diff options
| author | Stefan Boberg <[email protected]> | 2023-05-02 10:01:47 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-05-02 10:01:47 +0200 |
| commit | 075d17f8ada47e990fe94606c3d21df409223465 (patch) | |
| tree | e50549b766a2f3c354798a54ff73404217b4c9af /src | |
| parent | fix: bundle shouldn't append content zip to zen (diff) | |
| download | zen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz zen-075d17f8ada47e990fe94606c3d21df409223465.zip | |
moved source directories into `/src` (#264)
* moved source directories into `/src`
* updated bundle.lua for new `src` path
* moved some docs, icon
* removed old test trees
Diffstat (limited to 'src')
249 files changed, 81583 insertions, 0 deletions
diff --git a/src/UnrealEngine.ico b/src/UnrealEngine.ico Binary files differnew file mode 100644 index 000000000..1cfa301a2 --- /dev/null +++ b/src/UnrealEngine.ico diff --git a/src/zen/chunk/chunk.cpp b/src/zen/chunk/chunk.cpp new file mode 100644 index 000000000..d3591f8ca --- /dev/null +++ b/src/zen/chunk/chunk.cpp @@ -0,0 +1,1216 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "chunk.h" + +#if 0 +# include <gsl/gsl-lite.hpp> + +# include <zencore/filesystem.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zencore/refcount.h> +# include <zencore/scopeguard.h> +# include <zencore/sha1.h> +# include <zencore/string.h> +# include <zencore/testing.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zenstore/gc.h> + +# include "../internalfile.h" + +# include <lz4.h> +# include <zstd.h> + +# if ZEN_PLATFORM_WINDOWS +# include <ppl.h> +# include <ppltasks.h> +# endif // ZEN_PLATFORM_WINDOWS + +# include <cmath> +# include <filesystem> +# include <random> +# include <vector> + +////////////////////////////////////////////////////////////////////////// + +# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + +namespace Concurrency { + +template<typename IterType, typename LambdaType> +void +parallel_for_each(IterType Cursor, IterType End, const LambdaType& Lambda) +{ + for (; Cursor < End; ++Cursor) + { + Lambda(*Cursor); + } +} + +template<typename T> +struct combinable +{ + T& local() { return Value; } + + template<typename LambdaType> + void combine_each(const LambdaType& Lambda) + { + Lambda(Value); + } + + T Value = {}; +}; + +struct task_group +{ + template<class Function> + void run(const Function& Func) + { + Func(); + } + + void wait() {} +}; + +} // namespace Concurrency + +# endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + +////////////////////////////////////////////////////////////////////////// + +namespace detail { +static const uint32_t buzhashTable[] = { + 0x458be752, 0xc10748cc, 0xfbbcdbb8, 0x6ded5b68, 0xb10a82b5, 0x20d75648, 0xdfc5665f, 0xa8428801, 0x7ebf5191, 0x841135c7, 0x65cc53b3, + 0x280a597c, 0x16f60255, 0xc78cbc3e, 0x294415f5, 0xb938d494, 0xec85c4e6, 0xb7d33edc, 0xe549b544, 0xfdeda5aa, 0x882bf287, 0x3116737c, + 0x05569956, 0xe8cc1f68, 0x0806ac5e, 0x22a14443, 0x15297e10, 0x50d090e7, 0x4ba60f6f, 0xefd9f1a7, 0x5c5c885c, 0x82482f93, 0x9bfd7c64, + 0x0b3e7276, 0xf2688e77, 0x8fad8abc, 0xb0509568, 0xf1ada29f, 0xa53efdfe, 0xcb2b1d00, 0xf2a9e986, 0x6463432b, 0x95094051, 0x5a223ad2, + 0x9be8401b, 0x61e579cb, 0x1a556a14, 0x5840fdc2, 0x9261ddf6, 0xcde002bb, 0x52432bb0, 0xbf17373e, 0x7b7c222f, 0x2955ed16, 0x9f10ca59, + 0xe840c4c9, 0xccabd806, 0x14543f34, 0x1462417a, 0x0d4a1f9c, 0x087ed925, 0xd7f8f24c, 0x7338c425, 0xcf86c8f5, 0xb19165cd, 0x9891c393, + 0x325384ac, 0x0308459d, 0x86141d7e, 0xc922116a, 0xe2ffa6b6, 0x53f52aed, 0x2cd86197, 0xf5b9f498, 0xbf319c8f, 0xe0411fae, 0x977eb18c, + 0xd8770976, 0x9833466a, 0xc674df7f, 0x8c297d45, 0x8ca48d26, 0xc49ed8e2, 0x7344f874, 0x556f79c7, 0x6b25eaed, 0xa03e2b42, 0xf68f66a4, + 0x8e8b09a2, 0xf2e0e62a, 0x0d3a9806, 0x9729e493, 0x8c72b0fc, 0x160b94f6, 0x450e4d3d, 0x7a320e85, 0xbef8f0e1, 0x21d73653, 0x4e3d977a, + 0x1e7b3929, 0x1cc6c719, 0xbe478d53, 0x8d752809, 0xe6d8c2c6, 0x275f0892, 0xc8acc273, 0x4cc21580, 0xecc4a617, 0xf5f7be70, 0xe795248a, + 0x375a2fe9, 0x425570b6, 0x8898dcf8, 0xdc2d97c4, 0x0106114b, 0x364dc22f, 0x1e0cad1f, 0xbe63803c, 0x5f69fac2, 0x4d5afa6f, 0x1bc0dfb5, + 0xfb273589, 0x0ea47f7b, 0x3c1c2b50, 0x21b2a932, 0x6b1223fd, 0x2fe706a8, 0xf9bd6ce2, 0xa268e64e, 0xe987f486, 0x3eacf563, 0x1ca2018c, + 0x65e18228, 0x2207360a, 0x57cf1715, 0x34c37d2b, 0x1f8f3cde, 0x93b657cf, 0x31a019fd, 0xe69eb729, 0x8bca7b9b, 0x4c9d5bed, 0x277ebeaf, + 0xe0d8f8ae, 0xd150821c, 0x31381871, 0xafc3f1b0, 0x927db328, 0xe95effac, 0x305a47bd, 0x426ba35b, 0x1233af3f, 0x686a5b83, 0x50e072e5, + 0xd9d3bb2a, 0x8befc475, 0x487f0de6, 0xc88dff89, 0xbd664d5e, 0x971b5d18, 0x63b14847, 0xd7d3c1ce, 0x7f583cf3, 0x72cbcb09, 0xc0d0a81c, + 0x7fa3429b, 0xe9158a1b, 0x225ea19a, 0xd8ca9ea3, 0xc763b282, 0xbb0c6341, 0x020b8293, 0xd4cd299d, 0x58cfa7f8, 0x91b4ee53, 0x37e4d140, + 0x95ec764c, 0x30f76b06, 0x5ee68d24, 0x679c8661, 0xa41979c2, 0xf2b61284, 0x4fac1475, 0x0adb49f9, 0x19727a23, 0x15a7e374, 0xc43a18d5, + 0x3fb1aa73, 0x342fc615, 0x924c0793, 0xbee2d7f0, 0x8a279de9, 0x4aa2d70c, 0xe24dd37f, 0xbe862c0b, 0x177c22c2, 0x5388e5ee, 0xcd8a7510, + 0xf901b4fd, 0xdbc13dbc, 0x6c0bae5b, 0x64efe8c7, 0x48b02079, 0x80331a49, 0xca3d8ae6, 0xf3546190, 0xfed7108b, 0xc49b941b, 0x32baf4a9, + 0xeb833a4a, 0x88a3f1a5, 0x3a91ce0a, 0x3cc27da1, 0x7112e684, 0x4a3096b1, 0x3794574c, 0xa3c8b6f3, 0x1d213941, 0x6e0a2e00, 0x233479f1, + 0x0f4cd82f, 0x6093edd2, 0x5d7d209e, 0x464fe319, 0xd4dcac9e, 0x0db845cb, 0xfb5e4bc3, 0xe0256ce1, 0x09fb4ed1, 0x0914be1e, 0xa5bdb2c3, + 0xc6eb57bb, 0x30320350, 0x3f397e91, 0xa67791bc, 0x86bc0e2c, 0xefa0a7e2, 0xe9ff7543, 0xe733612c, 0xd185897b, 0x329e5388, 0x91dd236b, + 0x2ecb0d93, 0xf4d82a3d, 0x35b5c03f, 0xe4e606f0, 0x05b21843, 0x37b45964, 0x5eff22f4, 0x6027f4cc, 0x77178b3c, 0xae507131, 0x7bf7cabc, + 0xf9c18d66, 0x593ade65, 0xd95ddf11, +}; + +// ROL operation (compiler turns this into a ROL when optimizing) +static inline uint32_t +Rotate32(uint32_t Value, size_t RotateCount) +{ + RotateCount &= 31; + + return ((Value) << (RotateCount)) | ((Value) >> (32 - RotateCount)); +} +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +class ZenChunker +{ +public: + void SetChunkSize(size_t MinSize, size_t MaxSize, size_t AvgSize); + size_t ScanChunk(const void* DataBytes, size_t ByteCount); + void Reset(); + + // This controls which chunking approach is used - threshold or + // modulo based. Threshold is faster and generates similarly sized + // chunks + void SetUseThreshold(bool NewState) { m_useThreshold = NewState; } + + inline size_t ChunkSizeMin() const { return m_chunkSizeMin; } + inline size_t ChunkSizeMax() const { return m_chunkSizeMax; } + inline size_t ChunkSizeAvg() const { return m_chunkSizeAvg; } + inline uint64_t BytesScanned() const { return m_bytesScanned; } + + static constexpr size_t NoBoundaryFound = size_t(~0ull); + +private: + size_t m_chunkSizeMin = 0; + size_t m_chunkSizeMax = 0; + size_t m_chunkSizeAvg = 0; + + uint32_t m_discriminator = 0; // Computed in SetChunkSize() + uint32_t m_threshold = 0; // Computed in SetChunkSize() + + bool m_useThreshold = true; + + static constexpr size_t kChunkSizeLimitMax = 64 * 1024 * 1024; + static constexpr size_t kChunkSizeLimitMin = 1024; + + static constexpr size_t kDefaultAverageChunkSize = 64 * 1024; + + static constexpr int kWindowSize = 48; + uint8_t m_window[kWindowSize]; + uint32_t m_windowSize = 0; + + uint32_t m_currentHash = 0; + uint32_t m_currentChunkSize = 0; + + uint64_t m_bytesScanned = 0; + + size_t InternalScanChunk(const void* DataBytes, size_t ByteCount); + void InternalReset(); +}; + +void +ZenChunker::Reset() +{ + InternalReset(); + + m_bytesScanned = 0; +} + +void +ZenChunker::InternalReset() +{ + m_currentHash = 0; + m_currentChunkSize = 0; + m_windowSize = 0; +} + +void +ZenChunker::SetChunkSize(size_t MinSize, size_t MaxSize, size_t AvgSize) +{ + if (m_windowSize) + return; // Already started + + static_assert(kChunkSizeLimitMin > kWindowSize); + + if (AvgSize) + { + // TODO: Validate AvgSize range + } + else + { + if (MinSize && MaxSize) + { + AvgSize = lrint(pow(2, (log2(MinSize) + log2(MaxSize)) / 2)); + } + else if (MinSize) + { + AvgSize = MinSize * 4; + } + else if (MaxSize) + { + AvgSize = MaxSize / 4; + } + else + { + AvgSize = kDefaultAverageChunkSize; + } + } + + if (MinSize) + { + // TODO: Validate MinSize range + } + else + { + MinSize = std::max(AvgSize / 4, kChunkSizeLimitMin); + } + + if (MaxSize) + { + // TODO: Validate MaxSize range + } + else + { + MaxSize = std::min(AvgSize * 4, kChunkSizeLimitMax); + } + + m_discriminator = gsl::narrow<uint32_t>(AvgSize - MinSize); + + if (m_discriminator < MinSize) + { + m_discriminator = gsl::narrow<uint32_t>(MinSize); + } + + if (m_discriminator > MaxSize) + { + m_discriminator = gsl::narrow<uint32_t>(MaxSize); + } + + m_threshold = gsl::narrow<uint32_t>((uint64_t(std::numeric_limits<uint32_t>::max()) + 1) / m_discriminator); + + m_chunkSizeMin = MinSize; + m_chunkSizeMax = MaxSize; + m_chunkSizeAvg = AvgSize; +} + +size_t +ZenChunker::ScanChunk(const void* DataBytesIn, size_t ByteCount) +{ + size_t Result = InternalScanChunk(DataBytesIn, ByteCount); + + if (Result == NoBoundaryFound) + { + m_bytesScanned += ByteCount; + } + else + { + m_bytesScanned += Result; + } + + return Result; +} + +size_t +ZenChunker::InternalScanChunk(const void* DataBytesIn, size_t ByteCount) +{ + size_t CurrentOffset = 0; + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(DataBytesIn); + + // There's no point in updating the hash if we know we're not + // going to have a cut point, so just skip the data. This logic currently + // provides roughly a 20% speedup on my machine + + const size_t NeedHashOffset = m_chunkSizeMin - kWindowSize; + + if (m_currentChunkSize < NeedHashOffset) + { + const uint32_t SkipBytes = gsl::narrow<uint32_t>(std::min<uint64_t>(ByteCount, NeedHashOffset - m_currentChunkSize)); + + ByteCount -= SkipBytes; + m_currentChunkSize += SkipBytes; + CurrentOffset += SkipBytes; + CursorPtr += SkipBytes; + + m_windowSize = 0; + + if (ByteCount == 0) + { + return NoBoundaryFound; + } + } + + // Fill window first + + if (m_windowSize < kWindowSize) + { + const uint32_t FillBytes = uint32_t(std::min<size_t>(ByteCount, kWindowSize - m_windowSize)); + + memcpy(&m_window[m_windowSize], CursorPtr, FillBytes); + + CursorPtr += FillBytes; + + m_windowSize += FillBytes; + m_currentChunkSize += FillBytes; + + CurrentOffset += FillBytes; + ByteCount -= FillBytes; + + if (m_windowSize < kWindowSize) + { + return NoBoundaryFound; + } + + // We have a full window, initialize hash + + uint32_t CurrentHash = 0; + + for (int i = 1; i < kWindowSize; ++i) + { + CurrentHash ^= detail::Rotate32(detail::buzhashTable[m_window[i - 1]], kWindowSize - i); + } + + m_currentHash = CurrentHash ^ detail::buzhashTable[m_window[kWindowSize - 1]]; + } + + // Scan for boundaries (i.e points where the hash matches the value determined by + // the discriminator) + + uint32_t CurrentHash = m_currentHash; + uint32_t CurrentChunkSize = m_currentChunkSize; + + size_t Index = CurrentChunkSize % kWindowSize; + + if (m_threshold && m_useThreshold) + { + // This is roughly 4x faster than the general modulo approach on my + // TR 3990X (~940MB/sec) and doesn't require any special parameters to + // achieve max performance + + while (ByteCount) + { + const uint8_t NewByte = *CursorPtr; + const uint8_t OldByte = m_window[Index]; + + CurrentHash = detail::Rotate32(CurrentHash, 1) ^ detail::Rotate32(detail::buzhashTable[OldByte], m_windowSize) ^ + detail::buzhashTable[NewByte]; + + CurrentChunkSize++; + CurrentOffset++; + + if (CurrentChunkSize >= m_chunkSizeMin) + { + bool foundBreak; + + if (CurrentChunkSize >= m_chunkSizeMax) + { + foundBreak = true; + } + else + { + foundBreak = CurrentHash <= m_threshold; + } + + if (foundBreak) + { + // Boundary found! + InternalReset(); + + return CurrentOffset; + } + } + + m_window[Index++] = *CursorPtr; + + if (Index == kWindowSize) + { + Index = 0; + } + + ++CursorPtr; + --ByteCount; + } + } + else if ((m_discriminator & (m_discriminator - 1)) == 0) + { + // This is quite a bit faster than the generic modulo path, but + // requires a very specific average chunk size to be used. If you + // pass in an even power-of-two divided by 0.75 as the average + // chunk size you'll hit this path + + const uint32_t Mask = m_discriminator - 1; + + while (ByteCount) + { + const uint8_t NewByte = *CursorPtr; + const uint8_t OldByte = m_window[Index]; + + CurrentHash = detail::Rotate32(CurrentHash, 1) ^ detail::Rotate32(detail::buzhashTable[OldByte], m_windowSize) ^ + detail::buzhashTable[NewByte]; + + CurrentChunkSize++; + CurrentOffset++; + + if (CurrentChunkSize >= m_chunkSizeMin) + { + bool foundBreak; + + if (CurrentChunkSize >= m_chunkSizeMax) + { + foundBreak = true; + } + else + { + foundBreak = (CurrentHash & Mask) == Mask; + } + + if (foundBreak) + { + // Boundary found! + InternalReset(); + + return CurrentOffset; + } + } + + m_window[Index++] = *CursorPtr; + + if (Index == kWindowSize) + { + Index = 0; + } + + ++CursorPtr; + --ByteCount; + } + } + else + { + // This is the slowest path, which caps out around 250MB/sec for large sizes + // on my TR3900X + + while (ByteCount) + { + const uint8_t NewByte = *CursorPtr; + const uint8_t OldByte = m_window[Index]; + + CurrentHash = detail::Rotate32(CurrentHash, 1) ^ detail::Rotate32(detail::buzhashTable[OldByte], m_windowSize) ^ + detail::buzhashTable[NewByte]; + + CurrentChunkSize++; + CurrentOffset++; + + if (CurrentChunkSize >= m_chunkSizeMin) + { + bool foundBreak; + + if (CurrentChunkSize >= m_chunkSizeMax) + { + foundBreak = true; + } + else + { + foundBreak = (CurrentHash % m_discriminator) == (m_discriminator - 1); + } + + if (foundBreak) + { + // Boundary found! + InternalReset(); + + return CurrentOffset; + } + } + + m_window[Index++] = *CursorPtr; + + if (Index == kWindowSize) + { + Index = 0; + } + + ++CursorPtr; + --ByteCount; + } + } + + m_currentChunkSize = CurrentChunkSize; + m_currentHash = CurrentHash; + + return NoBoundaryFound; +} + +////////////////////////////////////////////////////////////////////////// + +class DirectoryScanner +{ +public: + struct FileEntry + { + std::filesystem::path Path; + uint64_t FileSize; + }; + + const std::vector<FileEntry>& Files() { return m_Files; } + std::vector<FileEntry>&& TakeFiles() { return std::move(m_Files); } + uint64_t FileBytes() const { return m_FileBytes; } + + void Scan(std::filesystem::path RootPath) + { + for (const std::filesystem::directory_entry& Entry : std::filesystem::recursive_directory_iterator(RootPath)) + { + if (Entry.is_regular_file()) + { + m_Files.push_back({Entry.path(), Entry.file_size()}); + m_FileBytes += Entry.file_size(); + } + } + } + +private: + std::vector<FileEntry> m_Files; + uint64_t m_FileBytes = 0; +}; + +////////////////////////////////////////////////////////////////////////// + +class BaseChunker +{ +public: + void SetCasStore(zen::CasStore* CasStore) { m_CasStore = CasStore; } + + struct StatsBlock + { + uint64_t TotalBytes = 0; + uint64_t TotalChunks = 0; + uint64_t TotalCompressed = 0; + uint64_t UniqueBytes = 0; + uint64_t UniqueChunks = 0; + uint64_t UniqueCompressed = 0; + uint64_t DuplicateBytes = 0; + uint64_t NewCasChunks = 0; + uint64_t NewCasBytes = 0; + + StatsBlock& operator+=(const StatsBlock& Rhs) + { + TotalBytes += Rhs.TotalBytes; + TotalChunks += Rhs.TotalChunks; + TotalCompressed += Rhs.TotalCompressed; + UniqueBytes += Rhs.UniqueBytes; + UniqueChunks += Rhs.UniqueChunks; + UniqueCompressed += Rhs.UniqueCompressed; + DuplicateBytes += Rhs.DuplicateBytes; + NewCasChunks += Rhs.NewCasChunks; + NewCasBytes += Rhs.NewCasBytes; + return *this; + } + }; + +protected: + Concurrency::combinable<StatsBlock> m_StatsBlock; + +public: + StatsBlock SumStats() + { + StatsBlock _; + m_StatsBlock.combine_each([&](const StatsBlock& Block) { _ += Block; }); + return _; + } + +protected: + struct HashSet + { + bool Add(const zen::IoHash& Hash) + { + const uint8_t ShardNo = Hash.Hash[19]; + + Bucket& Shard = m_Buckets[ShardNo]; + + zen::RwLock::ExclusiveLockScope _(Shard.HashLock); + + auto rv = Shard.Hashes.insert(Hash); + + return rv.second; + } + + private: + struct alignas(64) Bucket + { + zen::RwLock HashLock; + std::unordered_set<zen::IoHash, zen::IoHash::Hasher> Hashes; +# if ZEN_PLATFORM_WINDOWS +# pragma warning(suppress : 4324) // Padding due to alignment +# endif + }; + + Bucket m_Buckets[256]; + }; + + zen::CasStore* m_CasStore = nullptr; +}; + +class FixedBlockSizeChunker : public BaseChunker +{ +public: + FixedBlockSizeChunker(std::filesystem::path InRootPath) : m_RootPath(InRootPath) {} + ~FixedBlockSizeChunker() = default; + + void SetChunkSize(uint64_t ChunkSize) + { + /* TODO: verify validity of chunk size */ + m_ChunkSize = ChunkSize; + } + void SetUseCompression(bool UseCompression) { m_UseCompression = UseCompression; } + void SetPerformValidation(bool PerformValidation) { m_PerformValidation = PerformValidation; } + + void InitCompression() + { + if (!m_CompressionBufferManager) + { + std::call_once(m_CompressionInitFlag, [&] { + // Wasteful, but should only be temporary + m_CompressionBufferManager.reset(new FileBufferManager(m_ChunkSize * 2, 128)); + }); + } + } + + void ChunkFile(const DirectoryScanner::FileEntry& File) + { + InitCompression(); + + std::filesystem::path RelativePath{std::filesystem::relative(File.Path.generic_string(), m_RootPath)}; + + Concurrency::task_group ChunkProcessTasks; + + ZEN_INFO("Chunking {} ({})", RelativePath.generic_string(), zen::NiceBytes(File.FileSize)); + + zen::RefPtr<InternalFile> Zfile = new InternalFile; + Zfile->OpenRead(File.Path); + + size_t FileBytes = Zfile->GetFileSize(); + uint64_t CurrentFileOffset = 0; + + std::vector<zen::IoHash> BlockHashes{(FileBytes + m_ChunkSize - 1) / m_ChunkSize}; + + while (FileBytes) + { + zen::IoBuffer Buffer = m_BufferManager.AllocBuffer(); + + const size_t BytesToRead = std::min(FileBytes, Buffer.Size()); + + Zfile->Read((void*)Buffer.Data(), BytesToRead, CurrentFileOffset); + + auto ProcessChunk = [this, Buffer, &BlockHashes, CurrentFileOffset, BytesToRead] { + StatsBlock& Stats = m_StatsBlock.local(); + for (uint64_t Offset = 0; Offset < BytesToRead; Offset += m_ChunkSize) + { + const uint8_t* DataPointer = reinterpret_cast<const uint8_t*>(Buffer.Data()) + Offset; + const uint64_t DataSize = std::min(BytesToRead - Offset, m_ChunkSize); + const zen::IoHash Hash = zen::IoHash::HashBuffer(DataPointer, DataSize); + + BlockHashes[(CurrentFileOffset + Offset) / m_ChunkSize] = Hash; + + const bool IsNew = m_LocalHashSet.Add(Hash); + + if (IsNew) + { + if (m_UseCompression) + { + if (true) + { + // Compress using ZSTD + + // TODO: use CompressedBuffer format + + const size_t CompressBufferSize = ZSTD_compressBound(DataSize); + + zen::IoBuffer CompressedBuffer = m_CompressionBufferManager->AllocBuffer(); + char* CompressBuffer = (char*)CompressedBuffer.Data(); + + ZEN_ASSERT(CompressedBuffer.Size() >= CompressBufferSize); + + const size_t CompressedSize = ZSTD_compress(CompressBuffer, + CompressBufferSize, + (const char*)DataPointer, + DataSize, + ZSTD_CLEVEL_DEFAULT); + + Stats.UniqueCompressed += CompressedSize; + + if (m_CasStore) + { + const zen::IoHash CompressedHash = zen::IoHash::HashBuffer(CompressBuffer, CompressedSize); + zen::IoBuffer CompressedData = zen::IoBuffer(zen::IoBuffer::Wrap, CompressBuffer, CompressedSize); + zen::CasStore::InsertResult Result = m_CasStore->InsertChunk(CompressedData, CompressedHash); + + if (Result.New) + { + Stats.NewCasChunks += 1; + Stats.NewCasBytes += CompressedSize; + } + } + + m_CompressionBufferManager->ReturnBuffer(CompressedBuffer); + } + else + { + // Compress using LZ4 + const int CompressBufferSize = LZ4_compressBound(gsl::narrow<int>(DataSize)); + + zen::IoBuffer CompressedBuffer = m_CompressionBufferManager->AllocBuffer(); + char* CompressBuffer = (char*)CompressedBuffer.Data(); + + ZEN_ASSERT(CompressedBuffer.Size() >= size_t(CompressBufferSize)); + + const int CompressedSize = LZ4_compress_default((const char*)DataPointer, + CompressBuffer, + gsl::narrow<int>(DataSize), + CompressBufferSize); + + Stats.UniqueCompressed += CompressedSize; + + if (m_CasStore) + { + const zen::IoHash CompressedHash = zen::IoHash::HashBuffer(CompressBuffer, CompressedSize); + zen::IoBuffer CompressedData = zen::IoBuffer(zen::IoBuffer::Wrap, CompressBuffer, CompressedSize); + zen::CasStore::InsertResult Result = m_CasStore->InsertChunk(CompressedData, CompressedHash); + + if (Result.New) + { + Stats.NewCasChunks += 1; + Stats.NewCasBytes += CompressedSize; + } + } + + m_CompressionBufferManager->ReturnBuffer(CompressedBuffer); + } + } + else if (m_CasStore) + { + zen::CasStore::InsertResult Result = m_CasStore->InsertChunk(zen::IoBuffer(Buffer, Offset, DataSize), Hash); + + if (Result.New) + { + Stats.NewCasChunks += 1; + Stats.NewCasBytes += DataSize; + } + } + + Stats.UniqueBytes += DataSize; + Stats.UniqueChunks += 1; + } + else + { + // We've seen this chunk before + Stats.DuplicateBytes += DataSize; + } + + Stats.TotalBytes += DataSize; + Stats.TotalChunks += 1; + } + + m_BufferManager.ReturnBuffer(Buffer); + }; + + ChunkProcessTasks.run(ProcessChunk); + + CurrentFileOffset += BytesToRead; + FileBytes -= BytesToRead; + } + + ChunkProcessTasks.wait(); + + // Verify pass + + if (!m_UseCompression && m_PerformValidation) + { + const uint8_t* FileData = reinterpret_cast<const uint8_t*>(Zfile->MemoryMapFile()); + uint64_t Offset = 0; + const uint64_t BytesToRead = Zfile->GetFileSize(); + + for (zen::IoHash& Hash : BlockHashes) + { + const uint64_t DataSize = std::min(BytesToRead - Offset, m_ChunkSize); + const zen::IoHash CalcHash = zen::IoHash::HashBuffer(FileData + Offset, DataSize); + + ZEN_ASSERT(CalcHash == Hash); + + zen::IoBuffer FoundValue = m_CasStore->FindChunk(CalcHash); + + ZEN_ASSERT(FoundValue); + ZEN_ASSERT(FoundValue.Size() == DataSize); + + Offset += DataSize; + } + } + } + +private: + std::filesystem::path m_RootPath; + FileBufferManager m_BufferManager{128 * 1024, 128}; + uint64_t m_ChunkSize = 64 * 1024; + HashSet m_LocalHashSet; + bool m_UseCompression = true; + bool m_PerformValidation = false; + + std::once_flag m_CompressionInitFlag; + std::unique_ptr<FileBufferManager> m_CompressionBufferManager; +}; + +class VariableBlockSizeChunker : public BaseChunker +{ +public: + VariableBlockSizeChunker(std::filesystem::path InRootPath) : m_RootPath(InRootPath) {} + + void SetAverageChunkSize(uint64_t AverageChunkSize) { m_AverageChunkSize = AverageChunkSize; } + void SetUseCompression(bool UseCompression) { m_UseCompression = UseCompression; } + + void ChunkFile(const DirectoryScanner::FileEntry& File) + { + std::filesystem::path RelativePath{std::filesystem::relative(File.Path.generic_string(), m_RootPath)}; + + ZEN_INFO("Chunking {} ({})", RelativePath.generic_string(), zen::NiceBytes(File.FileSize)); + + zen::RefPtr<InternalFile> Zfile = new InternalFile; + Zfile->OpenRead(File.Path); + + // Could use IoBuffer here to help manage lifetimes of things + // across tasks / threads + + ZenChunker Chunker; + Chunker.SetChunkSize(0, 0, m_AverageChunkSize); + + const size_t DataSize = Zfile->GetFileSize(); + + std::vector<size_t> Boundaries; + + uint64_t CurrentStreamPosition = 0; + uint64_t CurrentChunkSize = 0; + size_t RemainBytes = DataSize; + + zen::IoHashStream IoHashStream; + + while (RemainBytes != 0) + { + zen::IoBuffer Buffer = m_BufferManager.AllocBuffer(); + + size_t BytesToRead = std::min(RemainBytes, Buffer.Size()); + + uint8_t* DataPointer = (uint8_t*)Buffer.Data(); + + Zfile->Read(DataPointer, BytesToRead, CurrentStreamPosition); + + StatsBlock& Stats = m_StatsBlock.local(); + + while (BytesToRead) + { + const size_t Boundary = Chunker.ScanChunk(DataPointer, BytesToRead); + + if (Boundary == ZenChunker::NoBoundaryFound) + { + IoHashStream.Append(DataPointer, BytesToRead); + CurrentStreamPosition += BytesToRead; + CurrentChunkSize += BytesToRead; + RemainBytes -= BytesToRead; + break; + } + + // Boundary found + + IoHashStream.Append(DataPointer, Boundary); + + const zen::IoHash Hash = IoHashStream.GetHash(); + const bool IsNew = m_LocalHashSet.Add(Hash); + + CurrentStreamPosition += Boundary; + CurrentChunkSize += Boundary; + Boundaries.push_back(CurrentStreamPosition); + + if (IsNew) + { + Stats.UniqueBytes += CurrentChunkSize; + } + else + { + // We've seen this chunk before + Stats.DuplicateBytes += CurrentChunkSize; + } + + DataPointer += Boundary; + RemainBytes -= Boundary; + BytesToRead -= Boundary; + CurrentChunkSize = 0; + IoHashStream.Reset(); + } + + m_BufferManager.ReturnBuffer(Buffer); + +# if 0 + Active.AddCount(); // needs fixing + + Concurrency::create_task([this, Zfile, CurrentPosition, DataPointer, &Active] { + const zen::IoHash Hash = zen::IoHash::HashBuffer(DataPointer, CurrentPosition); + + const bool isNew = m_LocalHashSet.Add(Hash); + + const int CompressBufferSize = LZ4_compressBound(gsl::narrow<int>(CurrentPosition)); + char* CompressBuffer = (char*)_aligned_malloc(CompressBufferSize, 16); + + const int CompressedSize = + LZ4_compress_default((const char*)DataPointer, CompressBuffer, gsl::narrow<int>(CurrentPosition), CompressBufferSize); + + m_TotalCompressed.local() += CompressedSize; + + if (isNew) + { + m_UniqueBytes.local() += CurrentPosition; + m_UniqueCompressed.local() += CompressedSize; + + if (m_CasStore) + { + const zen::IoHash CompressedHash = zen::IoHash::HashBuffer(CompressBuffer, CompressedSize); + m_CasStore->InsertChunk(CompressBuffer, CompressedSize, CompressedHash); + } + } + + Active.Signal(); // needs fixing + + _aligned_free(CompressBuffer); + }); +# endif + } + + StatsBlock& Stats = m_StatsBlock.local(); + Stats.TotalBytes += DataSize; + Stats.TotalChunks += Boundaries.size() + 1; + + // TODO: Wait for all compression tasks + + auto ChunkCount = Boundaries.size() + 1; + + ZEN_INFO("Split {} ({}) into {} chunks, avg size {}", + RelativePath.generic_string(), + zen::NiceBytes(File.FileSize), + ChunkCount, + File.FileSize / ChunkCount); + }; + +private: + HashSet m_LocalHashSet; + std::filesystem::path m_RootPath; + uint64_t m_AverageChunkSize = 32 * 1024; + bool m_UseCompression = true; + FileBufferManager m_BufferManager{128 * 1024, 128}; +}; + +////////////////////////////////////////////////////////////////////////// + +ChunkCommand::ChunkCommand() +{ + m_Options.add_options()("r,root", "Root directory for CAS pool", cxxopts::value(m_RootDirectory)); + m_Options.add_options()("d,dir", "Directory to scan", cxxopts::value(m_ScanDirectory)); + m_Options.add_options()("c,chunk-size", "Use fixed chunk size", cxxopts::value(m_ChunkSize)); + m_Options.add_options()("a,average-chunk-size", "Use dynamic chunk size", cxxopts::value(m_AverageChunkSize)); + m_Options.add_options()("compress", "Apply compression to chunks", cxxopts::value(m_UseCompression)); +} + +ChunkCommand::~ChunkCommand() = default; + +int +ChunkCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + bool IsValid = m_ScanDirectory.length(); + + if (!IsValid) + throw cxxopts::OptionParseException("Chunk command requires a directory to scan"); + + if ((m_ChunkSize && m_AverageChunkSize) && (!m_ChunkSize && !m_AverageChunkSize)) + throw cxxopts::OptionParseException("Either of --chunk-size or --average-chunk-size must be used"); + + std::unique_ptr<zen::CasStore> CasStore; + + zen::GcManager Gc; + + if (!m_RootDirectory.empty()) + { + zen::CasStoreConfiguration Config; + Config.RootDirectory = m_RootDirectory; + + CasStore = zen::CreateCasStore(Gc); + CasStore->Initialize(Config); + } + + // Gather list of files to process + + ZEN_INFO("Gathering files from {}", m_ScanDirectory); + + std::filesystem::path RootPath{m_ScanDirectory}; + DirectoryScanner Scanner; + Scanner.Scan(RootPath); + + auto Files = Scanner.TakeFiles(); + uint64_t FileBytes = Scanner.FileBytes(); + + std::sort(begin(Files), end(Files), [](const DirectoryScanner::FileEntry& Lhs, const DirectoryScanner::FileEntry& Rhs) { + return Lhs.FileSize < Rhs.FileSize; + }); + + ZEN_INFO("Gathered {} files, total size {}", Files.size(), zen::NiceBytes(FileBytes)); + + auto ReportSummary = [&](BaseChunker& Chunker, uint64_t ElapsedMs) { + const BaseChunker::StatsBlock& Stats = Chunker.SumStats(); + + const size_t TotalChunkCount = Stats.TotalChunks; + ZEN_INFO("Scanned {} files in {}, generated {} chunks", Files.size(), zen::NiceTimeSpanMs(ElapsedMs), TotalChunkCount); + + const size_t TotalByteCount = Stats.TotalBytes; + const size_t TotalCompressedBytes = Stats.TotalCompressed; + + ZEN_INFO("Total bytes {} ({}), compresses into {}", + zen::NiceBytes(TotalByteCount), + zen::NiceByteRate(TotalByteCount, ElapsedMs), + zen::NiceBytes(TotalCompressedBytes)); + + const size_t TotalUniqueBytes = Stats.UniqueBytes; + const size_t TotalUniqueCompressedBytes = Stats.UniqueCompressed; + const size_t TotalDuplicateBytes = Stats.DuplicateBytes; + + ZEN_INFO("Chunksize average {}, unique bytes = {} (compressed {}), dup bytes = {}", + TotalByteCount / TotalChunkCount, + zen::NiceBytes(TotalUniqueBytes), + zen::NiceBytes(TotalUniqueCompressedBytes), + zen::NiceBytes(TotalDuplicateBytes)); + + ZEN_INFO("New to CAS: {} chunks, {}", Stats.NewCasChunks, zen::NiceBytes(Stats.NewCasBytes)); + }; + + // Process them as quickly as possible + + if (m_AverageChunkSize) + { + VariableBlockSizeChunker Chunker{RootPath}; + Chunker.SetAverageChunkSize(m_AverageChunkSize); + Chunker.SetUseCompression(m_UseCompression); + Chunker.SetCasStore(CasStore.get()); + + zen::Stopwatch timer; + +# if 1 + Concurrency::parallel_for_each(begin(Files), end(Files), [&Chunker](const auto& ThisFile) { Chunker.ChunkFile(ThisFile); }); +# else + for (const auto& ThisFile : Files) + { + Chunker.ChunkFile(ThisFile); + } +# endif + + uint64_t ElapsedMs = timer.GetElapsedTimeMs(); + + ReportSummary(Chunker, ElapsedMs); + } + else if (m_ChunkSize) + { + FixedBlockSizeChunker Chunker{RootPath}; + Chunker.SetChunkSize(m_ChunkSize); + Chunker.SetUseCompression(m_UseCompression); + Chunker.SetCasStore(CasStore.get()); + + zen::Stopwatch timer; + + Concurrency::parallel_for_each(begin(Files), end(Files), [&Chunker](const DirectoryScanner::FileEntry& ThisFile) { + try + { + Chunker.ChunkFile(ThisFile); + } + catch (std::exception& ex) + { + zen::ExtendableStringBuilder<256> Path8; + zen::PathToUtf8(ThisFile.Path, Path8); + ZEN_WARN("Caught exception while chunking '{}': {}", Path8, ex.what()); + } + }); + + uint64_t ElapsedMs = timer.GetElapsedTimeMs(); + + ReportSummary(Chunker, ElapsedMs); + } + else + { + ZEN_ASSERT(false); + } + + // TODO: implement snapshot enumeration and display + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +# if ZEN_WITH_TESTS +TEST_CASE("chunking") +{ + using namespace zen; + + auto test = [](bool UseThreshold, bool Random, int MinBlockSize, int MaxBlockSize) { + std::mt19937_64 mt; + + std::vector<uint64_t> bytes; + bytes.resize(1 * 1024 * 1024); + + if (Random == false) + { + // Generate a single block of randomness + for (auto& w : bytes) + { + w = mt(); + } + } + + for (int i = MinBlockSize; i <= MaxBlockSize; i <<= 1) + { + Stopwatch timer; + + ZenChunker chunker; + chunker.SetUseThreshold(UseThreshold); + chunker.SetChunkSize(0, 0, i); + // chunker.SetChunkSize(i / 4, i * 4, 0); + // chunker.SetChunkSize(i / 8, i * 8, 0); + // chunker.SetChunkSize(i / 16, i * 16, 0); + // chunker.SetChunkSize(0, 0, size_t(i / 0.75)); // Hits the fast modulo path + + std::vector<size_t> boundaries; + + size_t CurrentPosition = 0; + int BoundaryCount = 0; + + do + { + if (Random == true) + { + // Generate a new block of randomness for each pass + for (auto& w : bytes) + { + w = mt(); + } + } + + const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(bytes.data()); + size_t BytesRemain = bytes.size() * sizeof(uint64_t); + + for (;;) + { + const size_t Boundary = chunker.ScanChunk(Ptr, BytesRemain); + + if (Boundary == ZenChunker::NoBoundaryFound) + { + CurrentPosition += BytesRemain; + break; + } + + // Boundary found + + CurrentPosition += Boundary; + + CHECK(CurrentPosition >= chunker.ChunkSizeMin()); + CHECK(CurrentPosition <= chunker.ChunkSizeMax()); + + boundaries.push_back(CurrentPosition); + + CurrentPosition = 0; + Ptr += Boundary; + BytesRemain -= Boundary; + + ++BoundaryCount; + } + } while (BoundaryCount < 5000); + + size_t BoundarySum = 0; + + for (const auto& v : boundaries) + { + BoundarySum += v; + } + + double Avg = double(BoundarySum) / BoundaryCount; + const uint64_t ElapsedTimeMs = timer.GetElapsedTimeMs(); + + ZEN_INFO("{:9} : Avg {:9} - {:2.5} ({:6}, {})", + i, + Avg, + double(i / Avg), + NiceTimeSpanMs(ElapsedTimeMs), + NiceByteRate(chunker.BytesScanned(), ElapsedTimeMs)); + } + }; + + const bool Random = false; + + SUBCASE("threshold method") { test(/* UseThreshold */ true, /* Random */ Random, 2048, 1 * 1024 * 1024); } + + SUBCASE("mod method") { test(/* UseThreshold */ false, /* Random */ Random, 2048, 1 * 1024 * 1024); } +} +# endif +#endif diff --git a/src/zen/chunk/chunk.h b/src/zen/chunk/chunk.h new file mode 100644 index 000000000..e796f4147 --- /dev/null +++ b/src/zen/chunk/chunk.h @@ -0,0 +1,25 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once +#include <zencore/zencore.h> +#include "../zen.h" + +#if 0 +class ChunkCommand : public ZenCmdBase +{ +public: + ChunkCommand(); + ~ChunkCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"chunk", "Do a chunking pass"}; + std::string m_RootDirectory; + std::string m_ScanDirectory; + size_t m_ChunkSize = 0; + size_t m_AverageChunkSize = 0; + bool m_UseCompression = true; +}; +#endif // 0 diff --git a/src/zen/cmds/cache.cpp b/src/zen/cmds/cache.cpp new file mode 100644 index 000000000..495662d2f --- /dev/null +++ b/src/zen/cmds/cache.cpp @@ -0,0 +1,275 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "cache.h" + +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zenhttp/httpcommon.h> +#include <zenutil/zenserverprocess.h> + +#include <memory> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +DropCommand::DropCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "n", "namespace", "Namespace name", cxxopts::value(m_NamespaceName), "<namespacename>"); + m_Options.add_option("", "b", "bucket", "Bucket name", cxxopts::value(m_BucketName), "<bucketname>"); + m_Options.parse_positional({"namespace", "bucket"}); +} + +DropCommand::~DropCommand() = default; + +int +DropCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + if (m_NamespaceName.empty()) + { + throw cxxopts::OptionParseException("Drop command requires a namespace"); + } + + cpr::Session Session; + if (m_BucketName.empty()) + { + ZEN_CONSOLE("Dropping cache namespace '{}' from '{}'", m_NamespaceName, m_HostName); + Session.SetUrl({fmt::format("{}/z$/{}", m_HostName, m_NamespaceName)}); + } + else + { + ZEN_CONSOLE("Dropping cache bucket '{}/{}' from '{}'", m_NamespaceName, m_BucketName, m_HostName); + Session.SetUrl({fmt::format("{}/z$/{}/{}", m_HostName, m_NamespaceName, m_BucketName)}); + } + + cpr::Response Result = Session.Delete(); + + if (zen::IsHttpSuccessCode(Result.status_code)) + { + ZEN_CONSOLE("OK: drop succeeded"); + + return 0; + } + + if (Result.status_code) + { + ZEN_ERROR("Drop failed: {}: {} ({})", Result.status_code, Result.reason, Result.text); + } + else + { + ZEN_ERROR("Drop failed: {}", Result.error.message); + } + + return 1; +} + +CacheInfoCommand::CacheInfoCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "n", "namespace", "Namespace name", cxxopts::value(m_NamespaceName), "<namespacename>"); + m_Options.add_option("", "b", "bucket", "Bucket name", cxxopts::value(m_BucketName), "<bucketname>"); + m_Options.parse_positional({"namespace", "bucket"}); +} + +CacheInfoCommand::~CacheInfoCommand() = default; + +int +CacheInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Session Session; + Session.SetHeader(cpr::Header{{"Accept", "application/json"}}); + if (m_HostName.empty()) + { + ZEN_CONSOLE("Info on cache from '{}'", m_HostName); + Session.SetUrl({fmt::format("{}/z$", m_HostName)}); + } + else if (m_BucketName.empty()) + { + ZEN_CONSOLE("Info on cache namespace '{}' from '{}'", m_NamespaceName, m_HostName); + Session.SetUrl({fmt::format("{}/z$/{}", m_HostName, m_NamespaceName)}); + } + else + { + ZEN_CONSOLE("Info on cache bucket '{}/{}' from '{}'", m_NamespaceName, m_BucketName, m_HostName); + Session.SetUrl({fmt::format("{}/z$/{}/{}", m_HostName, m_NamespaceName, m_BucketName)}); + } + + cpr::Response Result = Session.Get(); + + if (zen::IsHttpSuccessCode(Result.status_code)) + { + ZEN_CONSOLE("{}", Result.text); + + return 0; + } + + if (Result.status_code) + { + ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text); + } + else + { + ZEN_ERROR("Info failed: {}", Result.error.message); + } + + return 1; +} + +CacheStatsCommand::CacheStatsCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); +} + +CacheStatsCommand::~CacheStatsCommand() = default; + +int +CacheStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Session Session; + Session.SetUrl({fmt::format("{}/stats/z$", m_HostName)}); + Session.SetHeader(cpr::Header{{"Accept", "application/json"}}); + + cpr::Response Result = Session.Get(); + + if (zen::IsHttpSuccessCode(Result.status_code)) + { + ZEN_CONSOLE("{}", Result.text); + + return 0; + } + + if (Result.status_code) + { + ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text); + } + else + { + ZEN_ERROR("Info failed: {}", Result.error.message); + } + + return 1; +} + +CacheDetailsCommand::CacheDetailsCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "c", "csv", "Info on csv format", cxxopts::value(m_CSV), "<csv>"); + m_Options.add_option("", "d", "details", "Get detailed information about records", cxxopts::value(m_Details), "<details>"); + m_Options.add_option("", + "a", + "attachmentdetails", + "Get detailed information about attachments", + cxxopts::value(m_AttachmentDetails), + "<attachmentdetails>"); + m_Options.add_option("", "n", "namespace", "Namespace name to get info for", cxxopts::value(m_Namespace), "<namespace>"); + m_Options.add_option("", "b", "bucket", "Filter on bucket name", cxxopts::value(m_Bucket), "<bucket>"); + m_Options.add_option("", "v", "valuekey", "Filter on value key hash string", cxxopts::value(m_ValueKey), "<valuekey>"); +} + +CacheDetailsCommand::~CacheDetailsCommand() = default; + +int +CacheDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Session Session; + cpr::Parameters Parameters; + if (m_Details) + { + Parameters.Add({"details", "true"}); + } + if (m_AttachmentDetails) + { + Parameters.Add({"attachmentdetails", "true"}); + } + if (m_CSV) + { + Parameters.Add({"csv", "true"}); + } + else + { + Session.SetHeader(cpr::Header{{"Accept", "application/json"}}); + } + + if (!m_ValueKey.empty()) + { + if (m_Namespace.empty() || m_Bucket.empty()) + { + ZEN_ERROR("Provide namespace and bucket name"); + ZEN_CONSOLE("{}", m_Options.help({""}).c_str()); + return 1; + } + Session.SetUrl({fmt::format("{}/z$/details$/{}/{}/{}", m_HostName, m_Namespace, m_Bucket, m_ValueKey)}); + } + else if (!m_Bucket.empty()) + { + if (m_Namespace.empty()) + { + ZEN_ERROR("Provide namespace name"); + ZEN_CONSOLE("{}", m_Options.help({""}).c_str()); + return 1; + } + Session.SetUrl({fmt::format("{}/z$/details$/{}/{}", m_HostName, m_Namespace, m_Bucket)}); + } + else if (!m_Namespace.empty()) + { + Session.SetUrl({fmt::format("{}/z$/details$/{}", m_HostName, m_Namespace)}); + } + else + { + Session.SetUrl({fmt::format("{}/z$/details$", m_HostName)}); + } + Session.SetParameters(Parameters); + + cpr::Response Result = Session.Get(); + + if (zen::IsHttpSuccessCode(Result.status_code)) + { + ZEN_CONSOLE("{}", Result.text); + + return 0; + } + + if (Result.status_code) + { + ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text); + } + else + { + ZEN_ERROR("Info failed: {}", Result.error.message); + } + + return 1; +} diff --git a/src/zen/cmds/cache.h b/src/zen/cmds/cache.h new file mode 100644 index 000000000..1f368bdec --- /dev/null +++ b/src/zen/cmds/cache.h @@ -0,0 +1,68 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +class DropCommand : public ZenCmdBase +{ +public: + DropCommand(); + ~DropCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"drop", "Drop cache namespace or bucket"}; + std::string m_HostName; + std::string m_NamespaceName; + std::string m_BucketName; +}; + +class CacheInfoCommand : public ZenCmdBase +{ +public: + CacheInfoCommand(); + ~CacheInfoCommand(); + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"cache-info", "Info on cache, namespace or bucket"}; + std::string m_HostName; + std::string m_NamespaceName; + std::string m_BucketName; +}; + +class CacheStatsCommand : public ZenCmdBase +{ +public: + CacheStatsCommand(); + ~CacheStatsCommand(); + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"cache-stats", "Stats info on cache"}; + std::string m_HostName; +}; + +class CacheDetailsCommand : public ZenCmdBase +{ +public: + CacheDetailsCommand(); + ~CacheDetailsCommand(); + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"cache-details", "Detailed info on cache"}; + std::string m_HostName; + bool m_CSV; + bool m_Details; + bool m_AttachmentDetails; + std::string m_Namespace; + std::string m_Bucket; + std::string m_ValueKey; +}; diff --git a/src/zen/cmds/copy.cpp b/src/zen/cmds/copy.cpp new file mode 100644 index 000000000..6f6c078d4 --- /dev/null +++ b/src/zen/cmds/copy.cpp @@ -0,0 +1,95 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "copy.h" + +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <zencore/timer.h> + +namespace zen { + +CopyCommand::CopyCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_options()("no-clone", "Do not perform block clone", cxxopts::value(m_NoClone)->default_value("false")); + m_Options.add_option("", "s", "source", "Copy source", cxxopts::value(m_CopySource), "<file/directory>"); + m_Options.add_option("", "t", "target", "Copy target", cxxopts::value(m_CopyTarget), "<file/directory>"); + m_Options.add_option("", "", "positional", "Positional arguments", cxxopts::value(m_Positional), ""); + m_Options.parse_positional({"source", "target", "positional"}); +} + +CopyCommand::~CopyCommand() = default; + +int +CopyCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ZenCmdBase::ParseOptions(argc, argv)) + { + return 0; + } + + // Validate arguments + + if (m_CopySource.empty()) + throw std::runtime_error("No source specified"); + + if (m_CopyTarget.empty()) + throw std::runtime_error("No target specified"); + + std::filesystem::path FromPath; + std::filesystem::path ToPath; + + FromPath = m_CopySource; + ToPath = m_CopyTarget; + + const bool IsFileCopy = std::filesystem::is_regular_file(m_CopySource); + const bool IsDirCopy = std::filesystem::is_directory(m_CopySource); + + if (!IsFileCopy && !IsDirCopy) + { + throw std::runtime_error("Invalid source specification (neither directory nor file)"); + } + + if (IsFileCopy && IsDirCopy) + { + throw std::runtime_error("Invalid source specification (both directory AND file!?)"); + } + + if (IsDirCopy) + { + if (std::filesystem::exists(ToPath)) + { + const bool IsTargetDir = std::filesystem::is_directory(ToPath); + if (!IsTargetDir) + { + if (std::filesystem::is_regular_file(ToPath)) + { + throw std::runtime_error("Attempted copy of directory into file"); + } + } + } + else + { + std::filesystem::create_directories(ToPath); + } + } + else + { + // Single file copy + + zen::Stopwatch Timer; + + zen::CopyFileOptions CopyOptions; + CopyOptions.EnableClone = !m_NoClone; + zen::CopyFile(FromPath, ToPath, CopyOptions); + + ZEN_CONSOLE("Copy completed in {}", zen::NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + } + + return 0; +} + +} // namespace zen diff --git a/src/zen/cmds/copy.h b/src/zen/cmds/copy.h new file mode 100644 index 000000000..5527ae9b8 --- /dev/null +++ b/src/zen/cmds/copy.h @@ -0,0 +1,28 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +namespace zen { + +/** Copy files, possibly using block cloning + */ +class CopyCommand : public ZenCmdBase +{ +public: + CopyCommand(); + ~CopyCommand(); + + virtual cxxopts::Options& Options() override { return m_Options; } + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + +private: + cxxopts::Options m_Options{"copy", "Copy files"}; + std::vector<std::string> m_Positional; + std::string m_CopySource; + std::string m_CopyTarget; + bool m_NoClone = false; +}; + +} // namespace zen diff --git a/src/zen/cmds/dedup.cpp b/src/zen/cmds/dedup.cpp new file mode 100644 index 000000000..b48fb8c2d --- /dev/null +++ b/src/zen/cmds/dedup.cpp @@ -0,0 +1,302 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "dedup.h" + +#include <zencore/blake3.h> +#include <zencore/filesystem.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/timer.h> + +#if ZEN_PLATFORM_WINDOWS +# include <ppl.h> +#endif + +#include <list> + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + +namespace Concurrency { + + template<typename T0, typename T1> + inline void parallel_invoke(T0 const& t0, T1 const& t1) + { + t0(); + t1(); + } + +} // namespace Concurrency + +#endif // ZEN_PLATFORM_LINUX/MAC + +//////////////////////////////////////////////////////////////////////////////// + +DedupCommand::DedupCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_options()("size", "Configure size threshold for dedup", cxxopts::value(m_SizeThreshold)->default_value("131072")); + m_Options.add_option("", "s", "source", "Copy source", cxxopts::value(m_DedupSource), "<file/directory>"); + m_Options.add_option("", "t", "target", "Copy target", cxxopts::value(m_DedupTarget), "<file/directory>"); + m_Options.add_option("", "", "positional", "Positional arguments", cxxopts::value(m_Positional), ""); + m_Options.parse_positional({"source", "target", "positional"}); +} + +DedupCommand::~DedupCommand() = default; + +int +DedupCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + // Validate arguments + + const bool SourceGood = zen::SupportsBlockRefCounting(m_DedupSource); + const bool TargetGood = zen::SupportsBlockRefCounting(m_DedupTarget); + + if (!SourceGood) + { + ZEN_ERROR("Source directory '{}' does not support deduplication", m_DedupSource); + + return 0; + } + + if (!TargetGood) + { + ZEN_ERROR("Target directory '{}' does not support deduplication", m_DedupTarget); + + return 0; + } + + ZEN_CONSOLE("Performing dedup operation between {} and {}, size threshold {}", + m_DedupSource, + m_DedupTarget, + zen::NiceBytes(m_SizeThreshold)); + + using DirEntryList_t = std::list<std::filesystem::directory_entry>; + + zen::RwLock MapLock; + std::unordered_map<size_t, DirEntryList_t> FileSizeMap; + size_t CandidateCount = 0; + + auto AddToList = [&](const std::filesystem::directory_entry& Entry) { + if (Entry.is_regular_file()) + { + uintmax_t FileSize = Entry.file_size(); + if (FileSize > m_SizeThreshold) + { + zen::RwLock::ExclusiveLockScope _(MapLock); + FileSizeMap[FileSize].push_back(Entry); + ++CandidateCount; + } + } + }; + + std::filesystem::recursive_directory_iterator DirEnd; + + ZEN_CONSOLE("Gathering file info from source: '{}'", m_DedupSource); + ZEN_CONSOLE("Gathering file info from target: '{}'", m_DedupTarget); + + { + zen::Stopwatch Timer; + + Concurrency::parallel_invoke( + [&] { + for (std::filesystem::recursive_directory_iterator DirIt1(m_DedupSource); DirIt1 != DirEnd; ++DirIt1) + { + AddToList(*DirIt1); + } + }, + [&] { + for (std::filesystem::recursive_directory_iterator DirIt2(m_DedupTarget); DirIt2 != DirEnd; ++DirIt2) + { + AddToList(*DirIt2); + } + }); + + ZEN_CONSOLE("Gathered {} candidates across {} size buckets. Elapsed: {}", + CandidateCount, + FileSizeMap.size(), + zen::NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + } + + ZEN_CONSOLE("Sorting buckets by size"); + + zen::Stopwatch Timer; + + uint64_t DupeBytes = 0; + + struct SizeList + { + size_t Size; + DirEntryList_t* DirEntries; + }; + + std::vector<SizeList> SizeLists{FileSizeMap.size()}; + + { + int i = 0; + + for (auto& kv : FileSizeMap) + { + ZEN_ASSERT(kv.first >= m_SizeThreshold); + SizeLists[i].Size = kv.first; + SizeLists[i].DirEntries = &kv.second; + ++i; + } + } + + std::sort(begin(SizeLists), end(SizeLists), [](const SizeList& Lhs, const SizeList& Rhs) { return Lhs.Size > Rhs.Size; }); + + ZEN_CONSOLE("Bucket summary:"); + + std::vector<size_t> BucketId; + std::vector<size_t> BucketOffsets; + std::vector<size_t> BucketSizes; + std::vector<size_t> BucketFileCounts; + + size_t TotalFileSizes = 0; + size_t TotalFileCount = 0; + + { + size_t CurrentPow2 = 0; + size_t BucketSize = 0; + size_t BucketFileCount = 0; + bool FirstBucket = true; + + for (size_t i = 0; i < SizeLists.size(); ++i) + { + const size_t ThisSize = SizeLists[i].Size; + const size_t Pow2 = zen::NextPow2(ThisSize); + + if (CurrentPow2 != Pow2) + { + CurrentPow2 = Pow2; + + if (!FirstBucket) + { + BucketSizes.push_back(BucketSize); + BucketFileCounts.push_back(BucketFileCount); + } + + BucketId.push_back(Pow2); + BucketOffsets.push_back(i); + + FirstBucket = false; + BucketSize = 0; + BucketFileCount = 0; + } + + BucketSize += ThisSize; + TotalFileSizes += ThisSize; + BucketFileCount += SizeLists[i].DirEntries->size(); + TotalFileCount += SizeLists[i].DirEntries->size(); + } + + if (!FirstBucket) + { + BucketSizes.push_back(BucketSize); + BucketFileCounts.push_back(BucketFileCount); + } + + ZEN_ASSERT(BucketOffsets.size() == BucketSizes.size()); + ZEN_ASSERT(BucketOffsets.size() == BucketFileCounts.size()); + } + + for (size_t i = 0; i < BucketOffsets.size(); ++i) + { + ZEN_CONSOLE(" Bucket {} : {}, {} candidates", zen::NiceBytes(BucketId[i]), zen::NiceBytes(BucketSizes[i]), BucketFileCounts[i]); + } + + ZEN_CONSOLE("Total : {}, {} candidates", zen::NiceBytes(TotalFileSizes), TotalFileCount); + + std::string CurrentNice; + + for (SizeList& Size : SizeLists) + { + std::string CurNice{zen::NiceBytes(zen::NextPow2(Size.Size))}; + + if (CurNice != CurrentNice) + { + CurrentNice = CurNice; + ZEN_CONSOLE("Now scanning bucket: {}", CurrentNice); + } + + std::unordered_map<zen::BLAKE3, const std::filesystem::directory_entry*, zen::BLAKE3::Hasher> DedupMap; + + for (const auto& Entry : *Size.DirEntries) + { + zen::BLAKE3 Hash; + + if constexpr (true) + { + zen::BLAKE3Stream b3s; + + zen::ScanFile(Entry.path(), 64 * 1024, [&](const void* Data, size_t Size) { b3s.Append(Data, Size); }); + + Hash = b3s.GetHash(); + } + else + { + zen::FileContents Contents = zen::ReadFile(Entry.path()); + + zen::BLAKE3Stream b3s; + + for (zen::IoBuffer& Buffer : Contents.Data) + { + b3s.Append(Buffer.Data(), Buffer.Size()); + } + Hash = b3s.GetHash(); + } + + if (const std::filesystem::directory_entry* Dupe = DedupMap[Hash]) + { + std::string FileA = PathToUtf8(Dupe->path()); + std::string FileB = PathToUtf8(Entry.path()); + + size_t MinLen = std::min(FileA.size(), FileB.size()); + auto Its = std::mismatch(FileB.rbegin(), FileB.rbegin() + MinLen, FileA.rbegin()); + + if (Its.first != FileB.rbegin()) + { + if (Its.first[-1] == '\\' || Its.first[-1] == '/') + --Its.first; + + FileB = std::string(FileB.begin(), Its.first.base()) + "..."; + } + + ZEN_INFO("{} {} <-> {}", zen::NiceBytes(Entry.file_size()).c_str(), FileA.c_str(), FileB.c_str()); + + zen::CopyFileOptions Options; + Options.EnableClone = true; + Options.MustClone = true; + + zen::CopyFile(Dupe->path(), Entry.path(), Options); + + DupeBytes += Entry.file_size(); + } + else + { + DedupMap[Hash] = &Entry; + } + } + + Size.DirEntries->clear(); + } + + ZEN_CONSOLE("Elapsed: {} Deduped: {}", zen::NiceTimeSpanMs(Timer.GetElapsedTimeMs()), zen::NiceBytes(DupeBytes)); + + return 0; +} + +} // namespace zen diff --git a/src/zen/cmds/dedup.h b/src/zen/cmds/dedup.h new file mode 100644 index 000000000..6318704f5 --- /dev/null +++ b/src/zen/cmds/dedup.h @@ -0,0 +1,28 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +namespace zen { + +/** Deduplicate files in a tree using block cloning + */ +class DedupCommand : public ZenCmdBase +{ +public: + DedupCommand(); + ~DedupCommand(); + + virtual cxxopts::Options& Options() override { return m_Options; } + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + +private: + cxxopts::Options m_Options{"dedup", "Deduplicate files"}; + std::vector<std::string> m_Positional; + std::string m_DedupSource; + std::string m_DedupTarget; + size_t m_SizeThreshold = 1024 * 1024; +}; + +} // namespace zen diff --git a/src/zen/cmds/hash.cpp b/src/zen/cmds/hash.cpp new file mode 100644 index 000000000..7987d7738 --- /dev/null +++ b/src/zen/cmds/hash.cpp @@ -0,0 +1,171 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "hash.h" + +#include <zencore/blake3.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <zencore/timer.h> + +#if ZEN_PLATFORM_WINDOWS +# include <ppl.h> +#endif + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + +namespace Concurrency { + + template<typename IterType, typename LambdaType> + void parallel_for_each(IterType Cursor, IterType End, const LambdaType& Lambda) + { + for (; Cursor < End; ++Cursor) + { + Lambda(*Cursor); + } + } + + template<typename T> + struct combinable + { + combinable<T>& local() { return *this; } + + void operator+=(T Rhs) { Value += Rhs; } + + template<typename LambdaType> + void combine_each(const LambdaType& Lambda) + { + Lambda(Value); + } + + T Value = 0; + }; + +} // namespace Concurrency + +#endif // ZEN_PLATFORM_LINUX|MAC + +//////////////////////////////////////////////////////////////////////////////// + +HashCommand::HashCommand() +{ + m_Options.add_options()("d,dir", "Directory to scan", cxxopts::value<std::string>(m_ScanDirectory))( + "o,output", + "Output file", + cxxopts::value<std::string>(m_OutputFile)); +} + +HashCommand::~HashCommand() = default; + +int +HashCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + bool valid = m_ScanDirectory.length(); + + if (!valid) + throw cxxopts::OptionParseException("Hash command requires a directory to scan"); + + // Gather list of files to process + + ZEN_CONSOLE("Gathering files from {}", m_ScanDirectory); + + struct FileEntry + { + std::filesystem::path FilePath; + zen::BLAKE3 FileHash; + }; + + std::vector<FileEntry> FileList; + uint64_t FileBytes = 0; + + std::filesystem::path ScanDirectoryPath{m_ScanDirectory}; + + for (const std::filesystem::directory_entry& Entry : std::filesystem::recursive_directory_iterator(ScanDirectoryPath)) + { + if (Entry.is_regular_file()) + { + FileList.push_back({Entry.path()}); + FileBytes += Entry.file_size(); + } + } + + ZEN_CONSOLE("Gathered {} files, total size {}", FileList.size(), zen::NiceBytes(FileBytes)); + + Concurrency::combinable<uint64_t> TotalBytes; + + auto hashFile = [&](FileEntry& File) { + InternalFile InputFile; + InputFile.OpenRead(File.FilePath); + const uint8_t* DataPointer = (const uint8_t*)InputFile.MemoryMapFile(); + const size_t DataSize = InputFile.GetFileSize(); + + File.FileHash = zen::BLAKE3::HashMemory(DataPointer, DataSize); + + TotalBytes.local() += DataSize; + }; + + // Process them as quickly as possible + + zen::Stopwatch Timer; + +#if 1 + Concurrency::parallel_for_each(begin(FileList), end(FileList), [&](auto& file) { hashFile(file); }); +#else + for (const auto& file : FileList) + { + hashFile(file); + } +#endif + + size_t TotalByteCount = 0; + + TotalBytes.combine_each([&](size_t Total) { TotalByteCount += Total; }); + + const uint64_t ElapsedMs = Timer.GetElapsedTimeMs(); + ZEN_CONSOLE("Scanned {} files in {}", FileList.size(), zen::NiceTimeSpanMs(ElapsedMs)); + ZEN_CONSOLE("Total bytes {} ({})", zen::NiceBytes(TotalByteCount), zen::NiceByteRate(TotalByteCount, ElapsedMs)); + + InternalFile Output; + + if (m_OutputFile.empty()) + { + // TEMPORARY -- should properly open stdout + Output.OpenWrite("CONOUT$", false); + } + else + { + Output.OpenWrite(m_OutputFile, true); + } + + zen::ExtendableStringBuilder<256> Line; + + uint64_t CurrentOffset = 0; + + for (const auto& File : FileList) + { + Line.Append(File.FilePath.generic_u8string().c_str()); + Line.Append(','); + File.FileHash.ToHexString(Line); + Line.Append('\n'); + + Output.Write(Line.Data(), Line.Size(), CurrentOffset); + CurrentOffset += Line.Size(); + + Line.Reset(); + } + + // TODO: implement snapshot enumeration and display + return 0; +} + +} // namespace zen diff --git a/src/zen/cmds/hash.h b/src/zen/cmds/hash.h new file mode 100644 index 000000000..e5ee071e9 --- /dev/null +++ b/src/zen/cmds/hash.h @@ -0,0 +1,27 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../internalfile.h" +#include "../zen.h" + +namespace zen { + +/** Generate hash list file + */ +class HashCommand : public ZenCmdBase +{ +public: + HashCommand(); + ~HashCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"hash", "Hash files"}; + std::string m_ScanDirectory; + std::string m_OutputFile; +}; + +} // namespace zen diff --git a/src/zen/cmds/print.cpp b/src/zen/cmds/print.cpp new file mode 100644 index 000000000..67191605c --- /dev/null +++ b/src/zen/cmds/print.cpp @@ -0,0 +1,193 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "print.h" + +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <zenhttp/httpshared.h> + +using namespace std::literals; + +namespace zen { + +static void +PrintCbObject(CbObject Object) +{ + zen::StringBuilder<1024> ObjStr; + zen::CompactBinaryToJson(Object, ObjStr); + ZEN_CONSOLE("{}", ObjStr); +} + +static void +PrintCbObject(IoBuffer Data) +{ + zen::CbObject Object{SharedBuffer(Data)}; + + PrintCbObject(Object); +} + +PrintCommand::PrintCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "s", "source", "Object payload file", cxxopts::value(m_Filename), "<file name>"); + m_Options.parse_positional({"source"}); +} + +PrintCommand::~PrintCommand() = default; + +int +PrintCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + // Validate arguments + + if (m_Filename.empty()) + throw std::runtime_error("No file specified"); + + zen::FileContents Fc; + + if (m_Filename == "-") + { + Fc = zen::ReadStdIn(); + } + else + { + Fc = zen::ReadFile(m_Filename); + } + + if (Fc.ErrorCode) + { + ZEN_ERROR("Failed to read file '{}': {}", m_Filename, Fc.ErrorCode.message()); + + return 1; + } + + IoBuffer Data = Fc.Flatten(); + + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(Data, RawHash, RawSize)) + { + ZEN_CONSOLE("Compressed binary: size {}, raw size {}, hash: {}", Data.GetSize(), RawSize, RawHash); + } + else if (IsPackageMessage(Data)) + { + CbPackage Package = ParsePackageMessage(Data); + + CbObject Object = Package.GetObject(); + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + ZEN_CONSOLE("Package - {} attachments, object hash {}", Package.GetAttachments().size(), Package.GetObjectHash()); + ZEN_CONSOLE(""); + + int AttachmentIndex = 1; + + for (const CbAttachment& Attachment : Attachments) + { + std::string AttachmentSize = "n/a"; + const char* AttachmentType = "unknown"; + + if (Attachment.IsCompressedBinary()) + { + AttachmentType = "Compressed"; + AttachmentSize = fmt::format("{} ({} uncompressed)", + Attachment.AsCompressedBinary().GetCompressedSize(), + Attachment.AsCompressedBinary().DecodeRawSize()); + } + else if (Attachment.IsBinary()) + { + AttachmentType = "Binary"; + AttachmentSize = fmt::format("{}", Attachment.AsBinary().GetSize()); + } + else if (Attachment.IsObject()) + { + AttachmentType = "Object"; + AttachmentSize = fmt::format("{}", Attachment.AsObject().GetSize()); + } + else if (Attachment.IsNull()) + { + AttachmentType = "null"; + } + + ZEN_CONSOLE("Attachment #{} : {}, {}, size {}", AttachmentIndex, Attachment.GetHash(), AttachmentType, AttachmentSize); + + ++AttachmentIndex; + } + + ZEN_CONSOLE("---8<---"); + + PrintCbObject(Object); + } + else if (CbValidateError Result = ValidateCompactBinary(Data, CbValidateMode::All); Result == CbValidateError::None) + { + PrintCbObject(Data); + } + else + { + ZEN_ERROR("Data in file '{}' does not appear to be compact binary (validation error {:#x})", m_Filename, uint32_t(Result)); + + return 1; + } + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +PrintPackageCommand::PrintPackageCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "s", "source", "Package payload file", cxxopts::value(m_Filename), "<file name>"); + m_Options.parse_positional({"source"}); +} + +PrintPackageCommand::~PrintPackageCommand() +{ +} + +int +PrintPackageCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + // Validate arguments + + if (m_Filename.empty()) + throw std::runtime_error("No file specified"); + + zen::FileContents Fc = zen::ReadFile(m_Filename); + IoBuffer Data = Fc.Flatten(); + zen::CbPackage Package; + + bool Ok = Package.TryLoad(Data) || zen::legacy::TryLoadCbPackage(Package, Data, &UniqueBuffer::Alloc); + + if (Ok) + { + zen::StringBuilder<1024> ObjStr; + zen::CompactBinaryToJson(Package.GetObject(), ObjStr); + ZEN_CONSOLE("{}", ObjStr); + } + else + { + ZEN_ERROR("error: malformed package?"); + } + + return 0; +} + +} // namespace zen diff --git a/src/zen/cmds/print.h b/src/zen/cmds/print.h new file mode 100644 index 000000000..09d91830a --- /dev/null +++ b/src/zen/cmds/print.h @@ -0,0 +1,41 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +namespace zen { + +/** Print Compact Binary + */ +class PrintCommand : public ZenCmdBase +{ +public: + PrintCommand(); + ~PrintCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"print", "Print compact binary object"}; + std::string m_Filename; +}; + +/** Print Compact Binary Package + */ +class PrintPackageCommand : public ZenCmdBase +{ +public: + PrintPackageCommand(); + ~PrintPackageCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"printpkg", "Print compact binary package"}; + std::string m_Filename; +}; + +} // namespace zen diff --git a/src/zen/cmds/projectstore.cpp b/src/zen/cmds/projectstore.cpp new file mode 100644 index 000000000..fe0dd713e --- /dev/null +++ b/src/zen/cmds/projectstore.cpp @@ -0,0 +1,930 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "projectstore.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/logging.h> +#include <zencore/stream.h> +#include <zenhttp/httpcommon.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace { + +using namespace std::literals; + +const std::string DefaultCloudAccessTokenEnvVariableName( +#if ZEN_PLATFORM_WINDOWS + "UE-CloudDataCacheAccessToken"sv +#endif +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + "UE_CloudDataCacheAccessToken"sv +#endif +); + +} // namespace + +/////////////////////////////////////// + +DropProjectCommand::DropProjectCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>"); + m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>"); + m_Options.parse_positional({"project", "oplog"}); +} + +DropProjectCommand::~DropProjectCommand() = default; + +int +DropProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + if (m_ProjectName.empty()) + { + throw cxxopts::OptionParseException("Drop command requires a project"); + } + + cpr::Session Session; + if (m_OplogName.empty()) + { + ZEN_CONSOLE("Dropping project '{}' from '{}'", m_ProjectName, m_HostName); + Session.SetUrl({fmt::format("{}/prj/{}", m_HostName, m_ProjectName)}); + } + else + { + ZEN_CONSOLE("Dropping oplog '{}/{}' from '{}'", m_ProjectName, m_OplogName, m_HostName); + Session.SetUrl({fmt::format("{}/prj/{}/oplog/{}", m_HostName, m_ProjectName, m_OplogName)}); + } + + cpr::Response Result = Session.Delete(); + + if (zen::IsHttpSuccessCode(Result.status_code)) + { + ZEN_CONSOLE("OK: drop succeeded"); + return 0; + } + + if (Result.status_code) + { + ZEN_ERROR("Drop failed: {}: {} ({})", Result.status_code, Result.reason, Result.text); + } + else + { + ZEN_ERROR("Drop failed: {}", Result.error.message); + } + + return 1; +} + +/////////////////////////////////////// + +ProjectInfoCommand::ProjectInfoCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>"); + m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>"); + m_Options.parse_positional({"project", "oplog"}); +} + +ProjectInfoCommand::~ProjectInfoCommand() = default; + +int +ProjectInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Session Session; + Session.SetHeader(cpr::Header{{"Accept", "application/json"}}); + if (m_ProjectName.empty()) + { + ZEN_CONSOLE("Info from '{}'", m_HostName); + Session.SetUrl({fmt::format("{}/prj", m_HostName)}); + } + else if (m_OplogName.empty()) + { + ZEN_CONSOLE("Info on project '{}' from '{}'", m_ProjectName, m_HostName); + Session.SetUrl({fmt::format("{}/prj/{}", m_HostName, m_ProjectName)}); + } + else + { + ZEN_CONSOLE("Info on oplog '{}/{}' from '{}'", m_ProjectName, m_OplogName, m_HostName); + Session.SetUrl({fmt::format("{}/prj/{}/oplog/{}", m_HostName, m_ProjectName, m_OplogName)}); + } + + cpr::Response Result = Session.Get(); + + if (zen::IsHttpSuccessCode(Result.status_code)) + { + ZEN_CONSOLE("{}", Result.text); + + return 0; + } + + if (Result.status_code) + { + ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text); + } + else + { + ZEN_ERROR("Info failed: {}", Result.error.message); + } + + return 1; +} + +/////////////////////////////////////// + +CreateProjectCommand::CreateProjectCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectId), "<projectid>"); + m_Options.add_option("", "", "rootdir", "Absolute path to root directory", cxxopts::value(m_RootDir), "<root>"); + m_Options.add_option("", "", "enginedir", "Absolute path to engine root directory", cxxopts::value(m_EngineRootDir), "<engineroot>"); + m_Options.add_option("", "", "projectdir", "Absolute path to project directory", cxxopts::value(m_ProjectRootDir), "<projectroot>"); + m_Options.add_option("", "", "projectfile", "Absolute path to .uproject file", cxxopts::value(m_ProjectFile), "<projectfile>"); + m_Options.parse_positional({"project", "rootdir", "enginedir", "projectdir", "projectfile"}); +} + +CreateProjectCommand::~CreateProjectCommand() = default; + +int +CreateProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + using namespace std::literals; + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Session Session; + Session.SetHeader(cpr::Header{{"Accept", "application/json"}}); + + if (m_ProjectId.empty()) + { + ZEN_ERROR("Project name must be given"); + return 1; + } + + Session.SetUrl({fmt::format("{}/prj/{}", m_HostName, m_ProjectId)}); + cpr::Response Response = Session.Get(); + if (zen::IsHttpSuccessCode(Response.status_code)) + { + ZEN_CONSOLE("Project already exists.\n{}", Response.text); + return 1; + } + + if (Response.status_code == static_cast<long>(zen::HttpResponseCode::NotFound)) + { + zen::CbObjectWriter Project; + Project.AddString("id"sv, m_ProjectId); + Project.AddString("root"sv, m_RootDir); + Project.AddString("engine"sv, m_EngineRootDir); + Project.AddString("project"sv, m_ProjectRootDir); + Project.AddString("projectfile"sv, m_ProjectFile); + zen::IoBuffer ProjectPayload = Project.Save().GetBuffer().AsIoBuffer(); + Session.SetBody(cpr::Body{(const char*)ProjectPayload.GetData(), ProjectPayload.GetSize()}); + Session.SetHeader(cpr::Header{{"Accept", "text"}}); + Response = Session.Post(); + } + + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); +} + +/////////////////////////////////////// + +CreateOplogCommand::CreateOplogCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectId), "<projectid>"); + m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogId), "<oplogid>"); + m_Options.add_option("", "", "gcpath", "Absolute path to oplog lifetime marker file", cxxopts::value(m_GcPath), "<path>"); + m_Options.parse_positional({"project", "oplog", "gcpath"}); +} + +CreateOplogCommand::~CreateOplogCommand() = default; + +int +CreateOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + using namespace std::literals; + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Session Session; + Session.SetHeader(cpr::Header{{"Accept", "application/json"}}); + + if (m_ProjectId.empty()) + { + ZEN_ERROR("Project name must be given"); + return 1; + } + + if (m_OplogId.empty()) + { + ZEN_ERROR("Oplog name must be given"); + return 1; + } + + Session.SetUrl({fmt::format("{}/prj/{}/oplog/{}", m_HostName, m_ProjectId, m_OplogId)}); + cpr::Response Response = Session.Get(); + if (zen::IsHttpSuccessCode(Response.status_code)) + { + ZEN_CONSOLE("Oplog already exists.\n{}", Response.text); + return 1; + } + + if (Response.status_code == static_cast<long>(zen::HttpResponseCode::NotFound)) + { + Session.SetHeader(cpr::Header{{"Accept", "text"}}); + if (!m_GcPath.empty()) + { + zen::CbObjectWriter Oplog; + Oplog.AddString("gcpath"sv, m_GcPath); + zen::IoBuffer OplogPayload = Oplog.Save().GetBuffer().AsIoBuffer(); + Session.SetBody(cpr::Body{(const char*)OplogPayload.GetData(), OplogPayload.GetSize()}); + Session.SetHeader(cpr::Header{{"Accept", "text"}, {"Content-Type", std::string(ToString(zen::HttpContentType::kCbObject))}}); + } + + Response = Session.Post(); + } + + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + + return MapHttpToCommandReturnCode(Response); +} + +/////////////////////////////////////// + +ExportOplogCommand::ExportOplogCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>"); + m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>"); + m_Options.add_option("", "", "maxblocksize", "Max size for bundled attachments", cxxopts::value(m_MaxBlockSize), "<blocksize>"); + m_Options.add_option("", + "", + "maxchunkembedsize", + "Max size for attachment to be bundled", + cxxopts::value(m_MaxChunkEmbedSize), + "<chunksize>"); + m_Options.add_option("", "f", "force", "Force export of all attachments", cxxopts::value(m_Force), "<force>"); + m_Options.add_option("", + "", + "disableblocks", + "Disable block creation and save all attachments individually (applies to file and cloud target)", + cxxopts::value(m_DisableBlocks), + "<disable>"); + + m_Options.add_option("", "", "cloud", "Cloud Storage URL", cxxopts::value(m_CloudUrl), "<url>"); + m_Options.add_option("cloud", "", "namespace", "Cloud Storage namespace", cxxopts::value(m_CloudNamespace), "<namespace>"); + m_Options.add_option("cloud", "", "bucket", "Cloud Storage bucket", cxxopts::value(m_CloudBucket), "<bucket>"); + m_Options.add_option("cloud", "", "key", "Cloud Storage key", cxxopts::value(m_CloudKey), "<key>"); + m_Options + .add_option("cloud", "", "openid-provider", "Cloud Storage openid provider", cxxopts::value(m_CloudOpenIdProvider), "<provider>"); + m_Options.add_option("cloud", "", "access-token", "Cloud Storage access token", cxxopts::value(m_CloudAccessToken), "<accesstoken>"); + m_Options.add_option("cloud", + "", + "access-token-env", + "Name of environment variable that holds the cloud Storage access token", + cxxopts::value(m_CloudAccessTokenEnv)->default_value(DefaultCloudAccessTokenEnvVariableName), + "<envvariable>"); + m_Options.add_option("cloud", + "", + "disabletempblocks", + "Disable temp block creation and upload blocks without waiting for oplog container to be uploaded", + cxxopts::value(m_CloudDisableTempBlocks), + "<disable>"); + + m_Options.add_option("", "", "zen", "Zen service upload address", cxxopts::value(m_ZenUrl), "<url>"); + m_Options.add_option("zen", "", "target-project", "Zen target project name", cxxopts::value(m_ZenProjectName), "<targetprojectid>"); + m_Options.add_option("zen", "", "target-oplog", "Zen target oplog name", cxxopts::value(m_ZenOplogName), "<targetoplogid>"); + m_Options.add_option("zen", "", "clean", "Delete existing target Zen oplog", cxxopts::value(m_ZenClean), "<clean>"); + + m_Options.add_option("", "", "file", "Local folder path", cxxopts::value(m_FileDirectoryPath), "<path>"); + m_Options.add_option("file", "", "name", "Local file name", cxxopts::value(m_FileName), "<filename>"); + m_Options.add_option("file", + "", + "forcetempblocks", + "Force creation of temp attachment blocks", + cxxopts::value(m_FileForceEnableTempBlocks), + "<forcetempblocks>"); + + m_Options.parse_positional({"project", "oplog"}); +} + +ExportOplogCommand::~ExportOplogCommand() = default; + +int +ExportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + using namespace std::literals; + + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + if (m_ProjectName.empty()) + { + ZEN_ERROR("Project name must be given"); + return 1; + } + + if (m_OplogName.empty()) + { + ZEN_ERROR("Oplog name must be given"); + return 1; + } + + size_t TargetCount = 0; + TargetCount += m_CloudUrl.empty() ? 0 : 1; + TargetCount += m_ZenUrl.empty() ? 0 : 1; + TargetCount += m_FileDirectoryPath.empty() ? 0 : 1; + if (TargetCount != 1) + { + ZEN_ERROR("Provide one target only"); + ZEN_CONSOLE("{}", m_Options.help({""}).c_str()); + return 1; + } + + cpr::Session Session; + + if (!m_CloudUrl.empty()) + { + if (m_CloudNamespace.empty() || m_CloudBucket.empty()) + { + ZEN_ERROR("Options for cloud target are missing"); + ZEN_CONSOLE("{}", m_Options.help({"cloud"}).c_str()); + return 1; + } + if (m_CloudKey.empty()) + { + std::string KeyString = fmt::format("{}/{}/{}/{}", m_ProjectName, m_OplogName, m_CloudNamespace, m_CloudBucket); + zen::IoHash Key = zen::IoHash::HashBuffer(KeyString.data(), KeyString.size()); + m_CloudKey = Key.ToHexString(); + ZEN_WARN("Using auto generated cloud key '{}'", m_CloudKey); + } + } + + if (!m_ZenUrl.empty()) + { + if (m_ZenProjectName.empty()) + { + m_ZenProjectName = m_ProjectName; + ZEN_WARN("Using default zen target project id '{}'", m_ZenProjectName); + } + if (m_ZenOplogName.empty()) + { + m_ZenOplogName = m_OplogName; + ZEN_WARN("Using default zen target oplog id '{}'", m_ZenOplogName); + } + + std::string TargetUrlBase = fmt::format("{}/prj", m_ZenUrl); + if (TargetUrlBase.find("://") == std::string::npos) + { + // Assume https URL + TargetUrlBase = fmt::format("http://{}", TargetUrlBase); + } + + Session.SetUrl({fmt::format("{}/{}/oplog/{}", TargetUrlBase, m_ZenProjectName, m_ZenOplogName)}); + cpr::Response Response = Session.Get(); + if (Response.status_code == static_cast<long>(zen::HttpResponseCode::NotFound)) + { + ZEN_WARN("Automatically creating oplog '{}/{}'", m_ZenProjectName, m_ZenOplogName) + Response = Session.Post(); + if (!zen::IsHttpSuccessCode(Response.status_code)) + { + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); + } + } + else if (!zen::IsHttpSuccessCode(Response.status_code)) + { + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); + } + else if (m_ZenClean) + { + ZEN_WARN("Cleaning oplog '{}/{}'", m_ZenProjectName, m_ZenOplogName) + Response = Session.Delete(); + if (!zen::IsHttpSuccessCode(Response.status_code)) + { + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); + } + Response = Session.Post(); + if (!zen::IsHttpSuccessCode(Response.status_code)) + { + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); + } + } + } + + if (!m_FileDirectoryPath.empty()) + { + if (m_FileName.empty()) + { + m_FileName = m_OplogName; + ZEN_WARN("Using default file name '{}'", m_FileName); + } + } + + const std::string SourceUrlBase = fmt::format("{}/prj", m_HostName); + std::string TargetDescription; + Session.SetUrl({fmt::format("{}/{}/oplog/{}/rpc", SourceUrlBase, m_ProjectName, m_OplogName)}); + Session.SetHeader({{"Content-Type", std::string(zen::MapContentTypeToString(zen::HttpContentType::kCbObject))}}); + zen::CbObjectWriter Writer; + Writer.AddString("method"sv, "export"sv); + Writer.BeginObject("params"sv); + { + if (m_MaxBlockSize != 0) + { + Writer.AddInteger("maxblocksize"sv, m_MaxBlockSize); + } + if (m_MaxChunkEmbedSize != 0) + { + Writer.AddInteger("maxchunkembedsize"sv, m_MaxChunkEmbedSize); + } + if (m_Force) + { + Writer.AddBool("force"sv, true); + } + if (!m_FileDirectoryPath.empty()) + { + Writer.BeginObject("file"sv); + { + Writer.AddString("file"sv, m_FileDirectoryPath); + Writer.AddString("name"sv, m_FileName); + if (m_DisableBlocks) + { + Writer.AddBool("disableblocks"sv, true); + } + if (m_FileForceEnableTempBlocks) + { + Writer.AddBool("enabletempblocks"sv, true); + } + } + Writer.EndObject(); // "file" + TargetDescription = fmt::format("[file] '{}/{}'", m_FileDirectoryPath, m_FileName); + } + if (!m_CloudUrl.empty()) + { + Writer.BeginObject("cloud"sv); + { + Writer.AddString("url"sv, m_CloudUrl); + Writer.AddString("namespace"sv, m_CloudNamespace); + Writer.AddString("bucket"sv, m_CloudBucket); + Writer.AddString("key"sv, m_CloudKey); + if (!m_CloudOpenIdProvider.empty()) + { + Writer.AddString("openid-provider"sv, m_CloudOpenIdProvider); + } + if (!m_CloudAccessToken.empty()) + { + Writer.AddString("access-token"sv, m_CloudAccessToken); + } + if (!m_CloudAccessTokenEnv.empty()) + { + Writer.AddString("access-token-env"sv, m_CloudAccessTokenEnv); + } + if (m_DisableBlocks) + { + Writer.AddBool("disableblocks"sv, true); + } + if (m_CloudDisableTempBlocks) + { + Writer.AddBool("disabletempblocks"sv, true); + } + } + Writer.EndObject(); // "cloud" + TargetDescription = fmt::format("[cloud] '{}/{}/{}/{}'", m_CloudUrl, m_CloudNamespace, m_CloudBucket, m_CloudKey); + } + if (!m_ZenUrl.empty()) + { + Writer.BeginObject("zen"sv); + { + Writer.AddString("url"sv, m_ZenUrl); + Writer.AddString("project"sv, m_ZenProjectName); + Writer.AddString("oplog"sv, m_ZenOplogName); + } + Writer.EndObject(); // "zen" + + TargetDescription = fmt::format("[zen] '{}/{}/{}'", m_ZenUrl, m_ZenProjectName, m_ZenOplogName); + } + } + Writer.EndObject(); // "params" + + zen::BinaryWriter MemOut; + Writer.Save(MemOut); + Session.SetBody(cpr::Body{(const char*)MemOut.GetData(), MemOut.GetSize()}); + + ZEN_CONSOLE("Saving oplog '{}/{}' from '{}' to {}", m_ProjectName, m_OplogName, m_HostName, TargetDescription); + cpr::Response Response = Session.Post(); + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); +} + +//////////////////////////// + +ImportOplogCommand::ImportOplogCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>"); + m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>"); + m_Options.add_option("", "", "maxblocksize", "Max size for bundled attachments", cxxopts::value(m_MaxBlockSize), "<blocksize>"); + m_Options.add_option("", + "", + "maxchunkembedsize", + "Max size for attachment to be bundled", + cxxopts::value(m_MaxChunkEmbedSize), + "<chunksize>"); + m_Options.add_option("", "f", "force", "Force import of all attachments", cxxopts::value(m_Force), "<force>"); + + m_Options.add_option("", "", "cloud", "Cloud Storage URL", cxxopts::value(m_CloudUrl), "<url>"); + m_Options.add_option("cloud", "", "namespace", "Cloud Storage namespace", cxxopts::value(m_CloudNamespace), "<namespace>"); + m_Options.add_option("cloud", "", "bucket", "Cloud Storage bucket", cxxopts::value(m_CloudBucket), "<bucket>"); + m_Options.add_option("cloud", "", "key", "Cloud Storage key", cxxopts::value(m_CloudKey), "<key>"); + m_Options + .add_option("cloud", "", "openid-provider", "Cloud Storage openid provider", cxxopts::value(m_CloudOpenIdProvider), "<provider>"); + m_Options.add_option("cloud", "", "access-token", "Cloud Storage access token", cxxopts::value(m_CloudAccessToken), "<accesstoken>"); + m_Options.add_option("cloud", + "", + "access-token-env", + "Name of environment variable that holds the cloud Storage access token", + cxxopts::value(m_CloudAccessTokenEnv)->default_value(DefaultCloudAccessTokenEnvVariableName), + "<envvariable>"); + + m_Options.add_option("", "", "zen", "Zen service upload address", cxxopts::value(m_ZenUrl), "<url>"); + m_Options.add_option("zen", "", "source-project", "Zen source project name", cxxopts::value(m_ZenProjectName), "<sourceprojectid>"); + m_Options.add_option("zen", "", "source-oplog", "Zen source oplog name", cxxopts::value(m_ZenOplogName), "<sourceoplogid>"); + m_Options.add_option("zen", "", "clean", "Delete existing target Zen oplog", cxxopts::value(m_ZenClean), "<clean>"); + + m_Options.add_option("", "", "file", "Local folder path", cxxopts::value(m_FileDirectoryPath), "<path>"); + m_Options.add_option("file", "", "name", "Local file name", cxxopts::value(m_FileName), "<filename>"); + + m_Options.parse_positional({"project", "oplog"}); +} + +ImportOplogCommand::~ImportOplogCommand() = default; + +int +ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + using namespace std::literals; + + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + if (m_ProjectName.empty()) + { + ZEN_ERROR("Project name must be given"); + return 1; + } + + if (m_OplogName.empty()) + { + ZEN_ERROR("Oplog name must be given"); + return 1; + } + + size_t TargetCount = 0; + TargetCount += m_CloudUrl.empty() ? 0 : 1; + TargetCount += m_ZenUrl.empty() ? 0 : 1; + TargetCount += m_FileDirectoryPath.empty() ? 0 : 1; + if (TargetCount != 1) + { + ZEN_ERROR("Provide one source only"); + ZEN_CONSOLE("{}", m_Options.help({""}).c_str()); + return 1; + } + + cpr::Session Session; + + if (!m_CloudUrl.empty()) + { + if (m_CloudNamespace.empty() || m_CloudBucket.empty()) + { + ZEN_ERROR("Options for cloud source are missing"); + ZEN_CONSOLE("{}", m_Options.help({"cloud"}).c_str()); + return 1; + } + if (m_CloudKey.empty()) + { + std::string KeyString = fmt::format("{}/{}/{}/{}", m_ProjectName, m_OplogName, m_CloudNamespace, m_CloudBucket); + zen::IoHash Key = zen::IoHash::HashBuffer(KeyString.data(), KeyString.size()); + m_CloudKey = Key.ToHexString(); + ZEN_WARN("Using auto generated cloud key '{}'", m_CloudKey); + } + } + + if (!m_ZenUrl.empty()) + { + if (m_ZenProjectName.empty()) + { + m_ZenProjectName = m_ProjectName; + ZEN_WARN("Using default zen target project id '{}'", m_ZenProjectName); + } + if (m_ZenOplogName.empty()) + { + m_ZenOplogName = m_OplogName; + ZEN_WARN("Using default zen target oplog id '{}'", m_ZenOplogName); + } + } + + if (!m_FileDirectoryPath.empty()) + { + if (m_FileName.empty()) + { + m_FileName = m_OplogName; + ZEN_WARN("Using auto generated file name '{}'", m_FileName); + } + } + + const std::string TargetUrlBase = fmt::format("{}/prj", m_HostName); + Session.SetUrl({fmt::format("{}/{}/oplog/{}", TargetUrlBase, m_ProjectName, m_OplogName)}); + cpr::Response Response = Session.Get(); + if (Response.status_code == static_cast<long>(zen::HttpResponseCode::NotFound)) + { + ZEN_WARN("Automatically creating oplog '{}/{}'", m_ProjectName, m_OplogName) + Response = Session.Post(); + if (!zen::IsHttpSuccessCode(Response.status_code)) + { + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); + } + } + else if (!zen::IsHttpSuccessCode(Response.status_code)) + { + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); + } + else if (m_ZenClean) + { + ZEN_WARN("Cleaning oplog '{}/{}'", m_ProjectName, m_OplogName) + Response = Session.Delete(); + if (!zen::IsHttpSuccessCode(Response.status_code)) + { + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); + } + Response = Session.Post(); + if (!zen::IsHttpSuccessCode(Response.status_code)) + { + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); + } + } + + std::string SourceDescription; + Session.SetUrl(fmt::format("{}/{}/oplog/{}/rpc", TargetUrlBase, m_ProjectName, m_OplogName)); + Session.SetHeader({{"Content-Type", std::string(zen::MapContentTypeToString(zen::HttpContentType::kCbObject))}}); + + zen::CbObjectWriter Writer; + Writer.AddString("method"sv, "import"sv); + Writer.BeginObject("params"sv); + { + if (m_Force) + { + Writer.AddBool("force"sv, true); + } + if (!m_FileDirectoryPath.empty()) + { + Writer.BeginObject("file"sv); + { + Writer.AddString("file"sv, m_FileDirectoryPath); + Writer.AddString("name"sv, m_FileName); + } + Writer.EndObject(); // "file" + SourceDescription = fmt::format("[file] '{}/{}'", m_FileDirectoryPath, m_FileName); + } + if (!m_CloudUrl.empty()) + { + Writer.BeginObject("cloud"sv); + { + Writer.AddString("url"sv, m_CloudUrl); + Writer.AddString("namespace"sv, m_CloudNamespace); + Writer.AddString("bucket"sv, m_CloudBucket); + Writer.AddString("key"sv, m_CloudKey); + if (!m_CloudOpenIdProvider.empty()) + { + Writer.AddString("openid-provider"sv, m_CloudOpenIdProvider); + } + if (!m_CloudAccessToken.empty()) + { + Writer.AddString("access-token"sv, m_CloudAccessToken); + } + if (!m_CloudAccessTokenEnv.empty()) + { + Writer.AddString("access-token-env"sv, m_CloudAccessTokenEnv); + } + } + Writer.EndObject(); // "cloud" + SourceDescription = fmt::format("[cloud] '{}/{}/{}/{}'", m_CloudUrl, m_CloudNamespace, m_CloudBucket, m_CloudKey); + } + if (!m_ZenUrl.empty()) + { + Writer.BeginObject("zen"sv); + { + Writer.AddString("url"sv, m_ZenUrl); + Writer.AddString("project"sv, m_ZenProjectName); + Writer.AddString("oplog"sv, m_ZenOplogName); + } + Writer.EndObject(); // "zen" + SourceDescription = fmt::format("[zen] '{}'", m_ZenUrl); + } + } + Writer.EndObject(); // "params" + + zen::BinaryWriter MemOut; + Writer.Save(MemOut); + Session.SetBody(cpr::Body{(const char*)MemOut.GetData(), MemOut.GetSize()}); + + ZEN_CONSOLE("Loading oplog '{}/{}' from '{}' to {}", m_ProjectName, m_OplogName, SourceDescription, m_HostName); + Response = Session.Post(); + + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); +} + +ProjectStatsCommand::ProjectStatsCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); +} + +ProjectStatsCommand::~ProjectStatsCommand() = default; + +int +ProjectStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Session Session; + Session.SetUrl({fmt::format("{}/stats/prj", m_HostName)}); + Session.SetHeader(cpr::Header{{"Accept", "application/json"}}); + + cpr::Response Result = Session.Get(); + + if (zen::IsHttpSuccessCode(Result.status_code)) + { + ZEN_CONSOLE("{}", Result.text); + + return 0; + } + + if (Result.status_code) + { + ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text); + } + else + { + ZEN_ERROR("Info failed: {}", Result.error.message); + } + + return 1; +} + +ProjectDetailsCommand::ProjectDetailsCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "c", "csv", "Output in CSV format (default is JSon)", cxxopts::value(m_CSV), "<csv>"); + m_Options.add_option("", "d", "details", "Detailed info on opslog", cxxopts::value(m_Details), "<details>"); + m_Options.add_option("", "o", "opdetails", "Details info on oplog body", cxxopts::value(m_OpDetails), "<opdetails>"); + m_Options.add_option("", "p", "project", "Project name to get info from", cxxopts::value(m_ProjectName), "<projectid>"); + m_Options.add_option("", "l", "oplog", "Oplog name to get info from", cxxopts::value(m_OplogName), "<oplogid>"); + m_Options.add_option("", "i", "opid", "Oid of a specific op info for", cxxopts::value(m_OpId), "<opid>"); + m_Options.add_option("", + "a", + "attachmentdetails", + "Get detailed information about attachments", + cxxopts::value(m_AttachmentDetails), + "<attachmentdetails>"); +} + +ProjectDetailsCommand::~ProjectDetailsCommand() = default; + +int +ProjectDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Session Session; + cpr::Parameters Parameters; + if (m_OpDetails) + { + Parameters.Add({"opdetails", "true"}); + } + if (m_Details) + { + Parameters.Add({"details", "true"}); + } + if (m_AttachmentDetails) + { + Parameters.Add({"attachmentdetails", "true"}); + } + if (m_CSV) + { + Parameters.Add({"csv", "true"}); + } + else + { + Session.SetHeader(cpr::Header{{"Accept", "application/json"}}); + } + + if (!m_OpId.empty()) + { + if (m_ProjectName.empty() || m_OplogName.empty()) + { + ZEN_ERROR("Provide project and oplog name"); + ZEN_CONSOLE("{}", m_Options.help({""}).c_str()); + return 1; + } + Session.SetUrl({fmt::format("{}/prj/details$/{}/{}/{}", m_HostName, m_ProjectName, m_OplogName, m_OpId)}); + } + else if (!m_OplogName.empty()) + { + if (m_ProjectName.empty()) + { + ZEN_ERROR("Provide project name"); + ZEN_CONSOLE("{}", m_Options.help({""}).c_str()); + return 1; + } + Session.SetUrl({fmt::format("{}/prj/details$/{}/{}", m_HostName, m_ProjectName, m_OplogName)}); + } + else if (!m_ProjectName.empty()) + { + Session.SetUrl({fmt::format("{}/prj/details$/{}", m_HostName, m_ProjectName)}); + } + else + { + Session.SetUrl({fmt::format("{}/prj/details$", m_HostName)}); + } + Session.SetParameters(Parameters); + + cpr::Response Result = Session.Get(); + + if (zen::IsHttpSuccessCode(Result.status_code)) + { + ZEN_CONSOLE("{}", Result.text); + + return 0; + } + + if (Result.status_code) + { + ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text); + } + else + { + ZEN_ERROR("Info failed: {}", Result.error.message); + } + + return 1; +} diff --git a/src/zen/cmds/projectstore.h b/src/zen/cmds/projectstore.h new file mode 100644 index 000000000..10927a546 --- /dev/null +++ b/src/zen/cmds/projectstore.h @@ -0,0 +1,180 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +class DropProjectCommand : public ZenCmdBase +{ +public: + DropProjectCommand(); + ~DropProjectCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"project-drop", "Drop project or project oplog"}; + std::string m_HostName; + std::string m_ProjectName; + std::string m_OplogName; +}; + +class ProjectInfoCommand : public ZenCmdBase +{ +public: + ProjectInfoCommand(); + ~ProjectInfoCommand(); + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"project-info", "Info on project or project oplog"}; + std::string m_HostName; + std::string m_ProjectName; + std::string m_OplogName; +}; + +class CreateProjectCommand : public ZenCmdBase +{ +public: + CreateProjectCommand(); + ~CreateProjectCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"project-create", "Create project"}; + std::string m_HostName; + std::string m_ProjectId; + std::string m_RootDir; + std::string m_EngineRootDir; + std::string m_ProjectRootDir; + std::string m_ProjectFile; +}; + +class CreateOplogCommand : public ZenCmdBase +{ +public: + CreateOplogCommand(); + ~CreateOplogCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"oplog-create", "Create oplog"}; + std::string m_HostName; + std::string m_ProjectId; + std::string m_OplogId; + std::string m_GcPath; +}; + +class ExportOplogCommand : public ZenCmdBase +{ +public: + ExportOplogCommand(); + ~ExportOplogCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"oplog-export", + "Export project store oplog to cloud (--cloud), file system (--file) or other Zen instance (--zen)"}; + std::string m_HostName; + std::string m_ProjectName; + std::string m_OplogName; + uint64_t m_MaxBlockSize = 0; + uint64_t m_MaxChunkEmbedSize = 0; + bool m_Force = false; + bool m_DisableBlocks = false; + + std::string m_CloudUrl; + std::string m_CloudNamespace; + std::string m_CloudBucket; + std::string m_CloudKey; + std::string m_CloudOpenIdProvider; + std::string m_CloudAccessToken; + std::string m_CloudAccessTokenEnv; + bool m_CloudDisableTempBlocks = false; + + std::string m_ZenUrl; + std::string m_ZenProjectName; + std::string m_ZenOplogName; + bool m_ZenClean; + + std::string m_FileDirectoryPath; + std::string m_FileName; + bool m_FileForceEnableTempBlocks = false; +}; + +class ImportOplogCommand : public ZenCmdBase +{ +public: + ImportOplogCommand(); + ~ImportOplogCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"oplog-import", + "Import project store oplog from cloud (--cloud), file system (--file) or other Zen instance (--zen)"}; + std::string m_HostName; + std::string m_ProjectName; + std::string m_OplogName; + size_t m_MaxBlockSize = 0; + size_t m_MaxChunkEmbedSize = 0; + bool m_Force = false; + + std::string m_CloudUrl; + std::string m_CloudNamespace; + std::string m_CloudBucket; + std::string m_CloudKey; + std::string m_CloudOpenIdProvider; + std::string m_CloudAccessToken; + std::string m_CloudAccessTokenEnv; + + std::string m_ZenUrl; + std::string m_ZenProjectName; + std::string m_ZenOplogName; + bool m_ZenClean; + + std::string m_FileDirectoryPath; + std::string m_FileName; +}; + +class ProjectStatsCommand : public ZenCmdBase +{ +public: + ProjectStatsCommand(); + ~ProjectStatsCommand(); + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"project-stats", "Stats info on project store"}; + std::string m_HostName; +}; + +class ProjectDetailsCommand : public ZenCmdBase +{ +public: + ProjectDetailsCommand(); + ~ProjectDetailsCommand(); + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"project-details", "Detail info on project store"}; + std::string m_HostName; + bool m_Details; + bool m_OpDetails; + bool m_AttachmentDetails; + bool m_CSV; + std::string m_ProjectName; + std::string m_OplogName; + std::string m_OpId; +}; diff --git a/src/zen/cmds/rpcreplay.cpp b/src/zen/cmds/rpcreplay.cpp new file mode 100644 index 000000000..9bc4b2c7b --- /dev/null +++ b/src/zen/cmds/rpcreplay.cpp @@ -0,0 +1,417 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "rpcreplay.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/workthreadpool.h> +#include <zenhttp/httpcommon.h> +#include <zenhttp/httpshared.h> +#include <zenutil/cache/rpcrecording.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <fmt/format.h> +#include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <memory> + +namespace zen { + +using namespace std::literals; + +RpcStartRecordingCommand::RpcStartRecordingCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "p", "path", "Recording file path", cxxopts::value(m_RecordingPath), "<path>"); + + m_Options.parse_positional("path"); +} + +RpcStartRecordingCommand::~RpcStartRecordingCommand() = default; + +int +RpcStartRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions, argc, argv); + if (!ParseOptions(argc, argv)) + { + return 0; + } + + if (m_RecordingPath.empty()) + { + throw cxxopts::OptionParseException("Rpc start recording command requires a path"); + } + + cpr::Session Session; + Session.SetUrl(fmt::format("{}/z$/exec$/start-recording"sv, m_HostName)); + Session.SetParameters({{"path", m_RecordingPath}}); + cpr::Response Response = Session.Post(); + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); +} + +//////////////////////////////////////////////////// + +RpcStopRecordingCommand::RpcStopRecordingCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); +} + +RpcStopRecordingCommand::~RpcStopRecordingCommand() = default; + +int +RpcStopRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions, argc, argv); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Session Session; + Session.SetUrl(fmt::format("{}/z$/exec$/stop-recording"sv, m_HostName)); + cpr::Response Response = Session.Post(); + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); +} + +//////////////////////////////////////////////////// + +RpcReplayCommand::RpcReplayCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", "p", "path", "Recording file path", cxxopts::value(m_RecordingPath), "<path>"); + m_Options.add_option("", + "w", + "numthreads", + "Number of worker threads per process", + cxxopts::value(m_ThreadCount)->default_value(fmt::format("{}", std::thread::hardware_concurrency())), + "<count>"); + m_Options.add_option("", "", "onhost", "Replay on host, bypassing http/network layer", cxxopts::value(m_OnHost), "<onhost>"); + m_Options.add_option("", + "", + "showmethodstats", + "Show statistics of which RPC methods are used", + cxxopts::value(m_ShowMethodStats), + "<showmethodstats>"); + m_Options.add_option("", + "", + "offset", + "Offset into request recording to start replay", + cxxopts::value(m_Offset)->default_value("0"), + "<offset>"); + m_Options.add_option("", + "", + "stride", + "Stride for request recording when replaying requests", + cxxopts::value(m_Stride)->default_value("1"), + "<stride>"); + m_Options.add_option("", "", "numproc", "Number of worker processes", cxxopts::value(m_ProcessCount)->default_value("1"), "<count>"); + m_Options.add_option("", + "", + "forceallowlocalrefs", + "Force enable local refs in requests", + cxxopts::value(m_ForceAllowLocalRefs), + "<enable>"); + m_Options + .add_option("", "", "disablelocalrefs", "Force disable local refs in requests", cxxopts::value(m_DisableLocalRefs), "<enable>"); + m_Options.add_option("", + "", + "forceallowlocalhandlerefs", + "Force enable local refs as handles in requests", + cxxopts::value(m_ForceAllowLocalHandleRef), + "<enable>"); + m_Options.add_option("", + "", + "disablelocalhandlerefs", + "Force disable local refs as handles in requests", + cxxopts::value(m_DisableLocalHandleRefs), + "<enable>"); + m_Options.add_option("", + "", + "forceallowpartiallocalrefs", + "Force enable local refs for all sizes", + cxxopts::value(m_ForceAllowPartialLocalRefs), + "<enable>"); + m_Options.add_option("", + "", + "disablepartiallocalrefs", + "Force disable local refs for all sizes", + cxxopts::value(m_DisablePartialLocalRefs), + "<enable>"); + + m_Options.parse_positional("path"); +} + +RpcReplayCommand::~RpcReplayCommand() = default; + +int +RpcReplayCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions, argc, argv); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + if (m_RecordingPath.empty()) + { + throw cxxopts::OptionParseException("Rpc replay command requires a path"); + } + + if (m_OnHost) + { + cpr::Session Session; + Session.SetUrl(fmt::format("{}/z$/exec$/replay-recording"sv, m_HostName)); + Session.SetParameters({{"path", m_RecordingPath}, {"thread-count", fmt::format("{}", m_ThreadCount)}}); + cpr::Response Response = Session.Post(); + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + return MapHttpToCommandReturnCode(Response); + } + + std::unique_ptr<cache::IRpcRequestReplayer> Replayer = cache::MakeDiskRequestReplayer(m_RecordingPath, true); + uint64_t EntryCount = Replayer->GetRequestCount(); + + std::atomic_uint64_t EntryOffset = m_Offset; + std::atomic_uint64_t BytesSent = 0; + std::atomic_uint64_t BytesReceived = 0; + + Stopwatch Timer; + + if (m_ProcessCount > 1) + { + std::vector<std::unique_ptr<ProcessHandle>> WorkerProcesses; + WorkerProcesses.resize(m_ProcessCount); + + ProcessMonitor Monitor; + for (int ProcessIndex = 0; ProcessIndex < m_ProcessCount; ++ProcessIndex) + { + std::string CommandLine = + fmt::format("{} rpc-record-replay --hosturl {} --path \"{}\" --offset {} --stride {} --numthreads {} --numproc {}"sv, + argv[0], + m_HostName, + m_RecordingPath, + m_Stride == 1 ? 0 : m_Offset + ProcessIndex, + m_Stride, + m_ThreadCount, + 1); + CreateProcResult Result(CreateProc(std::filesystem::path(std::string(argv[0])), CommandLine)); + WorkerProcesses[ProcessIndex] = std::make_unique<ProcessHandle>(); + WorkerProcesses[ProcessIndex]->Initialize(Result); + Monitor.AddPid(WorkerProcesses[ProcessIndex]->Pid()); + } + while (Monitor.IsRunning()) + { + ZEN_CONSOLE("Waiting for worker processes..."); + Sleep(1000); + } + return 0; + } + else + { + std::map<std::string, size_t> MethodTypes; + RwLock MethodTypesLock; + + WorkerThreadPool WorkerPool(m_ThreadCount); + + Latch WorkLatch(m_ThreadCount); + for (int WorkerIndex = 0; WorkerIndex < m_ThreadCount; ++WorkerIndex) + { + WorkerPool.ScheduleWork( + [this, &WorkLatch, EntryCount, &EntryOffset, &Replayer, &BytesSent, &BytesReceived, &MethodTypes, &MethodTypesLock]() { + auto _ = MakeGuard([&WorkLatch]() { WorkLatch.CountDown(); }); + + cpr::Session Session; + Session.SetUrl(fmt::format("{}/z$/$rpc"sv, m_HostName)); + + uint64_t EntryIndex = EntryOffset.fetch_add(m_Stride); + while (EntryIndex < EntryCount) + { + IoBuffer Payload; + std::pair<ZenContentType, ZenContentType> Types = Replayer->GetRequest(EntryIndex, Payload); + ZenContentType RequestContentType = Types.first; + ZenContentType AcceptContentType = Types.second; + + CbPackage RequestPackage; + CbObject Request; + switch (RequestContentType) + { + case ZenContentType::kCbPackage: + { + if (ParsePackageMessageWithLegacyFallback(Payload, RequestPackage)) + { + Request = RequestPackage.GetObject(); + } + } + break; + case ZenContentType::kCbObject: + { + Request = LoadCompactBinaryObject(Payload); + } + break; + } + + RpcAcceptOptions OriginalAcceptOptions = static_cast<RpcAcceptOptions>(Request["AcceptFlags"sv].AsUInt16(0u)); + int OriginalProcessPid = Request["Pid"sv].AsInt32(0); + + int AdjustedPid = 0; + RpcAcceptOptions AdjustedAcceptOptions = RpcAcceptOptions::kNone; + if (!m_DisableLocalRefs) + { + if (EnumHasAnyFlags(OriginalAcceptOptions, RpcAcceptOptions::kAllowLocalReferences) || m_ForceAllowLocalRefs) + { + AdjustedAcceptOptions |= RpcAcceptOptions::kAllowLocalReferences; + if (!m_DisablePartialLocalRefs) + { + if (EnumHasAnyFlags(OriginalAcceptOptions, RpcAcceptOptions::kAllowPartialLocalReferences) || + m_ForceAllowPartialLocalRefs) + { + AdjustedAcceptOptions |= RpcAcceptOptions::kAllowPartialLocalReferences; + } + } + if (!m_DisableLocalHandleRefs) + { + if (OriginalProcessPid != 0 || m_ForceAllowLocalHandleRef) + { + AdjustedPid = GetCurrentProcessId(); + } + } + } + } + + if (m_ShowMethodStats) + { + std::string MethodName = std::string(Request["Method"sv].AsString()); + RwLock::ExclusiveLockScope __(MethodTypesLock); + if (auto It = MethodTypes.find(MethodName); It != MethodTypes.end()) + { + It->second++; + } + else + { + MethodTypes[MethodName] = 1; + } + } + + if (OriginalAcceptOptions != AdjustedAcceptOptions || OriginalProcessPid != AdjustedPid) + { + CbObjectWriter RequestCopyWriter; + for (const CbFieldView& Field : Request) + { + if (!Field.HasName()) + { + RequestCopyWriter.AddField(Field); + continue; + } + std::string_view FieldName = Field.GetName(); + if (FieldName == "Pid"sv) + { + continue; + } + if (FieldName == "AcceptFlags"sv) + { + continue; + } + RequestCopyWriter.AddField(FieldName, Field); + } + if (AdjustedPid != 0) + { + RequestCopyWriter.AddInteger("Pid"sv, AdjustedPid); + } + if (AdjustedAcceptOptions != RpcAcceptOptions::kNone) + { + RequestCopyWriter.AddInteger("AcceptFlags"sv, static_cast<uint16_t>(AdjustedAcceptOptions)); + } + + if (RequestContentType == ZenContentType::kCbPackage) + { + RequestPackage.SetObject(RequestCopyWriter.Save()); + std::vector<IoBuffer> Buffers = FormatPackageMessage(RequestPackage); + std::vector<SharedBuffer> SharedBuffers(Buffers.begin(), Buffers.end()); + Payload = CompositeBuffer(std::move(SharedBuffers)).Flatten().AsIoBuffer(); + } + else + { + RequestCopyWriter.Finalize(); + Payload = IoBuffer(RequestCopyWriter.GetSaveSize()); + RequestCopyWriter.Save(Payload.GetMutableView()); + } + } + + Session.SetHeader({{"Content-Type", std::string(MapContentTypeToString(RequestContentType))}, + {"Accept", std::string(MapContentTypeToString(AcceptContentType))}}); + uint64_t Offset = 0; + auto ReadCallback = [&Payload, &Offset](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, Payload.GetSize() - Offset); + IoBuffer PayloadRange = IoBuffer(Payload, Offset, size); + MutableMemoryView Data(buffer, size); + Data.CopyFrom(PayloadRange.GetView()); + Offset += size; + return true; + }; + Session.SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback)); + cpr::Response Response = Session.Post(); + BytesSent.fetch_add(Payload.GetSize()); + if (Response.error || !(IsHttpSuccessCode(Response.status_code) || + Response.status_code == gsl::narrow<long>(HttpResponseCode::NotFound))) + { + ZEN_CONSOLE("{}", FormatHttpResponse(Response)); + break; + } + BytesReceived.fetch_add(Response.downloaded_bytes); + EntryIndex = EntryOffset.fetch_add(m_Stride); + } + }); + } + + while (!WorkLatch.Wait(1000)) + { + ZEN_CONSOLE("Processing {} requests, {} remaining (sent {}, recevied {})...", + (EntryCount - m_Offset) / m_Stride, + (EntryCount - EntryOffset.load()) / m_Stride, + NiceBytes(BytesSent.load()), + NiceBytes(BytesReceived.load())); + } + if (m_ShowMethodStats) + { + for (const auto& It : MethodTypes) + { + ZEN_CONSOLE("{}: {}", It.first, It.second); + } + } + } + + const uint64_t RequestsSent = (EntryOffset.load() - m_Offset) / m_Stride; + const uint64_t ElapsedMS = Timer.GetElapsedTimeMs(); + const double ElapsedS = ElapsedMS / 1000.500; + const uint64_t Sent = BytesSent.load(); + const uint64_t Received = BytesReceived.load(); + const uint64_t RequestsPerS = static_cast<uint64_t>(RequestsSent / ElapsedS); + const uint64_t SentPerS = static_cast<uint64_t>(Sent / ElapsedS); + const uint64_t ReceivedPerS = static_cast<uint64_t>(Received / ElapsedS); + + ZEN_CONSOLE("Requests sent {} ({}/s), payloads sent {}B ({}B/s), payloads received {}B ({}B/s) in {}", + RequestsSent, + RequestsPerS, + NiceBytes(Sent), + NiceBytes(SentPerS), + NiceBytes(Received), + NiceBytes(ReceivedPerS), + NiceTimeSpanMs(ElapsedMS)); + + return 0; +} + +} // namespace zen diff --git a/src/zen/cmds/rpcreplay.h b/src/zen/cmds/rpcreplay.h new file mode 100644 index 000000000..742e5ec5b --- /dev/null +++ b/src/zen/cmds/rpcreplay.h @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +namespace zen { + +class RpcStartRecordingCommand : public ZenCmdBase +{ +public: + RpcStartRecordingCommand(); + ~RpcStartRecordingCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"rpc-record-start", "Starts recording of cache rpc requests on a host"}; + std::string m_HostName; + std::string m_RecordingPath; +}; + +class RpcStopRecordingCommand : public ZenCmdBase +{ +public: + RpcStopRecordingCommand(); + ~RpcStopRecordingCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"rpc-record-stop", "Stops recording of cache rpc requests on a host"}; + std::string m_HostName; +}; + +class RpcReplayCommand : public ZenCmdBase +{ +public: + RpcReplayCommand(); + ~RpcReplayCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"rpc-record-replay", "Replays a previously recorded session of cache rpc requests to a target host"}; + std::string m_HostName; + std::string m_RecordingPath; + bool m_OnHost = false; + bool m_ShowMethodStats = false; + int m_ProcessCount; + int m_ThreadCount; + uint64_t m_Offset; + uint64_t m_Stride; + bool m_ForceAllowLocalRefs; + bool m_DisableLocalRefs; + bool m_ForceAllowLocalHandleRef; + bool m_DisableLocalHandleRefs; + bool m_ForceAllowPartialLocalRefs; + bool m_DisablePartialLocalRefs; +}; + +} // namespace zen diff --git a/src/zen/cmds/scrub.cpp b/src/zen/cmds/scrub.cpp new file mode 100644 index 000000000..27ff5e0ac --- /dev/null +++ b/src/zen/cmds/scrub.cpp @@ -0,0 +1,154 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "scrub.h" +#include <zencore/logging.h> +#include <zenhttp/httpcommon.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +using namespace std::literals; + +namespace zen { + +ScrubCommand::ScrubCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); +} + +ScrubCommand::~ScrubCommand() = default; + +int +ScrubCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions, argc, argv); + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +GcCommand::GcCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); + m_Options.add_option("", + "s", + "smallobjects", + "Collect small objects", + cxxopts::value(m_SmallObjects)->default_value("false"), + "<smallobjects>"); + m_Options.add_option("", + "m", + "maxcacheduration", + "Max cache lifetime (in seconds)", + cxxopts::value(m_MaxCacheDuration)->default_value("0"), + "<maxcacheduration>"); + m_Options.add_option("", + "d", + "disksizesoftlimit", + "Max disk usage size (in bytes)", + cxxopts::value(m_DiskSizeSoftLimit)->default_value("0"), + "<disksizesoftlimit>"); +} + +GcCommand::~GcCommand() +{ +} + +int +GcCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions, argc, argv); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Parameters Params; + if (m_SmallObjects) + { + Params.Add({"smallobjects", "true"}); + } + if (m_MaxCacheDuration != 0) + { + Params.Add({"maxcacheduration", fmt::format("{}", m_MaxCacheDuration)}); + } + if (m_DiskSizeSoftLimit != 0) + { + Params.Add({"disksizesoftlimit", fmt::format("{}", m_DiskSizeSoftLimit)}); + } + + cpr::Session Session; + Session.SetHeader(cpr::Header{{"Accept", "application/json"}}); + Session.SetUrl({fmt::format("{}/admin/gc", m_HostName)}); + Session.SetParameters(Params); + + cpr::Response Result = Session.Post(); + + if (zen::IsHttpSuccessCode(Result.status_code)) + { + ZEN_CONSOLE("OK: {}", Result.text); + return 0; + } + + if (Result.status_code) + { + ZEN_ERROR("GC start failed: {}: {} ({})", Result.status_code, Result.reason, Result.text); + } + else + { + ZEN_ERROR("GC start failed: {}", Result.error.message); + } + + return 1; +} + +GcStatusCommand::GcStatusCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>"); +} + +GcStatusCommand::~GcStatusCommand() +{ +} + +int +GcStatusCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions, argc, argv); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + + cpr::Session Session; + Session.SetHeader(cpr::Header{{"Accept", "application/json"}}); + Session.SetUrl({fmt::format("{}/admin/gc", m_HostName)}); + + cpr::Response Result = Session.Get(); + + if (zen::IsHttpSuccessCode(Result.status_code)) + { + ZEN_CONSOLE("OK: {}", Result.text); + return 0; + } + + if (Result.status_code) + { + ZEN_ERROR("GC status failed: {}: {} ({})", Result.status_code, Result.reason, Result.text); + } + else + { + ZEN_ERROR("GC status failed: {}", Result.error.message); + } + + return 1; +} + +} // namespace zen diff --git a/src/zen/cmds/scrub.h b/src/zen/cmds/scrub.h new file mode 100644 index 000000000..ee8b4fdbb --- /dev/null +++ b/src/zen/cmds/scrub.h @@ -0,0 +1,58 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +namespace zen { + +/** Scrub storage + */ +class ScrubCommand : public ZenCmdBase +{ +public: + ScrubCommand(); + ~ScrubCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"scrub", "Scrub zen storage"}; + std::string m_HostName; +}; + +/** Garbage collect storage + */ +class GcCommand : public ZenCmdBase +{ +public: + GcCommand(); + ~GcCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"gc", "Garbage collect zen storage"}; + std::string m_HostName; + bool m_SmallObjects{false}; + uint64_t m_MaxCacheDuration{0}; + uint64_t m_DiskSizeSoftLimit{0}; +}; + +class GcStatusCommand : public ZenCmdBase +{ +public: + GcStatusCommand(); + ~GcStatusCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"gc-status", "Garbage collect zen storage status check"}; + std::string m_HostName; +}; + +} // namespace zen diff --git a/src/zen/cmds/status.cpp b/src/zen/cmds/status.cpp new file mode 100644 index 000000000..23c27f9f9 --- /dev/null +++ b/src/zen/cmds/status.cpp @@ -0,0 +1,41 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "status.h" + +#include <zencore/logging.h> +#include <zencore/string.h> +#include <zencore/uid.h> +#include <zenutil/zenserverprocess.h> + +namespace zen { + +StatusCommand::StatusCommand() +{ +} + +StatusCommand::~StatusCommand() = default; + +int +StatusCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions, argc, argv); + + ZenServerState State; + if (!State.InitializeReadOnly()) + { + ZEN_CONSOLE("no Zen state found"); + + return 0; + } + + ZEN_CONSOLE("{:>5} {:>6} {:>24}", "port", "pid", "session"); + State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) { + StringBuilder<25> SessionStringBuilder; + Entry.GetSessionId().ToString(SessionStringBuilder); + ZEN_CONSOLE("{:>5} {:>6} {:>24}", Entry.EffectiveListenPort, Entry.Pid, SessionStringBuilder.ToString()); + }); + + return 0; +} + +} // namespace zen diff --git a/src/zen/cmds/status.h b/src/zen/cmds/status.h new file mode 100644 index 000000000..98f72e651 --- /dev/null +++ b/src/zen/cmds/status.h @@ -0,0 +1,22 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +namespace zen { + +class StatusCommand : public ZenCmdBase +{ +public: + StatusCommand(); + ~StatusCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"status", "Show zen status"}; +}; + +} // namespace zen diff --git a/src/zen/cmds/top.cpp b/src/zen/cmds/top.cpp new file mode 100644 index 000000000..4fe8c9cdf --- /dev/null +++ b/src/zen/cmds/top.cpp @@ -0,0 +1,89 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "top.h" + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/uid.h> +#include <zenutil/zenserverprocess.h> + +#include <memory> + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +TopCommand::TopCommand() +{ +} + +TopCommand::~TopCommand() = default; + +int +TopCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions, argc, argv); + + ZenServerState State; + if (!State.InitializeReadOnly()) + { + ZEN_CONSOLE("no Zen state found"); + + return 0; + } + + int n = 0; + const int HeaderPeriod = 20; + + for (;;) + { + if ((n++ % HeaderPeriod) == 0) + { + ZEN_CONSOLE("{:>5} {:>6} {:>24}", "port", "pid", "session"); + } + + State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) { + StringBuilder<25> SessionStringBuilder; + Entry.GetSessionId().ToString(SessionStringBuilder); + ZEN_CONSOLE("{:>5} {:>6} {:>24}", Entry.EffectiveListenPort, Entry.Pid, SessionStringBuilder.ToString()); + }); + + zen::Sleep(1000); + + if (!State.IsReadOnly()) + { + State.Sweep(); + } + } + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +PsCommand::PsCommand() +{ +} + +PsCommand::~PsCommand() = default; + +int +PsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions, argc, argv); + + ZenServerState State; + if (!State.InitializeReadOnly()) + { + ZEN_CONSOLE("no Zen state found"); + + return 0; + } + + State.Snapshot( + [&](const ZenServerState::ZenServerEntry& Entry) { ZEN_CONSOLE("Port {} : pid {}", Entry.EffectiveListenPort, Entry.Pid); }); + + return 0; +} + +} // namespace zen diff --git a/src/zen/cmds/top.h b/src/zen/cmds/top.h new file mode 100644 index 000000000..83410587b --- /dev/null +++ b/src/zen/cmds/top.h @@ -0,0 +1,35 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +namespace zen { + +class TopCommand : public ZenCmdBase +{ +public: + TopCommand(); + ~TopCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"top", "Show dev UI"}; +}; + +class PsCommand : public ZenCmdBase +{ +public: + PsCommand(); + ~PsCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"ps", "Enumerate running Zen server instances"}; +}; + +} // namespace zen diff --git a/src/zen/cmds/up.cpp b/src/zen/cmds/up.cpp new file mode 100644 index 000000000..69bcbe829 --- /dev/null +++ b/src/zen/cmds/up.cpp @@ -0,0 +1,108 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "up.h" + +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zenutil/zenserverprocess.h> + +#include <memory> + +namespace zen { + +UpCommand::UpCommand() +{ +} + +UpCommand::~UpCommand() = default; + +int +UpCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions, argc, argv); + + std::filesystem::path ExePath = zen::GetRunningExecutablePath(); + + ZenServerEnvironment ServerEnvironment; + ServerEnvironment.Initialize(ExePath.parent_path()); + ZenServerInstance Server(ServerEnvironment); + Server.SpawnServer(); + + int Timeout = 10000; + + if (!Server.WaitUntilReady(Timeout)) + { + ZEN_ERROR("zen server launch failed (timed out)"); + } + else + { + ZEN_CONSOLE("zen server up"); + } + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +DownCommand::DownCommand() +{ + m_Options.add_option("", "p", "port", "Host port", cxxopts::value(m_Port)->default_value("1337"), "<hostport>"); +} + +DownCommand::~DownCommand() = default; + +int +DownCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + + if (!ParseOptions(argc, argv)) + { + return 0; + } + // Discover executing instances + + ZenServerState Instance; + Instance.Initialize(); + ZenServerState::ZenServerEntry* Entry = Instance.Lookup(m_Port); + + if (!Entry) + { + ZEN_WARN("no zen server to bring down"); + + return 0; + } + + try + { + std::filesystem::path ExePath = zen::GetRunningExecutablePath(); + + ZenServerEnvironment ServerEnvironment; + ServerEnvironment.Initialize(ExePath.parent_path()); + ZenServerInstance Server(ServerEnvironment); + Server.AttachToRunningServer(m_Port); + + ZEN_CONSOLE("attached to server on port {}, requesting shutdown", m_Port); + + Server.Shutdown(); + + ZEN_CONSOLE("shutdown complete"); + + return 0; + } + catch (std::exception& Ex) + { + ZEN_DEBUG("Exception caught when requesting shutdown: {}", Ex.what()); + } + + // Since we cannot obtain a handle to the process we are unable to block on the process + // handle to determine when the server has shut down. Thus we signal that we would like + // a shutdown via the shutdown flag and then + + ZEN_CONSOLE("requesting shutdown of server on port {}", m_Port); + Entry->SignalShutdownRequest(); + + return 0; +} + +} // namespace zen diff --git a/src/zen/cmds/up.h b/src/zen/cmds/up.h new file mode 100644 index 000000000..5af05541a --- /dev/null +++ b/src/zen/cmds/up.h @@ -0,0 +1,36 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +namespace zen { + +class UpCommand : public ZenCmdBase +{ +public: + UpCommand(); + ~UpCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"up", "Bring up zen service"}; +}; + +class DownCommand : public ZenCmdBase +{ +public: + DownCommand(); + ~DownCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"down", "Bring down zen service"}; + uint16_t m_Port; +}; + +} // namespace zen diff --git a/src/zen/cmds/version.cpp b/src/zen/cmds/version.cpp new file mode 100644 index 000000000..ba83b527d --- /dev/null +++ b/src/zen/cmds/version.cpp @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "version.h" + +#include <zencore/config.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zenhttp/httpcommon.h> +#include <zenutil/zenserverprocess.h> + +#include <memory> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +VersionCommand::VersionCommand() +{ + m_Options.add_options()("h,help", "Print help"); + m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName), "[hosturl]"); + m_Options.add_option("", "d", "detailed", "Detailed Version", cxxopts::value(m_DetailedVersion), "[detailedversion]"); + m_Options.parse_positional({"hosturl"}); +} + +VersionCommand::~VersionCommand() = default; + +int +VersionCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) +{ + ZEN_UNUSED(GlobalOptions); + if (!ParseOptions(argc, argv)) + { + return 0; + } + + std::string Version; + + if (m_HostName.empty()) + { + if (m_DetailedVersion) + { + Version = ZEN_CFG_VERSION_BUILD_STRING_FULL; + } + else + { + Version = ZEN_CFG_VERSION; + } + } + else + { + const std::string UrlBase = fmt::format("{}/health", m_HostName); + cpr::Session Session; + std::string VersionRequest = fmt::format("{}/version{}", UrlBase, m_DetailedVersion ? "?detailed=true" : ""); + Session.SetUrl(VersionRequest); + cpr::Response Response = Session.Get(); + if (!zen::IsHttpSuccessCode(Response.status_code)) + { + if (Response.status_code) + { + ZEN_ERROR("{} failed: {}: {} ({})", VersionRequest, Response.status_code, Response.reason, Response.text); + } + else + { + ZEN_ERROR("{} failed: {}", VersionRequest, Response.error.message); + } + + return 1; + } + Version = Response.text; + } + + zen::ConsoleLog().info("{}", Version); + + return 0; +} +} // namespace zen diff --git a/src/zen/cmds/version.h b/src/zen/cmds/version.h new file mode 100644 index 000000000..0e37e91a0 --- /dev/null +++ b/src/zen/cmds/version.h @@ -0,0 +1,24 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "../zen.h" + +namespace zen { + +class VersionCommand : public ZenCmdBase +{ +public: + VersionCommand(); + ~VersionCommand(); + + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override; + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"version", "Get zen service version"}; + std::string m_HostName; + bool m_DetailedVersion; +}; + +} // namespace zen diff --git a/src/zen/internalfile.cpp b/src/zen/internalfile.cpp new file mode 100644 index 000000000..2ade86e29 --- /dev/null +++ b/src/zen/internalfile.cpp @@ -0,0 +1,299 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "internalfile.h" + +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/memory.h> + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +# include <fcntl.h> +# include <sys/file.h> +# include <sys/mman.h> +# include <sys/stat.h> +#endif + +#include <gsl/gsl-lite.hpp> + +#define ZEN_USE_SLIST ZEN_PLATFORM_WINDOWS + +#if ZEN_USE_SLIST == 0 +struct FileBufferManager::Impl +{ + zen::RwLock m_Lock; + std::list<zen::IoBuffer> m_FreeBuffers; + + uint64_t m_BufferSize; + uint64_t m_MaxBufferCount; + + Impl(uint64_t BufferSize, uint64_t MaxBuffers) : m_BufferSize(BufferSize), m_MaxBufferCount(MaxBuffers) {} + + zen::IoBuffer AllocBuffer() + { + zen::RwLock::ExclusiveLockScope _(m_Lock); + + if (m_FreeBuffers.empty()) + { + return zen::IoBuffer{m_BufferSize, 64 * 1024}; + } + else + { + zen::IoBuffer Buffer = std::move(m_FreeBuffers.front()); + m_FreeBuffers.pop_front(); + return Buffer; + } + } + + void ReturnBuffer(zen::IoBuffer Buffer) + { + zen::RwLock::ExclusiveLockScope _(m_Lock); + + m_FreeBuffers.push_front(std::move(Buffer)); + } +}; +#else +struct FileBufferManager::Impl +{ + struct BufferItem + { + SLIST_ENTRY ItemEntry; + zen::IoBuffer Buffer; + }; + + SLIST_HEADER m_FreeList; + uint64_t m_BufferSize; + uint64_t m_MaxBufferCount; + + Impl(uint64_t BufferSize, uint64_t MaxBuffers) : m_BufferSize(BufferSize), m_MaxBufferCount(MaxBuffers) + { + InitializeSListHead(&m_FreeList); + } + + ~Impl() + { + while (SLIST_ENTRY* Entry = InterlockedPopEntrySList(&m_FreeList)) + { + BufferItem* Item = reinterpret_cast<BufferItem*>(Entry); + delete Item; + } + } + + zen::IoBuffer AllocBuffer() + { + SLIST_ENTRY* Entry = InterlockedPopEntrySList(&m_FreeList); + + if (Entry == nullptr) + { + return zen::IoBuffer{m_BufferSize, 64 * 1024}; + } + else + { + BufferItem* Item = reinterpret_cast<BufferItem*>(Entry); + zen::IoBuffer Buffer = std::move(Item->Buffer); + delete Item; // Todo: could keep this around in another list + + return Buffer; + } + } + + void ReturnBuffer(zen::IoBuffer Buffer) + { + BufferItem* Item = new BufferItem{nullptr, std::move(Buffer)}; + + InterlockedPushEntrySList(&m_FreeList, &Item->ItemEntry); + } +}; +#endif + +FileBufferManager::FileBufferManager(uint64_t BufferSize, uint64_t MaxBuffers) +{ + m_Impl = new Impl{BufferSize, MaxBuffers}; +} + +FileBufferManager::~FileBufferManager() +{ + delete m_Impl; +} + +zen::IoBuffer +FileBufferManager::AllocBuffer() +{ + return m_Impl->AllocBuffer(); +} + +void +FileBufferManager::ReturnBuffer(zen::IoBuffer Buffer) +{ + return m_Impl->ReturnBuffer(Buffer); +} + +////////////////////////////////////////////////////////////////////////// + +InternalFile::InternalFile() +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +: m_File(nullptr) +, m_Mmap(nullptr) +#endif +{ +} + +InternalFile::~InternalFile() +{ + if (m_Memory) + zen::Memory::Free(m_Memory); + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + if (m_Mmap) + munmap(m_Mmap, GetFileSize()); + if (m_File) + close(int(intptr_t(m_File))); +#endif +} + +size_t +InternalFile::GetFileSize() +{ +#if ZEN_PLATFORM_WINDOWS + ULONGLONG sz; + m_File.GetSize(sz); + return size_t(sz); +#else + int Fd = int(intptr_t(m_File)); + static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files"); + struct stat Stat; + fstat(Fd, &Stat); + return size_t(Stat.st_size); +#endif +} + +void +InternalFile::OpenWrite(std::filesystem::path FileName, bool IsCreate) +{ + bool Success = false; + +#if ZEN_PLATFORM_WINDOWS + const DWORD dwCreationDisposition = IsCreate ? CREATE_ALWAYS : OPEN_EXISTING; + + HRESULT hRes = m_File.Create(FileName.c_str(), GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ, dwCreationDisposition); + Success = SUCCEEDED(hRes); +#else + int OpenFlags = O_RDWR | O_CLOEXEC; + OpenFlags |= IsCreate ? O_CREAT | O_TRUNC : 0; + + int Fd = open(FileName.c_str(), OpenFlags, 0666); + if (Fd >= 0) + { + if (IsCreate) + { + fchmod(Fd, 0666); + } + Success = true; + m_File = (void*)(intptr_t(Fd)); + } +#endif // ZEN_PLATFORM_WINDOWS + + if (Success) + { + zen::ThrowLastError(fmt::format("Failed to open file for writing: '{}'", FileName)); + } +} + +void +InternalFile::OpenRead(std::filesystem::path FileName) +{ + bool Success = false; + +#if ZEN_PLATFORM_WINDOWS + const DWORD dwCreationDisposition = OPEN_EXISTING; + + HRESULT hRes = m_File.Create(FileName.c_str(), GENERIC_READ, FILE_SHARE_READ, dwCreationDisposition); + Success = SUCCEEDED(hRes); +#else + int Fd = open(FileName.c_str(), O_RDONLY); + if (Fd >= 0) + { + Success = true; + m_File = (void*)(intptr_t(Fd)); + } +#endif + + if (Success) + { + zen::ThrowLastError(fmt::format("Failed to open file for reading: '{}'", FileName)); + } +} + +const void* +InternalFile::MemoryMapFile() +{ + auto FileSize = GetFileSize(); + + if (FileSize <= 100 * 1024 * 1024) + { + m_Memory = zen::Memory::Alloc(FileSize, 64); + Read(m_Memory, FileSize, 0); + + return m_Memory; + } + +#if ZEN_PLATFORM_WINDOWS + m_Mmap.MapFile(m_File); + return m_Mmap.GetData(); +#else + int Fd = int(intptr_t(m_File)); + m_Mmap = mmap(nullptr, FileSize, PROT_READ, MAP_PRIVATE, Fd, 0); + return m_Mmap; +#endif +} + +void +InternalFile::Read(void* Data, uint64_t Size, uint64_t Offset) +{ + bool Success; + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED ovl{}; + + ovl.Offset = DWORD(Offset & 0xffff'ffffu); + ovl.OffsetHigh = DWORD(Offset >> 32); + + HRESULT hRes = m_File.Read(Data, gsl::narrow<DWORD>(Size), &ovl); + Success = SUCCEEDED(hRes); +#else + int Fd = int(intptr_t(m_File)); + int BytesRead = pread(Fd, Data, Size, Offset); + Success = (BytesRead > 0); +#endif + + if (Success) + { + zen::ThrowLastError(fmt::format("Failed to read from file '{}'", "")); // zen::PathFromHandle(m_File))); + } +} + +void +InternalFile::Write(const void* Data, uint64_t Size, uint64_t Offset) +{ + bool Success; + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(Offset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(Offset >> 32); + + HRESULT hRes = m_File.Write(Data, gsl::narrow<DWORD>(Size), &Ovl); + Success = SUCCEEDED(hRes); +#else + int Fd = int(intptr_t(m_File)); + int BytesWritten = pwrite(Fd, Data, Size, Offset); + Success = (BytesWritten > 0); +#endif + + if (Success) + { + zen::ThrowLastError(fmt::format("Failed to write to file '{}'", zen::PathFromHandle(m_File))); + } +} diff --git a/src/zen/internalfile.h b/src/zen/internalfile.h new file mode 100644 index 000000000..8acb600ff --- /dev/null +++ b/src/zen/internalfile.h @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <zencore/iobuffer.h> +#include <zencore/refcount.h> +#include <zencore/thread.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#endif + +#if ZEN_PLATFORM_WINDOWS +# include <atlfile.h> +#endif + +#include <filesystem> +#include <list> + +////////////////////////////////////////////////////////////////////////// + +class FileBufferManager : public zen::RefCounted +{ +public: + FileBufferManager(uint64_t BufferSize, uint64_t MaxBufferCount); + ~FileBufferManager(); + + zen::IoBuffer AllocBuffer(); + void ReturnBuffer(zen::IoBuffer Buffer); + +private: + struct Impl; + + Impl* m_Impl; +}; + +class InternalFile : public zen::RefCounted +{ +public: + InternalFile(); + ~InternalFile(); + + void OpenRead(std::filesystem::path FileName); + void Read(void* Data, uint64_t Size, uint64_t Offset); + + void OpenWrite(std::filesystem::path FileName, bool isCreate); + void Write(const void* Data, uint64_t Size, uint64_t Offset); + + const void* MemoryMapFile(); + size_t GetFileSize(); + +private: +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + using CAtlFile = void*; + using CAtlFileMappingBase = void*; +#endif + CAtlFile m_File; + CAtlFileMappingBase m_Mmap; + void* m_Memory = nullptr; +}; diff --git a/src/zen/xmake.lua b/src/zen/xmake.lua new file mode 100644 index 000000000..b83999efc --- /dev/null +++ b/src/zen/xmake.lua @@ -0,0 +1,31 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zen") + set_kind("binary") + add_headerfiles("**.h") + add_files("**.cpp") + add_files("zen.cpp", {unity_ignored = true }) + add_deps("zencore", "zenhttp", "zenutil") + add_includedirs(".") + set_symbols("debug") + + if is_mode("release") then + set_optimize("fastest") + end + + if is_plat("windows") then + add_files("zen.rc") + add_ldflags("/subsystem:console,5.02") + add_ldflags("/LTCG") + add_ldflags("crypt32.lib", "wldap32.lib", "Ws2_32.lib") + end + + if is_plat("macosx") then + add_ldflags("-framework CoreFoundation") + add_ldflags("-framework Security") + add_ldflags("-framework SystemConfiguration") + add_syslinks("bsm") + end + + add_packages("vcpkg::zstd") + add_packages("vcpkg::cxxopts", "vcpkg::mimalloc") diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp new file mode 100644 index 000000000..9754f4434 --- /dev/null +++ b/src/zen/zen.cpp @@ -0,0 +1,421 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +// Zen command line client utility +// + +#include "zen.h" + +#include "chunk/chunk.h" +#include "cmds/cache.h" +#include "cmds/copy.h" +#include "cmds/dedup.h" +#include "cmds/hash.h" +#include "cmds/print.h" +#include "cmds/projectstore.h" +#include "cmds/rpcreplay.h" +#include "cmds/scrub.h" +#include "cmds/status.h" +#include "cmds/top.h" +#include "cmds/up.h" +#include "cmds/version.h" + +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> + +#include <zenhttp/httpcommon.h> + +#if ZEN_WITH_TESTS +# define ZEN_TEST_WITH_RUNNER 1 +# include <zencore/testing.h> +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_USE_MIMALLOC +# include <mimalloc-new-delete.h> +#endif + +////////////////////////////////////////////////////////////////////////// + +bool +ZenCmdBase::ParseOptions(int argc, char** argv) +{ + cxxopts::Options& CmdOptions = Options(); + cxxopts::ParseResult Result = CmdOptions.parse(argc, argv); + + if (Result.count("help")) + { + printf("%s\n", CmdOptions.help().c_str()); + return false; + } + + if (!Result.unmatched().empty()) + { + zen::ExtendableStringBuilder<64> StringBuilder; + for (bool First = true; const auto& Param : Result.unmatched()) + { + if (!First) + { + StringBuilder.Append(", "); + } + StringBuilder.Append('"'); + StringBuilder.Append(Param); + StringBuilder.Append('"'); + First = false; + } + + throw cxxopts::OptionParseException(fmt::format("Invalid arguments: {}", StringBuilder.ToView())); + } + + return true; +} + +std::string +ZenCmdBase::FormatHttpResponse(const cpr::Response& Response) +{ + if (Response.error.code != cpr::ErrorCode::OK) + { + if (Response.error.message.empty()) + { + return fmt::format("Request '{}' failed, error code {}", Response.url.str(), static_cast<int>(Response.error.code)); + } + return fmt::format("Request '{}' failed. Reason: '{}' ({})", + Response.url.str(), + Response.error.message, + static_cast<int>(Response.error.code)); + } + + std::string Content; + if (auto It = Response.header.find("Content-Type"); It != Response.header.end()) + { + zen::HttpContentType ContentType = zen::ParseContentType(It->second); + if (ContentType == zen::HttpContentType::kText) + { + Content = fmt::format("'{}'", Response.text); + } + else if (ContentType == zen::HttpContentType::kJSON) + { + Content = fmt::format("\n{}", Response.text); + } + else if (!Response.text.empty()) + { + Content = fmt::format("[{}]", MapContentTypeToString(ContentType)); + } + } + + std::string_view ResponseString = zen::ReasonStringForHttpResultCode( + Response.status_code == static_cast<long>(zen::HttpResponseCode::NoContent) ? static_cast<long>(zen::HttpResponseCode::OK) + : Response.status_code); + if (Content.empty()) + { + return std::string(ResponseString); + } + + return fmt::format("{}: {}", ResponseString, Content); +} + +int +ZenCmdBase::MapHttpToCommandReturnCode(const cpr::Response& Response) +{ + if (zen::IsHttpSuccessCode(Response.status_code)) + { + return 0; + } + if (Response.error.code != cpr::ErrorCode::OK) + { + return static_cast<int>(Response.error.code); + } + return 1; +} + +#if ZEN_WITH_TESTS + +class RunTestsCommand : public ZenCmdBase +{ +public: + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override + { + ZEN_UNUSED(GlobalOptions); + + // Set output mode to handle virtual terminal sequences +# if ZEN_PLATFORM_WINDOWS + HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE); + if (hOut == INVALID_HANDLE_VALUE) + return GetLastError(); + + DWORD dwMode = 0; + if (!GetConsoleMode(hOut, &dwMode)) + return GetLastError(); + + dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if (!SetConsoleMode(hOut, dwMode)) + return GetLastError(); +# endif // ZEN_PLATFORM_WINDOWS + + return ZEN_RUN_TESTS(argc, argv); + } + + virtual cxxopts::Options& Options() override { return m_Options; } + +private: + cxxopts::Options m_Options{"runtests", "Run tests"}; +}; + +#endif + +////////////////////////////////////////////////////////////////////////// +// TODO: should make this Unicode-aware so we can pass anything in on the +// command line. + +int +main(int argc, char** argv) +{ + using namespace zen; + +#if ZEN_USE_MIMALLOC + mi_version(); +#endif + + zen::logging::InitializeLogging(); + zen::MaximizeOpenFileCount(); + + ////////////////////////////////////////////////////////////////////////// + + auto _ = zen::MakeGuard([] { spdlog::shutdown(); }); + + CacheInfoCommand CacheInfoCmd; + CopyCommand CopyCmd; + CreateOplogCommand CreateOplogCmd; + CreateProjectCommand CreateProjectCmd; + DedupCommand DedupCmd; + DownCommand DownCmd; + DropCommand DropCmd; + DropProjectCommand ProjectDropCmd; + ExportOplogCommand ExportOplogCmd; + GcCommand GcCmd; + GcStatusCommand GcStatusCmd; + HashCommand HashCmd; + ImportOplogCommand ImportOplogCmd; + PrintCommand PrintCmd; + PrintPackageCommand PrintPkgCmd; + ProjectInfoCommand ProjectInfoCmd; + PsCommand PsCmd; + RpcReplayCommand RpcReplayCmd; + RpcStartRecordingCommand RpcStartRecordingCmd; + RpcStopRecordingCommand RpcStopRecordingCmd; + StatusCommand StatusCmd; + TopCommand TopCmd; + UpCommand UpCmd; + VersionCommand VersionCmd; + CacheStatsCommand CacheStatsCmd; + CacheDetailsCommand CacheDetailsCmd; + ProjectStatsCommand ProjectStatsCmd; + ProjectDetailsCommand ProjectDetailsCmd; +#if ZEN_WITH_TESTS + RunTestsCommand RunTestsCmd; +#endif + + const struct CommandInfo + { + const char* CmdName; + ZenCmdBase* Cmd; + const char* CmdSummary; + } Commands[] = { + // clang-format off +// {"chunk", &ChunkCmd, "Perform chunking"}, + {"cache-info", &CacheInfoCmd, "Info on cache, namespace or bucket"}, + {"copy", &CopyCmd, "Copy file(s)"}, + {"dedup", &DedupCmd, "Dedup files"}, + {"down", &DownCmd, "Bring zen server down"}, + {"drop", &DropCmd, "Drop cache namespace or bucket"}, + {"gc-status", &GcStatusCmd, "Garbage collect zen storage status check"}, + {"gc", &GcCmd, "Garbage collect zen storage"}, + {"hash", &HashCmd, "Compute file hashes"}, + {"oplog-create", &CreateOplogCmd, "Create a project oplog"}, + {"oplog-export", &ExportOplogCmd, "Export project store oplog"}, + {"oplog-import", &ImportOplogCmd, "Import project store oplog"}, + {"print", &PrintCmd, "Print compact binary object"}, + {"printpackage", &PrintPkgCmd, "Print compact binary package"}, + {"project-create", &CreateProjectCmd, "Create a project"}, + {"project-drop", &ProjectDropCmd, "Drop project or project oplog"}, + {"project-info", &ProjectInfoCmd, "Info on project or project oplog"}, + {"ps", &PsCmd, "Enumerate running zen server instances"}, + {"rpc-record-replay", &RpcReplayCmd, "Stops recording of cache rpc requests on a host"}, + {"rpc-record-start", &RpcStartRecordingCmd, "Replays a previously recorded session of rpc requests"}, + {"rpc-record-stop", &RpcStopRecordingCmd, "Starts recording of cache rpc requests on a host"}, + {"status", &StatusCmd, "Show zen status"}, + {"top", &TopCmd, "Monitor zen server activity"}, + {"up", &UpCmd, "Bring zen server up"}, + {"version", &VersionCmd, "Get zen server version"}, + {"cache-stats", &CacheStatsCmd, "Stats on cache"}, + {"cache-details", &CacheDetailsCmd, "Details on cache"}, + {"project-stats", &ProjectStatsCmd, "Stats on project store"}, + {"project-details", &ProjectDetailsCmd, "Details on project store"}, +#if ZEN_WITH_TESTS + {"runtests", &RunTestsCmd, "Run zen tests"}, +#endif + // clang-format on + }; + + // Build set containing available commands + + std::unordered_set<std::string> CommandSet; + + for (const auto& Cmd : Commands) + CommandSet.insert(Cmd.CmdName); + + // Split command line into options, commands and any pass-through arguments + + std::string Passthrough; + std::vector<std::string> PassthroughV; + + for (int i = 1; i < argc; ++i) + { + if (strcmp(argv[i], "--") == 0) + { + bool IsFirst = true; + zen::ExtendableStringBuilder<256> Line; + + for (int j = i + 1; j < argc; ++j) + { + if (!IsFirst) + { + Line.AppendAscii(" "); + } + + std::string_view ThisArg(argv[j]); + PassthroughV.push_back(std::string(ThisArg)); + + const bool NeedsQuotes = (ThisArg.find(' ') != std::string_view::npos); + + if (NeedsQuotes) + { + Line.AppendAscii("\""); + } + + Line.Append(ThisArg); + + if (NeedsQuotes) + { + Line.AppendAscii("\""); + } + + IsFirst = false; + } + + Passthrough = Line.c_str(); + + // This will "truncate" the arg vector and terminate the loop + argc = i - 1; + } + } + + // Split command line into global vs command options. We do this by simply + // scanning argv for a string we recognise as a command and split it there + + std::vector<char*> CommandArgVec; + CommandArgVec.push_back(argv[0]); + + for (int i = 1; i < argc; ++i) + { + if (CommandSet.find(argv[i]) != CommandSet.end()) + { + int commandArgCount = /* exec name */ 1 + argc - (i + 1); + CommandArgVec.resize(commandArgCount); + std::copy(argv + i + 1, argv + argc, CommandArgVec.begin() + 1); + + argc = i + 1; + + break; + } + } + + // Parse global CLI arguments + + ZenCliOptions GlobalOptions; + + GlobalOptions.PassthroughArgs = Passthrough; + GlobalOptions.PassthroughV = PassthroughV; + + std::string SubCommand = "<None>"; + + cxxopts::Options Options("zen", "Zen management tool"); + + Options.add_options()("d, debug", "Enable debugging", cxxopts::value<bool>(GlobalOptions.IsDebug)); + Options.add_options()("v, verbose", "Enable verbose logging", cxxopts::value<bool>(GlobalOptions.IsVerbose)); + Options.add_options()("help", "Show command line help"); + Options.add_options()("c, command", "Sub command", cxxopts::value<std::string>(SubCommand)); + + Options.parse_positional({"command"}); + + const bool IsNullInvoke = (argc == 1); // If no arguments are passed we want to print usage information + + try + { + auto ParseResult = Options.parse(argc, argv); + + if (ParseResult.count("help") || IsNullInvoke == 1) + { + std::string Help = Options.help(); + + printf("%s\n", Help.c_str()); + + printf("available commands:\n"); + + for (const auto& CmdInfo : Commands) + { + printf(" %-20s %s\n", CmdInfo.CmdName, CmdInfo.CmdSummary); + } + + exit(0); + } + + if (GlobalOptions.IsDebug) + { + spdlog::set_level(spdlog::level::debug); + } + + for (const CommandInfo& CmdInfo : Commands) + { + if (StrCaseCompare(SubCommand.c_str(), CmdInfo.CmdName) == 0) + { + cxxopts::Options& VerbOptions = CmdInfo.Cmd->Options(); + try + { + return CmdInfo.Cmd->Run(GlobalOptions, (int)CommandArgVec.size(), CommandArgVec.data()); + } + catch (cxxopts::OptionParseException& Ex) + { + std::string help = VerbOptions.help(); + + printf("Error parsing arguments for command '%s': %s\n\n%s", SubCommand.c_str(), Ex.what(), help.c_str()); + + exit(11); + } + } + } + + printf("Unknown command specified: '%s', exiting\n", SubCommand.c_str()); + } + catch (cxxopts::OptionParseException& Ex) + { + std::string HelpMessage = Options.help(); + + printf("Error parsing program arguments: %s\n\n%s", Ex.what(), HelpMessage.c_str()); + + return 9; + } + catch (std::exception& Ex) + { + printf("Exception caught from 'main': %s\n", Ex.what()); + + return 10; + } + + return 0; +} diff --git a/src/zen/zen.h b/src/zen/zen.h new file mode 100644 index 000000000..b55e7a16c --- /dev/null +++ b/src/zen/zen.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cxxopts.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#endif + +namespace cpr { +class Response; +} + +struct ZenCliOptions +{ + bool IsDebug = false; + bool IsVerbose = false; + + // Arguments after " -- " on command line are passed through and not parsed + std::string PassthroughArgs; + std::vector<std::string> PassthroughV; +}; + +class ZenCmdBase +{ +public: + virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) = 0; + virtual cxxopts::Options& Options() = 0; + + bool ParseOptions(int argc, char** argv); + static std::string FormatHttpResponse(const cpr::Response& Response); + static int MapHttpToCommandReturnCode(const cpr::Response& Response); +}; diff --git a/src/zen/zen.rc b/src/zen/zen.rc new file mode 100644 index 000000000..14a9afb70 --- /dev/null +++ b/src/zen/zen.rc @@ -0,0 +1,33 @@ +#include "zencore/config.h" + +#define APSTUDIO_READONLY_SYMBOLS +#include "winres.h" +#undef APSTUDIO_READONLY_SYMBOLS + +LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US +#pragma code_page(1252) + +101 ICON "..\\UnrealEngine.ico" + +VS_VERSION_INFO VERSIONINFO +FILEVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0 +PRODUCTVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0 +{ + BLOCK "StringFileInfo" + { + BLOCK "040904b0" + { + VALUE "CompanyName", "Epic Games Inc\0" + VALUE "FileDescription", "CLI utility for Zen Storage Service\0" + VALUE "FileVersion", ZEN_CFG_VERSION "\0" + VALUE "LegalCopyright", "Copyright Epic Games Inc. All Rights Reserved\0" + VALUE "OriginalFilename", "zen.exe\0" + VALUE "ProductName", "Zen Storage Server\0" + VALUE "ProductVersion", ZEN_CFG_VERSION_BUILD_STRING_FULL "\0" + } + } + BLOCK "VarFileInfo" + { + VALUE "Translation", 0x409, 1200 + } +} diff --git a/src/zencore-test/targetver.h b/src/zencore-test/targetver.h new file mode 100644 index 000000000..d432d6993 --- /dev/null +++ b/src/zencore-test/targetver.h @@ -0,0 +1,10 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +// Including SDKDDKVer.h defines the highest available Windows platform. + +// If you wish to build your application for a previous Windows platform, include WinSDKVer.h and +// set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h. + +#include <SDKDDKVer.h> diff --git a/src/zencore-test/xmake.lua b/src/zencore-test/xmake.lua new file mode 100644 index 000000000..74c7e74a7 --- /dev/null +++ b/src/zencore-test/xmake.lua @@ -0,0 +1,8 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zencore-test") + set_kind("binary") + add_headerfiles("**.h") + add_files("*.cpp") + add_deps("zencore") + add_packages("vcpkg::doctest") diff --git a/src/zencore-test/zencore-test.cpp b/src/zencore-test/zencore-test.cpp new file mode 100644 index 000000000..53413fb25 --- /dev/null +++ b/src/zencore-test/zencore-test.cpp @@ -0,0 +1,26 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +// zencore-test.cpp : Defines the entry point for the console application. +// + +#include <zencore/logging.h> +#include <zencore/zencore.h> + +#if ZEN_WITH_TESTS +# define ZEN_TEST_WITH_RUNNER 1 +# include <zencore/testing.h> +#endif + +int +main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) +{ +#if ZEN_WITH_TESTS + zen::zencore_forcelinktests(); + + zen::logging::InitializeLogging(); + + return ZEN_RUN_TESTS(argc, argv); +#else + return 0; +#endif +} diff --git a/src/zencore/.gitignore b/src/zencore/.gitignore new file mode 100644 index 000000000..77d39c17e --- /dev/null +++ b/src/zencore/.gitignore @@ -0,0 +1 @@ +include/zencore/config.h diff --git a/src/zencore/base64.cpp b/src/zencore/base64.cpp new file mode 100644 index 000000000..b97dfebbf --- /dev/null +++ b/src/zencore/base64.cpp @@ -0,0 +1,107 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/base64.h> + +namespace zen { + +/** The table used to encode a 6 bit value as an ascii character */ +static const uint8_t EncodingAlphabet[64] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', + 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; + +/** The table used to convert an ascii character into a 6 bit value */ +#if 0 +static const uint8_t DecodingAlphabet[256] = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x00-0x0f + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x10-0x1f + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF, 0x3F, // 0x20-0x2f + 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x30-0x3f + 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, // 0x40-0x4f + 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x50-0x5f + 0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, // 0x60-0x6f + 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x70-0x7f + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x80-0x8f + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x90-0x9f + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xa0-0xaf + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xb0-0xbf + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xc0-0xcf + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xd0-0xdf + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xe0-0xef + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF // 0xf0-0xff +}; +#endif // 0 + +template<typename CharType> +uint32_t +Base64::Encode(const uint8_t* Source, uint32_t Length, CharType* Dest) +{ + CharType* EncodedBytes = Dest; + + // Loop through the buffer converting 3 bytes of binary data at a time + while (Length >= 3) + { + uint8_t A = *Source++; + uint8_t B = *Source++; + uint8_t C = *Source++; + Length -= 3; + + // The algorithm takes 24 bits of data (3 bytes) and breaks it into 4 6bit chunks represented as ascii + uint32_t ByteTriplet = A << 16 | B << 8 | C; + + // Use the 6bit block to find the representation ascii character for it + EncodedBytes[3] = EncodingAlphabet[ByteTriplet & 0x3F]; + ByteTriplet >>= 6; + EncodedBytes[2] = EncodingAlphabet[ByteTriplet & 0x3F]; + ByteTriplet >>= 6; + EncodedBytes[1] = EncodingAlphabet[ByteTriplet & 0x3F]; + ByteTriplet >>= 6; + EncodedBytes[0] = EncodingAlphabet[ByteTriplet & 0x3F]; + + // Now we can append this buffer to our destination string + EncodedBytes += 4; + } + + // Since this algorithm operates on blocks, we may need to pad the last chunks + if (Length > 0) + { + uint8_t A = *Source++; + uint8_t B = 0; + uint8_t C = 0; + // Grab the second character if it is a 2 uint8_t finish + if (Length == 2) + { + B = *Source; + } + uint32_t ByteTriplet = A << 16 | B << 8 | C; + // Pad with = to make a 4 uint8_t chunk + EncodedBytes[3] = '='; + ByteTriplet >>= 6; + // If there's only one 1 uint8_t left in the source, then you need 2 pad chars + if (Length == 1) + { + EncodedBytes[2] = '='; + } + else + { + EncodedBytes[2] = EncodingAlphabet[ByteTriplet & 0x3F]; + } + // Now encode the remaining bits the same way + ByteTriplet >>= 6; + EncodedBytes[1] = EncodingAlphabet[ByteTriplet & 0x3F]; + ByteTriplet >>= 6; + EncodedBytes[0] = EncodingAlphabet[ByteTriplet & 0x3F]; + + EncodedBytes += 4; + } + + // Add a null terminator + *EncodedBytes = 0; + + return uint32_t(EncodedBytes - Dest); +} + +template ZENCORE_API uint32_t Base64::Encode<char>(const uint8_t* Source, uint32_t Length, char* Dest); +template ZENCORE_API uint32_t Base64::Encode<wchar_t>(const uint8_t* Source, uint32_t Length, wchar_t* Dest); + +} // namespace zen diff --git a/src/zencore/blake3.cpp b/src/zencore/blake3.cpp new file mode 100644 index 000000000..89826ae5d --- /dev/null +++ b/src/zencore/blake3.cpp @@ -0,0 +1,175 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/blake3.h> + +#include <zencore/compositebuffer.h> +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/zencore.h> + +#include <string.h> + +#include "blake3.h" + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +void +blake3_forcelink() +{ +} + +BLAKE3 BLAKE3::Zero; // Initialized to all zeroes + +BLAKE3 +BLAKE3::HashMemory(const void* data, size_t byteCount) +{ + BLAKE3 b3; + + blake3_hasher b3h; + blake3_hasher_init(&b3h); + blake3_hasher_update(&b3h, data, byteCount); + blake3_hasher_finalize(&b3h, b3.Hash, sizeof b3.Hash); + + return b3; +} + +BLAKE3 +BLAKE3::HashBuffer(const CompositeBuffer& Buffer) +{ + BLAKE3 Hash; + + blake3_hasher Hasher; + blake3_hasher_init(&Hasher); + + for (const SharedBuffer& Segment : Buffer.GetSegments()) + { + blake3_hasher_update(&Hasher, Segment.GetData(), Segment.GetSize()); + } + + blake3_hasher_finalize(&Hasher, Hash.Hash, sizeof Hash.Hash); + + return Hash; +} + +BLAKE3 +BLAKE3::FromHexString(const char* string) +{ + BLAKE3 b3; + + ParseHexBytes(string, 2 * sizeof b3.Hash, b3.Hash); + + return b3; +} + +const char* +BLAKE3::ToHexString(char* outString /* 40 characters + NUL terminator */) const +{ + ToHexBytes(Hash, sizeof(BLAKE3), outString); + outString[2 * sizeof(BLAKE3)] = '\0'; + + return outString; +} + +StringBuilderBase& +BLAKE3::ToHexString(StringBuilderBase& outBuilder) const +{ + char str[65]; + ToHexString(str); + + outBuilder.AppendRange(str, &str[65]); + + return outBuilder; +} + +BLAKE3Stream::BLAKE3Stream() +{ + blake3_hasher* b3h = reinterpret_cast<blake3_hasher*>(m_HashState); + static_assert(sizeof(blake3_hasher) <= sizeof(m_HashState)); + blake3_hasher_init(b3h); +} + +void +BLAKE3Stream::Reset() +{ + blake3_hasher* b3h = reinterpret_cast<blake3_hasher*>(m_HashState); + blake3_hasher_init(b3h); +} + +BLAKE3Stream& +BLAKE3Stream::Append(const void* data, size_t byteCount) +{ + blake3_hasher* b3h = reinterpret_cast<blake3_hasher*>(m_HashState); + blake3_hasher_update(b3h, data, byteCount); + + return *this; +} + +BLAKE3 +BLAKE3Stream::GetHash() +{ + BLAKE3 b3; + + blake3_hasher* b3h = reinterpret_cast<blake3_hasher*>(m_HashState); + blake3_hasher_finalize(b3h, b3.Hash, sizeof b3.Hash); + + return b3; +} + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +// doctest::String +// toString(const BLAKE3& value) +// { +// char text[2 * sizeof(BLAKE3) + 1]; +// value.ToHexString(text); + +// return text; +// } + +TEST_CASE("BLAKE3") +{ + SUBCASE("Basics") + { + BLAKE3 b3 = BLAKE3::HashMemory(nullptr, 0); + CHECK(BLAKE3::FromHexString("af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262") == b3); + + BLAKE3::String_t b3s; + std::string b3ss = b3.ToHexString(b3s); + CHECK(b3ss == "af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262"); + } + + SUBCASE("hashes") + { + CHECK(BLAKE3::FromHexString("00307ced6a8b278d5e3a9f77b138d0e9d2209717c9d45b205f427a73565cc5fb") == BLAKE3::HashMemory("abc123", 6)); + CHECK(BLAKE3::FromHexString("a7142c8c3905cd11b1e35105c7ac588b75d6798822f71e1145187ad46f3e8df4") == + BLAKE3::HashMemory("1234567890123456789012345678901234567890", 40)); + CHECK(BLAKE3::FromHexString("70e708532559265c4662d0285e5e0a4be8bd972bd1f255a93ddf342243adc427") == + BLAKE3::HashMemory("The HttpSendHttpResponse function sends an HTTP response to the specified HTTP request.", 87)); + } + + SUBCASE("streamHashes") + { + auto streamHash = [](const void* data, size_t dataBytes) -> BLAKE3 { + BLAKE3Stream b3s; + b3s.Append(data, dataBytes); + return b3s.GetHash(); + }; + + CHECK(BLAKE3::FromHexString("00307ced6a8b278d5e3a9f77b138d0e9d2209717c9d45b205f427a73565cc5fb") == streamHash("abc123", 6)); + CHECK(BLAKE3::FromHexString("a7142c8c3905cd11b1e35105c7ac588b75d6798822f71e1145187ad46f3e8df4") == + streamHash("1234567890123456789012345678901234567890", 40)); + CHECK(BLAKE3::FromHexString("70e708532559265c4662d0285e5e0a4be8bd972bd1f255a93ddf342243adc427") == + streamHash("The HttpSendHttpResponse function sends an HTTP response to the specified HTTP request.", 87)); + } +} + +#endif + +} // namespace zen diff --git a/src/zencore/compactbinary.cpp b/src/zencore/compactbinary.cpp new file mode 100644 index 000000000..0db9f02ea --- /dev/null +++ b/src/zencore/compactbinary.cpp @@ -0,0 +1,2299 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/compactbinary.h" + +#include <zencore/base64.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/compactbinaryvalue.h> +#include <zencore/compress.h> +#include <zencore/endian.h> +#include <zencore/fmtutils.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/uid.h> + +#include <fmt/format.h> +#include <string_view> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#else +# include <time.h> +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +const int DaysToMonth[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365}; + +double +GetJulianDay(uint64_t Ticks) +{ + return (double)(1721425.5 + Ticks / TimeSpan::TicksPerDay); +} + +bool +IsLeapYear(int Year) +{ + if ((Year % 4) == 0) + { + return (((Year % 100) != 0) || ((Year % 400) == 0)); + } + + return false; +} + +static constexpr uint64_t +GetPlatformToDateTimeBiasInSeconds() +{ +#if ZEN_PLATFORM_WINDOWS + const uint64_t PlatformEpochYear = 1601; +#else + const uint64_t PlatformEpochYear = 1970; +#endif + const uint64_t DateTimeEpochYear = 1; + return uint64_t(double(PlatformEpochYear - DateTimeEpochYear) * 365.2425) * 86400; +} + +uint64_t +DateTime::NowTicks() +{ + static constexpr uint64_t EpochBias = GetPlatformToDateTimeBiasInSeconds(); + +#if ZEN_PLATFORM_WINDOWS + FILETIME SysTime; + GetSystemTimePreciseAsFileTime(&SysTime); + return (EpochBias * TimeSpan::TicksPerSecond) + ((uint64_t(SysTime.dwHighDateTime) << 32) | SysTime.dwLowDateTime); +#else + int64_t SecondsSinceUnixEpoch = time(nullptr); + return (EpochBias + SecondsSinceUnixEpoch) * TimeSpan::TicksPerSecond; +#endif +} + +DateTime +DateTime::Now() +{ + return DateTime{NowTicks()}; +} + +void +DateTime::Set(int Year, int Month, int Day, int Hour, int Minute, int Second, int MilliSecond) +{ + int TotalDays = 0; + + if ((Month > 2) && IsLeapYear(Year)) + { + ++TotalDays; + } + + --Year; // the current year is not a full year yet + --Month; // the current month is not a full month yet + + TotalDays += Year * 365; + TotalDays += Year / 4; // leap year day every four years... + TotalDays -= Year / 100; // ...except every 100 years... + TotalDays += Year / 400; // ...but also every 400 years + TotalDays += DaysToMonth[Month]; // days in this year up to last month + TotalDays += Day - 1; // days in this month minus today + + Ticks = TotalDays * TimeSpan::TicksPerDay + Hour * TimeSpan::TicksPerHour + Minute * TimeSpan::TicksPerMinute + + Second * TimeSpan::TicksPerSecond + MilliSecond * TimeSpan::TicksPerMillisecond; +} + +int +DateTime::GetYear() const +{ + int Year, Month, Day; + GetDate(Year, Month, Day); + + return Year; +} + +int +DateTime::GetMonth() const +{ + int Year, Month, Day; + GetDate(Year, Month, Day); + + return Month; +} + +int +DateTime::GetDay() const +{ + int Year, Month, Day; + GetDate(Year, Month, Day); + + return Day; +} + +int +DateTime::GetHour() const +{ + return (int)((Ticks / TimeSpan::TicksPerHour) % 24); +} + +int +DateTime::GetHour12() const +{ + int Hour = GetHour(); + + if (Hour < 1) + { + return 12; + } + + if (Hour > 12) + { + return (Hour - 12); + } + + return Hour; +} + +int +DateTime::GetMinute() const +{ + return (int)((Ticks / TimeSpan::TicksPerMinute) % 60); +} + +int +DateTime::GetSecond() const +{ + return (int)((Ticks / TimeSpan::TicksPerSecond) % 60); +} + +int +DateTime::GetMillisecond() const +{ + return (int)((Ticks / TimeSpan::TicksPerMillisecond) % 1000); +} + +void +DateTime::GetDate(int& Year, int& Month, int& Day) const +{ + // Based on FORTRAN code in: + // Fliegel, H. F. and van Flandern, T. C., + // Communications of the ACM, Vol. 11, No. 10 (October 1968). + + int i, j, k, l, n; + + l = int(GetJulianDay(Ticks) + 0.5) + 68569; + n = 4 * l / 146097; + l = l - (146097 * n + 3) / 4; + i = 4000 * (l + 1) / 1461001; + l = l - 1461 * i / 4 + 31; + j = 80 * l / 2447; + k = l - 2447 * j / 80; + l = j / 11; + j = j + 2 - 12 * l; + i = 100 * (n - 49) + i + l; + + Year = i; + Month = j; + Day = k; +} + +std::string +DateTime::ToString(const char* Format) const +{ + ExtendableStringBuilder<32> Result; + int Year, Month, Day; + + GetDate(Year, Month, Day); + + if (Format != nullptr) + { + while (*Format != '\0') + { + if ((*Format == '%') && (*(++Format) != '\0')) + { + switch (*Format) + { + // case 'a': Result.Append(IsMorning() ? TEXT("am") : TEXT("pm")); break; + // case 'A': Result.Append(IsMorning() ? TEXT("AM") : TEXT("PM")); break; + case 'd': + Result.Append(fmt::format("{:02}", Day)); + break; + // case 'D': Result.Appendf(TEXT("%03i"), GetDayOfYear()); break; + case 'm': + Result.Append(fmt::format("{:02}", Month)); + break; + case 'y': + Result.Append(fmt::format("{:02}", Year % 100)); + break; + case 'Y': + Result.Append(fmt::format("{:04}", Year)); + break; + case 'h': + Result.Append(fmt::format("{:02}", GetHour12())); + break; + case 'H': + Result.Append(fmt::format("{:02}", GetHour())); + break; + case 'M': + Result.Append(fmt::format("{:02}", GetMinute())); + break; + case 'S': + Result.Append(fmt::format("{:02}", GetSecond())); + break; + case 's': + Result.Append(fmt::format("{:03}", GetMillisecond())); + break; + default: + Result.Append(*Format); + } + } + else + { + Result.Append(*Format); + } + + // move to the next one + Format++; + } + } + + return Result.ToString(); +} + +std::string +DateTime::ToIso8601() const +{ + return ToString("%Y-%m-%dT%H:%M:%S.%sZ"); +} + +void +TimeSpan::Set(int Days, int Hours, int Minutes, int Seconds, int FractionNano) +{ + int64_t TotalTicks = 0; + + TotalTicks += Days * TicksPerDay; + TotalTicks += Hours * TicksPerHour; + TotalTicks += Minutes * TicksPerMinute; + TotalTicks += Seconds * TicksPerSecond; + TotalTicks += FractionNano / NanosecondsPerTick; + + Ticks = TotalTicks; +} + +std::string +TimeSpan::ToString(const char* Format) const +{ + StringBuilder<128> Result; + + Result.Append((int64_t(Ticks) < 0) ? '-' : '+'); + + while (*Format != '\0') + { + if ((*Format == '%') && (*++Format != '\0')) + { + switch (*Format) + { + case 'd': + Result.Append(fmt::format("{}", GetDays())); + break; + case 'D': + Result.Append(fmt::format("{:08}", GetDays())); + break; + case 'h': + Result.Append(fmt::format("{:02}", GetHours())); + break; + case 'm': + Result.Append(fmt::format("{:02}", GetMinutes())); + break; + case 's': + Result.Append(fmt::format("{:02}", GetSeconds())); + break; + case 'f': + Result.Append(fmt::format("{:03}", GetFractionMilli())); + break; + case 'u': + Result.Append(fmt::format("{:06}", GetFractionMicro())); + break; + case 't': + Result.Append(fmt::format("{:07}", GetFractionTicks())); + break; + case 'n': + Result.Append(fmt::format("{:09}", GetFractionNano())); + break; + default: + Result.Append(*Format); + } + } + else + { + Result.Append(*Format); + } + + ++Format; + } + + return Result.ToString(); +} + +std::string +TimeSpan::ToString() const +{ + if (GetDays() == 0) + { + return ToString("%h:%m:%s.%f"); + } + + return ToString("%d.%h:%m:%s.%f"); +} + +StringBuilderBase& +Guid::ToString(StringBuilderBase& Sb) const +{ + char Buf[128]; + snprintf(Buf, sizeof Buf, "%08x-%04x-%04x-%04x-%04x%08x", A, B >> 16, B & 0xFFFF, C >> 16, C & 0xFFFF, D); + Sb << Buf; + + return Sb; +} + +////////////////////////////////////////////////////////////////////////// + +namespace CompactBinaryPrivate { + static constexpr const uint8_t GEmptyObjectPayload[] = {uint8_t(CbFieldType::Object), 0x00}; + static constexpr const uint8_t GEmptyArrayPayload[] = {uint8_t(CbFieldType::Array), 0x01, 0x00}; +} // namespace CompactBinaryPrivate + +////////////////////////////////////////////////////////////////////////// + +CbFieldView::CbFieldView(const void* DataPointer, CbFieldType FieldType) +{ + const uint8_t* Bytes = static_cast<const uint8_t*>(DataPointer); + const CbFieldType LocalType = CbFieldTypeOps::HasFieldType(FieldType) ? (CbFieldType(*Bytes++) | CbFieldType::HasFieldType) : FieldType; + + uint32_t NameLenByteCount = 0; + const uint64_t NameLen64 = CbFieldTypeOps::HasFieldName(LocalType) ? ReadVarUInt(Bytes, NameLenByteCount) : 0; + Bytes += NameLen64 + NameLenByteCount; + + Type = LocalType; + NameLen = uint32_t(std::clamp<uint64_t>(NameLen64, 0, ~uint32_t(0))); + Payload = Bytes; +} + +void +CbFieldView::IterateAttachments(std::function<void(CbFieldView)> Visitor) const +{ + switch (CbFieldTypeOps::GetType(Type)) + { + case CbFieldType::Object: + case CbFieldType::UniformObject: + return CbObjectView::FromFieldView(*this).IterateAttachments(Visitor); + case CbFieldType::Array: + case CbFieldType::UniformArray: + return CbArrayView::FromFieldView(*this).IterateAttachments(Visitor); + case CbFieldType::ObjectAttachment: + case CbFieldType::BinaryAttachment: + return Visitor(*this); + default: + return; + } +} + +CbObjectView +CbFieldView::AsObjectView() +{ + if (CbFieldTypeOps::IsObject(Type)) + { + Error = CbFieldError::None; + return CbObjectView::FromFieldView(*this); + } + else + { + Error = CbFieldError::TypeError; + return CbObjectView(); + } +} + +CbArrayView +CbFieldView::AsArrayView() +{ + if (CbFieldTypeOps::IsArray(Type)) + { + Error = CbFieldError::None; + return CbArrayView::FromFieldView(*this); + } + else + { + Error = CbFieldError::TypeError; + return CbArrayView(); + } +} + +MemoryView +CbFieldView::AsBinaryView(const MemoryView Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsBinary(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsBinary(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +std::string_view +CbFieldView::AsString(const std::string_view Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsString(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsString(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +std::u8string_view +CbFieldView::AsU8String(const std::u8string_view Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsString(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsU8String(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +uint64_t +CbFieldView::AsInteger(const uint64_t Default, const CompactBinaryPrivate::IntegerParams Params) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsInteger(Accessor.GetType())) + { + return Accessor.AsInteger(Params, &Error, Default); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +float +CbFieldView::AsFloat(const float Default) +{ + switch (CbValue Accessor = GetValue(); Accessor.GetType()) + { + case CbFieldType::IntegerPositive: + case CbFieldType::IntegerNegative: + { + const uint64_t IsNegative = uint8_t(Accessor.GetType()) & 1; + constexpr uint64_t OutOfRangeMask = ~((uint64_t(1) << /*FLT_MANT_DIG*/ 24) - 1); + + uint32_t MagnitudeByteCount; + const int64_t Magnitude = ReadVarUInt(Accessor.GetData(), MagnitudeByteCount) + IsNegative; + const uint64_t IsInRange = !(Magnitude & OutOfRangeMask); + Error = IsInRange ? CbFieldError::None : CbFieldError::RangeError; + return IsInRange ? float(IsNegative ? -Magnitude : Magnitude) : Default; + } + case CbFieldType::Float32: + { + Error = CbFieldError::None; + return Accessor.AsFloat32(); + } + case CbFieldType::Float64: + Error = CbFieldError::RangeError; + return Default; + default: + Error = CbFieldError::TypeError; + return Default; + } +} + +double +CbFieldView::AsDouble(const double Default) +{ + switch (CbValue Accessor = GetValue(); Accessor.GetType()) + { + case CbFieldType::IntegerPositive: + case CbFieldType::IntegerNegative: + { + const uint64_t IsNegative = uint8_t(Accessor.GetType()) & 1; + constexpr uint64_t OutOfRangeMask = ~((uint64_t(1) << /*DBL_MANT_DIG*/ 53) - 1); + + uint32_t MagnitudeByteCount; + const int64_t Magnitude = ReadVarUInt(Accessor.GetData(), MagnitudeByteCount) + IsNegative; + const uint64_t IsInRange = !(Magnitude & OutOfRangeMask); + Error = IsInRange ? CbFieldError::None : CbFieldError::RangeError; + return IsInRange ? double(IsNegative ? -Magnitude : Magnitude) : Default; + } + case CbFieldType::Float32: + { + Error = CbFieldError::None; + return Accessor.AsFloat32(); + } + case CbFieldType::Float64: + { + Error = CbFieldError::None; + return Accessor.AsFloat64(); + } + default: + Error = CbFieldError::TypeError; + return Default; + } +} + +bool +CbFieldView::AsBool(const bool bDefault) +{ + CbValue Accessor = GetValue(); + const bool IsBool = CbFieldTypeOps::IsBool(Accessor.GetType()); + Error = IsBool ? CbFieldError::None : CbFieldError::TypeError; + return (uint8_t(IsBool) & Accessor.AsBool()) | ((!IsBool) & bDefault); +} + +IoHash +CbFieldView::AsObjectAttachment(const IoHash& Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsObjectAttachment(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsObjectAttachment(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +IoHash +CbFieldView::AsBinaryAttachment(const IoHash& Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsBinaryAttachment(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsBinaryAttachment(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +IoHash +CbFieldView::AsAttachment(const IoHash& Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsAttachment(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsAttachment(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +IoHash +CbFieldView::AsHash(const IoHash& Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsHash(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsHash(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +Guid +CbFieldView::AsUuid() +{ + return AsUuid(Guid()); +} + +Guid +CbFieldView::AsUuid(const Guid& Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsUuid(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsUuid(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +Oid +CbFieldView::AsObjectId() +{ + return AsObjectId(Oid()); +} + +Oid +CbFieldView::AsObjectId(const Oid& Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsObjectId(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsObjectId(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +CbCustomById +CbFieldView::AsCustomById(CbCustomById Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsCustomById(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsCustomById(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +CbCustomByName +CbFieldView::AsCustomByName(CbCustomByName Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsCustomByName(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsCustomByName(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +int64_t +CbFieldView::AsDateTimeTicks(const int64_t Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsDateTime(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsDateTimeTicks(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +DateTime +CbFieldView::AsDateTime() +{ + return DateTime(AsDateTimeTicks(0)); +} + +DateTime +CbFieldView::AsDateTime(DateTime Default) +{ + return DateTime(AsDateTimeTicks(Default.GetTicks())); +} + +int64_t +CbFieldView::AsTimeSpanTicks(const int64_t Default) +{ + if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsTimeSpan(Accessor.GetType())) + { + Error = CbFieldError::None; + return Accessor.AsTimeSpanTicks(); + } + else + { + Error = CbFieldError::TypeError; + return Default; + } +} + +TimeSpan +CbFieldView::AsTimeSpan() +{ + return TimeSpan(AsTimeSpanTicks(0)); +} + +TimeSpan +CbFieldView::AsTimeSpan(TimeSpan Default) +{ + return TimeSpan(AsTimeSpanTicks(Default.GetTicks())); +} + +uint64_t +CbFieldView::GetSize() const +{ + return sizeof(CbFieldType) + GetViewNoType().GetSize(); +} + +uint64_t +CbFieldView::GetPayloadSize() const +{ + switch (CbFieldTypeOps::GetType(Type)) + { + case CbFieldType::None: + case CbFieldType::Null: + return 0; + case CbFieldType::Object: + case CbFieldType::UniformObject: + case CbFieldType::Array: + case CbFieldType::UniformArray: + case CbFieldType::Binary: + case CbFieldType::String: + { + uint32_t PayloadSizeByteCount; + const uint64_t PayloadSize = ReadVarUInt(Payload, PayloadSizeByteCount); + return PayloadSize + PayloadSizeByteCount; + } + case CbFieldType::IntegerPositive: + case CbFieldType::IntegerNegative: + return MeasureVarUInt(Payload); + case CbFieldType::Float32: + return 4; + case CbFieldType::Float64: + return 8; + case CbFieldType::BoolFalse: + case CbFieldType::BoolTrue: + return 0; + case CbFieldType::ObjectAttachment: + case CbFieldType::BinaryAttachment: + case CbFieldType::Hash: + return 20; + case CbFieldType::Uuid: + return 16; + case CbFieldType::ObjectId: + return 12; + case CbFieldType::DateTime: + case CbFieldType::TimeSpan: + return 8; + default: + return 0; + } +} + +IoHash +CbFieldView::GetHash() const +{ + IoHashStream HashStream; + GetHash(HashStream); + return HashStream.GetHash(); +} + +void +CbFieldView::GetHash(IoHashStream& Hash) const +{ + const CbFieldType SerializedType = CbFieldTypeOps::GetSerializedType(Type); + Hash.Append(&SerializedType, sizeof(SerializedType)); + auto View = GetViewNoType(); + Hash.Append(View.GetData(), View.GetSize()); +} + +bool +CbFieldView::Equals(const CbFieldView& Other) const +{ + return CbFieldTypeOps::GetSerializedType(Type) == CbFieldTypeOps::GetSerializedType(Other.Type) && + GetViewNoType().EqualBytes(Other.GetViewNoType()); +} + +void +CbFieldView::CopyTo(MutableMemoryView Buffer) const +{ + const MemoryView Source = GetViewNoType(); + ZEN_ASSERT(Buffer.GetSize() == sizeof(CbFieldType) + Source.GetSize()); + // TEXT("A buffer of %" UINT64_FMT " bytes was provided when %" UINT64_FMT " bytes are required"), + // Buffer.GetSize(), + // sizeof(CbFieldType) + Source.GetSize()); + *static_cast<CbFieldType*>(Buffer.GetData()) = CbFieldTypeOps::GetSerializedType(Type); + Buffer.RightChopInline(sizeof(CbFieldType)); + memcpy(Buffer.GetData(), Source.GetData(), Source.GetSize()); +} + +void +CbFieldView::CopyTo(BinaryWriter& Ar) const +{ + const MemoryView SourceView = GetViewNoType(); + CbFieldType SerializedType = CbFieldTypeOps::GetSerializedType(Type); + const MemoryView TypeView(reinterpret_cast<const uint8_t*>(&SerializedType), sizeof(SerializedType)); + Ar.Write({TypeView, SourceView}); +} + +MemoryView +CbFieldView::GetView() const +{ + const uint32_t TypeSize = CbFieldTypeOps::HasFieldType(Type) ? sizeof(CbFieldType) : 0; + const uint32_t NameSize = CbFieldTypeOps::HasFieldName(Type) ? NameLen + MeasureVarUInt(NameLen) : 0; + const uint64_t PayloadSize = GetPayloadSize(); + return MemoryView(static_cast<const uint8_t*>(Payload) - TypeSize - NameSize, TypeSize + NameSize + PayloadSize); +} + +MemoryView +CbFieldView::GetViewNoType() const +{ + const uint32_t NameSize = CbFieldTypeOps::HasFieldName(Type) ? NameLen + MeasureVarUInt(NameLen) : 0; + const uint64_t PayloadSize = GetPayloadSize(); + return MemoryView(static_cast<const uint8_t*>(Payload) - NameSize, NameSize + PayloadSize); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +CbArrayView::CbArrayView() : CbFieldView(CompactBinaryPrivate::GEmptyArrayPayload) +{ +} + +uint64_t +CbArrayView::Num() const +{ + const uint8_t* PayloadBytes = static_cast<const uint8_t*>(GetPayload()); + PayloadBytes += MeasureVarUInt(PayloadBytes); + uint32_t NumByteCount; + return ReadVarUInt(PayloadBytes, NumByteCount); +} + +CbFieldViewIterator +CbArrayView::CreateViewIterator() const +{ + const uint8_t* PayloadBytes = static_cast<const uint8_t*>(GetPayload()); + uint32_t PayloadSizeByteCount; + const uint64_t PayloadSize = ReadVarUInt(PayloadBytes, PayloadSizeByteCount); + PayloadBytes += PayloadSizeByteCount; + const uint64_t NumByteCount = MeasureVarUInt(PayloadBytes); + if (PayloadSize > NumByteCount) + { + const void* const PayloadEnd = PayloadBytes + PayloadSize; + PayloadBytes += NumByteCount; + const CbFieldType UniformType = + CbFieldTypeOps::GetType(GetType()) == CbFieldType::UniformArray ? CbFieldType(*PayloadBytes++) : CbFieldType::HasFieldType; + return CbFieldViewIterator::MakeRange(MemoryView(PayloadBytes, PayloadEnd), UniformType); + } + return CbFieldViewIterator(); +} + +void +CbArrayView::VisitFields(ICbVisitor&) +{ +} + +uint64_t +CbArrayView::GetSize() const +{ + return sizeof(CbFieldType) + GetPayloadSize(); +} + +IoHash +CbArrayView::GetHash() const +{ + IoHashStream Hash; + GetHash(Hash); + return Hash.GetHash(); +} + +void +CbArrayView::GetHash(IoHashStream& HashStream) const +{ + const CbFieldType SerializedType = CbFieldTypeOps::GetType(GetType()); + HashStream.Append(&SerializedType, sizeof(SerializedType)); + auto _ = GetPayloadView(); + HashStream.Append(_.GetData(), _.GetSize()); +} + +bool +CbArrayView::Equals(const CbArrayView& Other) const +{ + return CbFieldTypeOps::GetType(GetType()) == CbFieldTypeOps::GetType(Other.GetType()) && + GetPayloadView().EqualBytes(Other.GetPayloadView()); +} + +void +CbArrayView::CopyTo(MutableMemoryView Buffer) const +{ + const MemoryView Source = GetPayloadView(); + ZEN_ASSERT(Buffer.GetSize() == sizeof(CbFieldType) + Source.GetSize()); + // TEXT("Buffer is %" UINT64_FMT " bytes but %" UINT64_FMT " is required."), + // Buffer.GetSize(), + // sizeof(CbFieldType) + Source.GetSize()); + + *static_cast<CbFieldType*>(Buffer.GetData()) = CbFieldTypeOps::GetType(GetType()); + Buffer.RightChopInline(sizeof(CbFieldType)); + memcpy(Buffer.GetData(), Source.GetData(), Source.GetSize()); +} + +void +CbArrayView::CopyTo(BinaryWriter& Ar) const +{ + const MemoryView SourceView = GetPayloadView(); + CbFieldType SerializedType = CbFieldTypeOps::GetSerializedType(GetType()); + const MemoryView TypeView(reinterpret_cast<const uint8_t*>(&SerializedType), sizeof(SerializedType)); + Ar.Write({TypeView, SourceView}); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +CbObjectView::CbObjectView() : CbFieldView(CompactBinaryPrivate::GEmptyObjectPayload) +{ +} + +CbFieldViewIterator +CbObjectView::CreateViewIterator() const +{ + const uint8_t* PayloadBytes = static_cast<const uint8_t*>(GetPayload()); + uint32_t PayloadSizeByteCount; + const uint64_t PayloadSize = ReadVarUInt(PayloadBytes, PayloadSizeByteCount); + + PayloadBytes += PayloadSizeByteCount; + + if (PayloadSize) + { + const void* const PayloadEnd = PayloadBytes + PayloadSize; + const CbFieldType UniformType = + CbFieldTypeOps::GetType(GetType()) == CbFieldType::UniformObject ? CbFieldType(*PayloadBytes++) : CbFieldType::HasFieldType; + return CbFieldViewIterator::MakeRange(MemoryView(PayloadBytes, PayloadEnd), UniformType); + } + + return CbFieldViewIterator(); +} + +void +CbObjectView::VisitFields(ICbVisitor&) +{ +} + +CbFieldView +CbObjectView::FindView(const std::string_view Name) const +{ + for (const CbFieldView& Field : *this) + { + if (Name == Field.GetName()) + { + return Field; + } + } + return CbFieldView(); +} + +CbFieldView +CbObjectView::FindViewIgnoreCase(const std::string_view Name) const +{ + for (const CbFieldView& Field : *this) + { + if (Name == Field.GetName()) + { + return Field; + } + } + return CbFieldView(); +} + +CbObjectView::operator bool() const +{ + return GetSize() > sizeof(CompactBinaryPrivate::GEmptyObjectPayload); +} + +uint64_t +CbObjectView::GetSize() const +{ + return sizeof(CbFieldType) + GetPayloadSize(); +} + +IoHash +CbObjectView::GetHash() const +{ + IoHashStream Hash; + GetHash(Hash); + return Hash.GetHash(); +} + +void +CbObjectView::GetHash(IoHashStream& HashStream) const +{ + const CbFieldType SerializedType = CbFieldTypeOps::GetType(GetType()); + HashStream.Append(&SerializedType, sizeof(SerializedType)); + HashStream.Append(GetPayloadView()); +} + +bool +CbObjectView::Equals(const CbObjectView& Other) const +{ + return CbFieldTypeOps::GetType(GetType()) == CbFieldTypeOps::GetType(Other.GetType()) && + GetPayloadView().EqualBytes(Other.GetPayloadView()); +} + +void +CbObjectView::CopyTo(MutableMemoryView Buffer) const +{ + const MemoryView Source = GetPayloadView(); + ZEN_ASSERT(Buffer.GetSize() == (sizeof(CbFieldType) + Source.GetSize())); + // TEXT("Buffer is %" UINT64_FMT " bytes but %" UINT64_FMT " is required."), + // Buffer.GetSize(), + // sizeof(CbFieldType) + Source.GetSize()); + *static_cast<CbFieldType*>(Buffer.GetData()) = CbFieldTypeOps::GetType(GetType()); + Buffer.RightChopInline(sizeof(CbFieldType)); + memcpy(Buffer.GetData(), Source.GetData(), Source.GetSize()); +} + +void +CbObjectView::CopyTo(BinaryWriter& Ar) const +{ + const MemoryView SourceView = GetPayloadView(); + CbFieldType SerializedType = CbFieldTypeOps::GetSerializedType(GetType()); + const MemoryView TypeView(reinterpret_cast<const uint8_t*>(&SerializedType), sizeof(SerializedType)); + Ar.Write({TypeView, SourceView}); +} + +////////////////////////////////////////////////////////////////////////// + +template<typename FieldType> +uint64_t +TCbFieldIterator<FieldType>::GetRangeSize() const +{ + MemoryView View; + if (TryGetSerializedRangeView(View)) + { + return View.GetSize(); + } + else + { + uint64_t Size = 0; + for (CbFieldViewIterator It(*this); It; ++It) + { + Size += It.GetSize(); + } + return Size; + } +} + +template<typename FieldType> +IoHash +TCbFieldIterator<FieldType>::GetRangeHash() const +{ + IoHashStream Hash; + GetRangeHash(Hash); + return IoHash(Hash.GetHash()); +} + +template<typename FieldType> +void +TCbFieldIterator<FieldType>::GetRangeHash(IoHashStream& Hash) const +{ + MemoryView View; + if (TryGetSerializedRangeView(View)) + { + Hash.Append(View.GetData(), View.GetSize()); + } + else + { + for (CbFieldViewIterator It(*this); It; ++It) + { + It.GetHash(Hash); + } + } +} + +template<typename FieldType> +void +TCbFieldIterator<FieldType>::CopyRangeTo(MutableMemoryView InBuffer) const +{ + MemoryView Source; + if (TryGetSerializedRangeView(Source)) + { + ZEN_ASSERT(InBuffer.GetSize() == Source.GetSize()); + // TEXT("Buffer is %" UINT64_FMT " bytes but %" UINT64_FMT " is required."), + // InBuffer.GetSize(), + // Source.GetSize()); + memcpy(InBuffer.GetData(), Source.GetData(), Source.GetSize()); + } + else + { + for (CbFieldViewIterator It(*this); It; ++It) + { + const uint64_t Size = It.GetSize(); + It.CopyTo(InBuffer.Left(Size)); + InBuffer.RightChopInline(Size); + } + } +} + +template class TCbFieldIterator<CbFieldView>; +template class TCbFieldIterator<CbField>; + +template<typename FieldType> +void +TCbFieldIterator<FieldType>::IterateRangeAttachments(std::function<void(CbFieldView)> Visitor) const +{ + if (CbFieldTypeOps::HasFieldType(FieldType::GetType())) + { + // Always iterate over non-uniform ranges because we do not know if they contain an attachment. + for (CbFieldViewIterator It(*this); It; ++It) + { + if (CbFieldTypeOps::MayContainAttachments(It.GetType())) + { + It.IterateAttachments(Visitor); + } + } + } + else + { + // Only iterate over uniform ranges if the uniform type may contain an attachment. + if (CbFieldTypeOps::MayContainAttachments(FieldType::GetType())) + { + for (CbFieldViewIterator It(*this); It; ++It) + { + It.IterateAttachments(Visitor); + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +CbFieldIterator +CbFieldIterator::CloneRange(const CbFieldViewIterator& It) +{ + MemoryView View; + if (It.TryGetSerializedRangeView(View)) + { + return MakeRange(SharedBuffer::Clone(View)); + } + else + { + UniqueBuffer Buffer = UniqueBuffer::Alloc(It.GetRangeSize()); + It.CopyRangeTo(MutableMemoryView(Buffer.GetData(), Buffer.GetSize())); + return MakeRange(SharedBuffer(std::move(Buffer))); + } +} + +SharedBuffer +CbFieldIterator::GetRangeBuffer() const +{ + const MemoryView RangeView = GetRangeView(); + const SharedBuffer& OuterBuffer = GetOuterBuffer(); + return OuterBuffer.GetView() == RangeView ? OuterBuffer : SharedBuffer::MakeView(RangeView, OuterBuffer); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +uint64_t +MeasureCompactBinary(MemoryView View, CbFieldType Type) +{ + uint64_t Size; + return TryMeasureCompactBinary(View, Type, Size, Type) ? Size : 0; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +bool +TryMeasureCompactBinary(MemoryView View, CbFieldType& OutType, uint64_t& OutSize, CbFieldType Type) +{ + uint64_t Size = 0; + + if (CbFieldTypeOps::HasFieldType(Type)) + { + if (View.GetSize() == 0) + { + OutType = CbFieldType::None; + OutSize = 1; + return false; + } + + Type = *static_cast<const CbFieldType*>(View.GetData()); + View.RightChopInline(1); + Size += 1; + } + + bool bDynamicSize = false; + uint64_t FixedSize = 0; + switch (CbFieldTypeOps::GetType(Type)) + { + case CbFieldType::Null: + break; + case CbFieldType::Object: + case CbFieldType::UniformObject: + case CbFieldType::Array: + case CbFieldType::UniformArray: + case CbFieldType::Binary: + case CbFieldType::String: + case CbFieldType::IntegerPositive: + case CbFieldType::IntegerNegative: + bDynamicSize = true; + break; + case CbFieldType::Float32: + FixedSize = 4; + break; + case CbFieldType::Float64: + FixedSize = 8; + break; + case CbFieldType::BoolFalse: + case CbFieldType::BoolTrue: + break; + case CbFieldType::ObjectAttachment: + case CbFieldType::BinaryAttachment: + case CbFieldType::Hash: + FixedSize = 20; + break; + case CbFieldType::Uuid: + FixedSize = 16; + break; + case CbFieldType::ObjectId: + FixedSize = 12; + break; + case CbFieldType::DateTime: + case CbFieldType::TimeSpan: + FixedSize = 8; + break; + case CbFieldType::None: + default: + OutType = CbFieldType::None; + OutSize = 0; + return false; + } + + OutType = Type; + + if (CbFieldTypeOps::HasFieldName(Type)) + { + if (View.GetSize() == 0) + { + OutSize = Size + 1; + return false; + } + + uint32_t NameLenByteCount = MeasureVarUInt(View.GetData()); + if (View.GetSize() < NameLenByteCount) + { + OutSize = Size + NameLenByteCount; + return false; + } + + const uint64_t NameLen = ReadVarUInt(View.GetData(), NameLenByteCount); + const uint64_t NameSize = NameLen + NameLenByteCount; + + if (bDynamicSize && View.GetSize() < NameSize) + { + OutSize = Size + NameSize; + return false; + } + + View.RightChopInline(NameSize); + Size += NameSize; + } + + switch (CbFieldTypeOps::GetType(Type)) + { + case CbFieldType::Object: + case CbFieldType::UniformObject: + case CbFieldType::Array: + case CbFieldType::UniformArray: + case CbFieldType::Binary: + case CbFieldType::String: + if (View.GetSize() == 0) + { + OutSize = Size + 1; + return false; + } + else + { + uint32_t PayloadSizeByteCount = MeasureVarUInt(View.GetData()); + if (View.GetSize() < PayloadSizeByteCount) + { + OutSize = Size + PayloadSizeByteCount; + return false; + } + const uint64_t PayloadSize = ReadVarUInt(View.GetData(), PayloadSizeByteCount); + OutSize = Size + PayloadSize + PayloadSizeByteCount; + } + return true; + + case CbFieldType::IntegerPositive: + case CbFieldType::IntegerNegative: + if (View.GetSize() == 0) + { + OutSize = Size + 1; + return false; + } + OutSize = Size + MeasureVarUInt(View.GetData()); + return true; + + default: + OutSize = Size + FixedSize; + return true; + } +} + +////////////////////////////////////////////////////////////////////////// + +CbField +LoadCompactBinary(BinaryReader& Ar, BufferAllocator Allocator) +{ + std::vector<uint8_t> HeaderBytes; + CbFieldType FieldType; + uint64_t FieldSize = 1; + + for (const int64_t StartPos = Ar.CurrentOffset(); FieldSize > 0;) + { + // Read in small increments until the total field size is known, to avoid reading too far. + const int32_t ReadSize = int32_t(FieldSize - HeaderBytes.size()); + if (Ar.CurrentOffset() + ReadSize > Ar.GetSize()) + { + break; + } + + const size_t ReadOffset = HeaderBytes.size(); + HeaderBytes.resize(ReadOffset + ReadSize); + + Ar.Read(HeaderBytes.data() + ReadOffset, ReadSize); + if (TryMeasureCompactBinary(MakeMemoryView(HeaderBytes), FieldType, FieldSize)) + { + if (FieldSize <= uint64_t(Ar.Size() - StartPos)) + { + UniqueBuffer Buffer = Allocator(FieldSize); + ZEN_ASSERT(Buffer.GetSize() == FieldSize); + MutableMemoryView View = Buffer.GetMutableView(); + memcpy(View.GetData(), HeaderBytes.data(), HeaderBytes.size()); + View.RightChopInline(HeaderBytes.size()); + if (!View.IsEmpty()) + { + // Read the remainder of the field. + Ar.Read(View.GetData(), View.GetSize()); + } + if (ValidateCompactBinary(Buffer, CbValidateMode::Default) == CbValidateError::None) + { + return CbField(SharedBuffer(std::move(Buffer))); + } + } + break; + } + } + return CbField(); +} + +CbObject +LoadCompactBinaryObject(IoBuffer&& Payload) +{ + return CbObject{SharedBuffer(std::move(Payload))}; +} + +CbObject +LoadCompactBinaryObject(const IoBuffer& Payload) +{ + return CbObject{SharedBuffer(Payload)}; +} + +CbObject +LoadCompactBinaryObject(CompressedBuffer&& Payload) +{ + return CbObject{SharedBuffer(Payload.DecompressToComposite().Flatten())}; +} + +CbObject +LoadCompactBinaryObject(const CompressedBuffer& Payload) +{ + return CbObject{SharedBuffer(Payload.DecompressToComposite().Flatten())}; +} + +////////////////////////////////////////////////////////////////////////// + +void +SaveCompactBinary(BinaryWriter& Ar, const CbFieldView& Field) +{ + Field.CopyTo(Ar); +} + +void +SaveCompactBinary(BinaryWriter& Ar, const CbArrayView& Array) +{ + Array.CopyTo(Ar); +} + +void +SaveCompactBinary(BinaryWriter& Ar, const CbObjectView& Object) +{ + Object.CopyTo(Ar); +} + +////////////////////////////////////////////////////////////////////////// + +class CbJsonWriter +{ +public: + explicit CbJsonWriter(StringBuilderBase& InBuilder) : Builder(InBuilder) { NewLineAndIndent << LINE_TERMINATOR_ANSI; } + + void WriteField(CbFieldView Field) + { + using namespace std::literals; + + WriteOptionalComma(); + WriteOptionalNewLine(); + + if (std::u8string_view Name = Field.GetU8Name(); !Name.empty()) + { + AppendQuotedString(Name); + Builder << ": "sv; + } + + switch (CbValue Accessor = Field.GetValue(); Accessor.GetType()) + { + case CbFieldType::Null: + Builder << "null"sv; + break; + case CbFieldType::Object: + case CbFieldType::UniformObject: + { + Builder << '{'; + NewLineAndIndent << '\t'; + NeedsNewLine = true; + for (CbFieldView It : Field) + { + WriteField(It); + } + NewLineAndIndent.RemoveSuffix(1); + if (NeedsComma) + { + WriteOptionalNewLine(); + } + Builder << '}'; + } + break; + case CbFieldType::Array: + case CbFieldType::UniformArray: + { + Builder << '['; + NewLineAndIndent << '\t'; + NeedsNewLine = true; + for (CbFieldView It : Field) + { + WriteField(It); + } + NewLineAndIndent.RemoveSuffix(1); + if (NeedsComma) + { + WriteOptionalNewLine(); + } + Builder << ']'; + } + break; + case CbFieldType::Binary: + AppendBase64String(Accessor.AsBinary()); + break; + case CbFieldType::String: + AppendQuotedString(Accessor.AsU8String()); + break; + case CbFieldType::IntegerPositive: + Builder << Accessor.AsIntegerPositive(); + break; + case CbFieldType::IntegerNegative: + Builder << Accessor.AsIntegerNegative(); + break; + case CbFieldType::Float32: + { + const float Value = Accessor.AsFloat32(); + if (std::isfinite(Value)) + { + Builder.Append(fmt::format("{:.9g}", Value)); + } + else + { + Builder << "null"sv; + } + } + break; + case CbFieldType::Float64: + { + const double Value = Accessor.AsFloat64(); + if (std::isfinite(Value)) + { + Builder.Append(fmt::format("{:.17g}", Value)); + } + else + { + Builder << "null"sv; + } + } + break; + case CbFieldType::BoolFalse: + Builder << "false"sv; + break; + case CbFieldType::BoolTrue: + Builder << "true"sv; + break; + case CbFieldType::ObjectAttachment: + case CbFieldType::BinaryAttachment: + { + Builder << '"'; + Accessor.AsAttachment().ToHexString(Builder); + Builder << '"'; + } + break; + case CbFieldType::Hash: + { + Builder << '"'; + Accessor.AsHash().ToHexString(Builder); + Builder << '"'; + } + break; + case CbFieldType::Uuid: + { + Builder << '"'; + Accessor.AsUuid().ToString(Builder); + Builder << '"'; + } + break; + case CbFieldType::DateTime: + Builder << '"' << DateTime(Accessor.AsDateTimeTicks()).ToIso8601() << '"'; + break; + case CbFieldType::TimeSpan: + { + const TimeSpan Span(Accessor.AsTimeSpanTicks()); + if (Span.GetDays() == 0) + { + Builder << '"' << Span.ToString("%h:%m:%s.%n") << '"'; + } + else + { + Builder << '"' << Span.ToString("%d.%h:%m:%s.%n") << '"'; + } + break; + } + case CbFieldType::ObjectId: + Builder << '"'; + Accessor.AsObjectId().ToString(Builder); + Builder << '"'; + break; + case CbFieldType::CustomById: + { + CbCustomById Custom = Accessor.AsCustomById(); + Builder << "{ \"Id\": "; + Builder << Custom.Id; + Builder << ", \"Data\": "; + AppendBase64String(Custom.Data); + Builder << " }"; + break; + } + case CbFieldType::CustomByName: + { + CbCustomByName Custom = Accessor.AsCustomByName(); + Builder << "{ \"Name\": "; + AppendQuotedString(Custom.Name); + Builder << ", \"Data\": "; + AppendBase64String(Custom.Data); + Builder << " }"; + break; + } + default: + ZEN_ASSERT(false); + break; + } + + NeedsComma = true; + NeedsNewLine = true; + } + +private: + void WriteOptionalComma() + { + if (NeedsComma) + { + NeedsComma = false; + Builder << ','; + } + } + + void WriteOptionalNewLine() + { + if (NeedsNewLine) + { + NeedsNewLine = false; + Builder << NewLineAndIndent; + } + } + + void AppendQuotedString(std::u8string_view Value) + { + using namespace std::literals; + + const AsciiSet EscapeSet( + "\\\"\b\f\n\r\t" + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f" + "\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"); + + Builder << '\"'; + while (!Value.empty()) + { + std::u8string_view Verbatim = AsciiSet::FindPrefixWithout(Value, EscapeSet); + Builder << Verbatim; + + Value = Value.substr(Verbatim.size()); + + std::u8string_view Escape = AsciiSet::FindPrefixWith(Value, EscapeSet); + for (char Char : Escape) + { + switch (Char) + { + case '\\': + Builder << "\\\\"sv; + break; + case '\"': + Builder << "\\\""sv; + break; + case '\b': + Builder << "\\b"sv; + break; + case '\f': + Builder << "\\f"sv; + break; + case '\n': + Builder << "\\n"sv; + break; + case '\r': + Builder << "\\r"sv; + break; + case '\t': + Builder << "\\t"sv; + break; + default: + Builder << Char; + break; + } + } + Value = Value.substr(Escape.size()); + } + Builder << '\"'; + } + + void AppendBase64String(MemoryView Value) + { + Builder << '"'; + ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024); + const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize())); + const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize)); + Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex); + } + +private: + StringBuilderBase& Builder; + ExtendableStringBuilder<32> NewLineAndIndent; + bool NeedsComma{false}; + bool NeedsNewLine{false}; +}; + +void +CompactBinaryToJson(const CbObjectView& Object, StringBuilderBase& Builder) +{ + CbJsonWriter Writer(Builder); + Writer.WriteField(Object.AsFieldView()); +} + +void +CompactBinaryToJson(const CbArrayView& Array, StringBuilderBase& Builder) +{ + CbJsonWriter Writer(Builder); + Writer.WriteField(Array.AsFieldView()); +} + +////////////////////////////////////////////////////////////////////////// + +class CbJsonReader +{ +public: + static CbFieldIterator Read(std::string_view JsonText, std::string& Error) + { + using namespace json11; + + const Json Json = Json::parse(std::string(JsonText), Error); + + if (Error.empty()) + { + CbWriter Writer; + if (ReadField(Writer, Json, std::string_view(), Error)) + { + return Writer.Save(); + } + } + + return CbFieldIterator(); + } + +private: + static bool ReadField(CbWriter& Writer, const json11::Json& Json, const std::string_view FieldName, std::string& Error) + { + using namespace json11; + + switch (Json.type()) + { + case Json::Type::OBJECT: + { + if (FieldName.empty()) + { + Writer.BeginObject(); + } + else + { + Writer.BeginObject(FieldName); + } + + for (const auto& Kv : Json.object_items()) + { + const std::string& Name = Kv.first; + const json11::Json& Item = Kv.second; + + if (ReadField(Writer, Item, Name, Error) == false) + { + return false; + } + } + + Writer.EndObject(); + } + break; + case Json::Type::ARRAY: + { + if (FieldName.empty()) + { + Writer.BeginArray(); + } + else + { + Writer.BeginArray(FieldName); + } + + for (const json11::Json& Item : Json.array_items()) + { + if (ReadField(Writer, Item, std::string_view(), Error) == false) + { + return false; + } + } + + Writer.EndArray(); + } + break; + case Json::Type::NUL: + { + if (FieldName.empty()) + { + Writer.AddNull(); + } + else + { + Writer.AddNull(FieldName); + } + } + break; + case Json::Type::BOOL: + { + if (FieldName.empty()) + { + Writer.AddBool(Json.bool_value()); + } + else + { + Writer.AddBool(FieldName, Json.bool_value()); + } + } + break; + case Json::Type::NUMBER: + { + if (FieldName.empty()) + { + Writer.AddFloat(Json.number_value()); + } + else + { + Writer.AddFloat(FieldName, Json.number_value()); + } + } + break; + case Json::Type::STRING: + { + Oid Id; + if (TryParseObjectId(Json.string_value(), Id)) + { + if (FieldName.empty()) + { + Writer.AddObjectId(Id); + } + else + { + Writer.AddObjectId(FieldName, Id); + } + + return true; + } + + IoHash Hash; + if (TryParseIoHash(Json.string_value(), Hash)) + { + if (FieldName.empty()) + { + Writer.AddHash(Hash); + } + else + { + Writer.AddHash(FieldName, Hash); + } + + return true; + } + + if (FieldName.empty()) + { + Writer.AddString(Json.string_value()); + } + else + { + Writer.AddString(FieldName, Json.string_value()); + } + } + break; + default: + break; + } + + return true; + } + + static constexpr AsciiSet HexCharSet = AsciiSet("0123456789abcdefABCDEF"); + + static bool TryParseObjectId(std::string_view Str, Oid& Id) + { + using namespace std::literals; + + if (Str.size() == Oid::StringLength && AsciiSet::HasOnly(Str, HexCharSet)) + { + Id = Oid::FromHexString(Str); + return true; + } + + if (Str.starts_with("0x"sv)) + { + return TryParseObjectId(Str.substr(2), Id); + } + + return false; + } + + static bool TryParseIoHash(std::string_view Str, IoHash& Hash) + { + using namespace std::literals; + + if (Str.size() == IoHash::StringLength && AsciiSet::HasOnly(Str, HexCharSet)) + { + Hash = IoHash::FromHexString(Str); + return true; + } + + if (Str.starts_with("0x"sv)) + { + return TryParseIoHash(Str.substr(2), Hash); + } + + return false; + } +}; + +CbFieldIterator +LoadCompactBinaryFromJson(std::string_view Json, std::string& Error) +{ + if (Json.empty() == false) + { + return CbJsonReader::Read(Json, Error); + } + + return CbFieldIterator(); +} + +CbFieldIterator +LoadCompactBinaryFromJson(std::string_view Json) +{ + std::string Error; + return LoadCompactBinaryFromJson(Json, Error); +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS +void +uson_forcelink() +{ +} + +TEST_CASE("uson") +{ + using namespace std::literals; + + SUBCASE("CbField") + { + constexpr CbFieldView DefaultField; + static_assert(!DefaultField.HasName(), "Error in HasName()"); + static_assert(!DefaultField.HasValue(), "Error in HasValue()"); + static_assert(!DefaultField.HasError(), "Error in HasError()"); + static_assert(DefaultField.GetError() == CbFieldError::None, "Error in GetError()"); + + CHECK(DefaultField.GetSize() == 1); + CHECK(DefaultField.GetName().size() == 0); + CHECK(DefaultField.HasName() == false); + CHECK(DefaultField.HasValue() == false); + CHECK(DefaultField.HasError() == false); + CHECK(DefaultField.GetError() == CbFieldError::None); + + const uint8_t Type = (uint8_t)CbFieldType::None; + CHECK(DefaultField.GetHash() == IoHash::HashBuffer(&Type, sizeof Type)); + + CHECK(DefaultField.GetView() == MemoryView{}); + MemoryView SerializedView; + CHECK(DefaultField.TryGetSerializedView(SerializedView) == false); + } + + SUBCASE("CbField(None)") + { + CbFieldView NoneField(nullptr, CbFieldType::None); + CHECK(NoneField.GetSize() == 1); + CHECK(NoneField.GetName().size() == 0); + CHECK(NoneField.HasName() == false); + CHECK(NoneField.HasValue() == false); + CHECK(NoneField.HasError() == false); + CHECK(NoneField.GetError() == CbFieldError::None); + CHECK(NoneField.GetHash() == CbFieldView().GetHash()); + CHECK(NoneField.GetView() == MemoryView()); + MemoryView SerializedView; + CHECK(NoneField.TryGetSerializedView(SerializedView) == false); + } + + SUBCASE("CbField(None|Type|Name)") + { + constexpr CbFieldType FieldType = CbFieldType::None | CbFieldType::HasFieldName; + const char NoneBytes[] = {char(FieldType), 4, 'N', 'a', 'm', 'e'}; + CbFieldView NoneField(NoneBytes); + + CHECK(NoneField.GetSize() == sizeof(NoneBytes)); + CHECK(NoneField.GetName().compare("Name"sv) == 0); + CHECK(NoneField.HasName() == true); + CHECK(NoneField.HasValue() == false); + CHECK(NoneField.GetHash() == IoHash::HashBuffer(NoneBytes, sizeof NoneBytes)); + CHECK(NoneField.GetView() == MemoryView(NoneBytes, sizeof NoneBytes)); + MemoryView SerializedView; + CHECK(NoneField.TryGetSerializedView(SerializedView) == true); + CHECK(SerializedView == MemoryView(NoneBytes, sizeof NoneBytes)); + + uint8_t CopyBytes[sizeof(NoneBytes)]; + NoneField.CopyTo(MutableMemoryView(CopyBytes, sizeof CopyBytes)); + CHECK(MemoryView(NoneBytes, sizeof NoneBytes).EqualBytes(MemoryView(CopyBytes, sizeof CopyBytes))); + } + + SUBCASE("CbField(None|Type)") + { + constexpr CbFieldType FieldType = CbFieldType::None; + const char NoneBytes[] = {char(FieldType)}; + CbFieldView NoneField(NoneBytes); + + CHECK(NoneField.GetSize() == sizeof NoneBytes); + CHECK(NoneField.GetName().size() == 0); + CHECK(NoneField.HasName() == false); + CHECK(NoneField.HasValue() == false); + CHECK(NoneField.GetHash() == CbFieldView().GetHash()); + CHECK(NoneField.GetView() == MemoryView(NoneBytes, sizeof NoneBytes)); + MemoryView SerializedView; + CHECK(NoneField.TryGetSerializedView(SerializedView) == true); + CHECK(SerializedView == MemoryView(NoneBytes, sizeof NoneBytes)); + } + + SUBCASE("CbField(None|Name)") + { + constexpr CbFieldType FieldType = CbFieldType::None | CbFieldType::HasFieldName; + const char NoneBytes[] = {char(FieldType), 4, 'N', 'a', 'm', 'e'}; + CbFieldView NoneField(NoneBytes + 1, FieldType); + CHECK(NoneField.GetSize() == uint64_t(sizeof NoneBytes)); + CHECK(NoneField.GetName().compare("Name") == 0); + CHECK(NoneField.HasName() == true); + CHECK(NoneField.HasValue() == false); + CHECK(NoneField.GetHash() == IoHash::HashBuffer(NoneBytes, sizeof NoneBytes)); + CHECK(NoneField.GetView() == MemoryView(NoneBytes + 1, sizeof NoneBytes - 1)); + MemoryView SerializedView; + CHECK(NoneField.TryGetSerializedView(SerializedView) == false); + + uint8_t CopyBytes[sizeof(NoneBytes)]; + NoneField.CopyTo(MutableMemoryView(CopyBytes, sizeof CopyBytes)); + CHECK(MemoryView(NoneBytes, sizeof NoneBytes).EqualBytes(MemoryView(CopyBytes, sizeof CopyBytes))); + } + + SUBCASE("CbField(None|EmptyName)") + { + constexpr CbFieldType FieldType = CbFieldType::None | CbFieldType::HasFieldName; + const uint8_t NoneBytes[] = {uint8_t(FieldType), 0}; + CbFieldView NoneField(NoneBytes + 1, FieldType); + CHECK(NoneField.GetSize() == sizeof NoneBytes); + CHECK(NoneField.GetName().empty() == true); + CHECK(NoneField.HasName() == true); + CHECK(NoneField.HasValue() == false); + CHECK(NoneField.GetHash() == IoHash::HashBuffer(NoneBytes, sizeof NoneBytes)); + CHECK(NoneField.GetView() == MemoryView(NoneBytes + 1, sizeof NoneBytes - 1)); + MemoryView SerializedView; + CHECK(NoneField.TryGetSerializedView(SerializedView) == false); + } + + static_assert(!std::is_constructible<CbFieldView, const CbObjectView&>::value, "Invalid constructor for CbField"); + static_assert(!std::is_assignable<CbFieldView, const CbObjectView&>::value, "Invalid assignment for CbField"); + static_assert(!std::is_convertible<CbFieldView, CbObjectView>::value, "Invalid conversion to CbObject"); + static_assert(!std::is_assignable<CbObjectView, const CbFieldView&>::value, "Invalid assignment for CbObject"); + + static_assert(std::is_constructible<CbField>::value, "Missing constructor for CbField"); + static_assert(std::is_constructible<CbField, const CbField&>::value, "Missing constructor for CbField"); + static_assert(std::is_constructible<CbField, CbField&&>::value, "Missing constructor for CbField"); +} + +TEST_CASE("uson.null") +{ + using namespace std::literals; + + SUBCASE("CbField(Null)") + { + CbFieldView NullField(nullptr, CbFieldType::Null); + CHECK(NullField.GetSize() == 1); + CHECK(NullField.IsNull() == true); + CHECK(NullField.HasValue() == true); + CHECK(NullField.HasError() == false); + CHECK(NullField.GetError() == CbFieldError::None); + const uint8_t Null[]{uint8_t(CbFieldType::Null)}; + CHECK(NullField.GetHash() == IoHash::HashBuffer(Null, sizeof Null)); + } + + SUBCASE("CbField(None)") + { + CbFieldView Field; + CHECK(Field.IsNull() == false); + } +} + +TEST_CASE("uson.json") +{ + SUBCASE("string") + { + CbObjectWriter Writer; + Writer << "KeyOne" + << "ValueOne"; + Writer << "KeyTwo" + << "ValueTwo"; + CbObject Obj = Writer.Save(); + + StringBuilder<128> Sb; + const char* JsonText = Obj.ToJson(Sb).Data(); + + std::string JsonError; + json11::Json Json = json11::Json::parse(JsonText, JsonError); + + const std::string ValueOne = Json["KeyOne"].string_value(); + const std::string ValueTwo = Json["KeyTwo"].string_value(); + + CHECK(JsonError.empty()); + CHECK(ValueOne == "ValueOne"); + CHECK(ValueTwo == "ValueTwo"); + } + + SUBCASE("number") + { + const float ExpectedFloatValue = 21.21f; + const double ExpectedDoubleValue = 42.42; + + CbObjectWriter Writer; + Writer << "Float" << ExpectedFloatValue; + Writer << "Double" << ExpectedDoubleValue; + + CbObject Obj = Writer.Save(); + + StringBuilder<128> Sb; + const char* JsonText = Obj.ToJson(Sb).Data(); + + std::string JsonError; + json11::Json Json = json11::Json::parse(JsonText, JsonError); + + const float FloatValue = float(Json["Float"].number_value()); + const double DoubleValue = Json["Double"].number_value(); + + CHECK(JsonError.empty()); + CHECK(FloatValue == Approx(ExpectedFloatValue)); + CHECK(DoubleValue == Approx(ExpectedDoubleValue)); + } + + SUBCASE("number.nan") + { + const float FloatNan = std::numeric_limits<float>::quiet_NaN(); + const double DoubleNan = std::numeric_limits<double>::quiet_NaN(); + + CbObjectWriter Writer; + Writer << "FloatNan" << FloatNan; + Writer << "DoubleNan" << DoubleNan; + + CbObject Obj = Writer.Save(); + + StringBuilder<128> Sb; + const char* JsonText = Obj.ToJson(Sb).Data(); + + std::string JsonError; + json11::Json Json = json11::Json::parse(JsonText, JsonError); + + const double FloatValue = Json["FloatNan"].number_value(); + const double DoubleValue = Json["DoubleNan"].number_value(); + + CHECK(JsonError.empty()); + CHECK(FloatValue == 0); + CHECK(DoubleValue == 0); + } +} + +TEST_CASE("uson.datetime") +{ + using namespace std::literals; + + { + DateTime D1600(1601, 1, 1); + CHECK_EQ(D1600.GetYear(), 1601); + CHECK_EQ(D1600.GetMonth(), 1); + CHECK_EQ(D1600.GetDay(), 1); + CHECK_EQ(D1600.GetHour(), 0); + CHECK_EQ(D1600.GetMinute(), 0); + CHECK_EQ(D1600.GetSecond(), 0); + + CHECK_EQ(D1600.ToIso8601(), "1601-01-01T00:00:00.000Z"sv); + } + + { + DateTime D72(1972, 2, 23, 17, 30, 10); + CHECK_EQ(D72.GetYear(), 1972); + CHECK_EQ(D72.GetMonth(), 2); + CHECK_EQ(D72.GetDay(), 23); + CHECK_EQ(D72.GetHour(), 17); + CHECK_EQ(D72.GetMinute(), 30); + CHECK_EQ(D72.GetSecond(), 10); + } +} + +TEST_CASE("json.uson") +{ + using namespace std::literals; + using namespace json11; + + SUBCASE("empty") + { + CbFieldIterator It = LoadCompactBinaryFromJson(""sv); + CHECK(It.HasValue() == false); + } + + SUBCASE("object") + { + const Json JsonObject = Json::object{{"Null", nullptr}, + {"String", "Value1"}, + {"Bool", true}, + {"Number", 46.2}, + {"Array", Json::array{1, 2, 3}}, + {"Object", + Json::object{ + {"String", "Value2"}, + }}}; + + CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); + + CHECK(Cb["Null"].IsNull()); + CHECK(Cb["String"].AsString() == "Value1"sv); + CHECK(Cb["Bool"].AsBool()); + CHECK(Cb["Number"].AsDouble() == 46.2); + CHECK(Cb["Object"].IsObject()); + CbObjectView Object = Cb["Object"].AsObjectView(); + CHECK(Object["String"].AsString() == "Value2"sv); + } + + SUBCASE("array") + { + const Json JsonArray = Json::array{42, 43, 44}; + CbArray Cb = LoadCompactBinaryFromJson(JsonArray.dump()).AsArray(); + + auto It = Cb.CreateIterator(); + CHECK((*It).AsDouble() == 42); + It++; + CHECK((*It).AsDouble() == 43); + It++; + CHECK((*It).AsDouble() == 44); + } + + SUBCASE("objectid") + { + const Oid& Id = Oid::NewOid(); + + StringBuilder<64> Sb; + Id.ToString(Sb); + + Json JsonObject = Json::object{{"value", Sb.ToString()}}; + CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); + + CHECK(Cb["value"sv].IsObjectId()); + CHECK(Cb["value"sv].AsObjectId() == Id); + + Sb.Reset(); + Sb << "0x"; + Id.ToString(Sb); + + JsonObject = Json::object{{"value", Sb.ToString()}}; + Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); + + CHECK(Cb["value"sv].IsObjectId()); + CHECK(Cb["value"sv].AsObjectId() == Id); + } + + SUBCASE("iohash") + { + const uint8_t Data[] = { + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + }; + + const IoHash Hash = IoHash::HashBuffer(Data, sizeof(Data)); + + Json JsonObject = Json::object{{"value", Hash.ToHexString()}}; + CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); + + CHECK(Cb["value"sv].IsHash()); + CHECK(Cb["value"sv].AsHash() == Hash); + + JsonObject = Json::object{{"value", "0x" + Hash.ToHexString()}}; + Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject(); + + CHECK(Cb["value"sv].IsHash()); + CHECK(Cb["value"sv].AsHash() == Hash); + } +} + +#endif + +} // namespace zen diff --git a/src/zencore/compactbinarybuilder.cpp b/src/zencore/compactbinarybuilder.cpp new file mode 100644 index 000000000..d4ccd434d --- /dev/null +++ b/src/zencore/compactbinarybuilder.cpp @@ -0,0 +1,1545 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/compactbinarybuilder.h" + +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/endian.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +#define _USE_MATH_DEFINES +#include <math.h> + +namespace zen { + +template<typename T> +uint64_t +AddUninitialized(std::vector<T>& Vector, uint64_t Count) +{ + const uint64_t Offset = Vector.size(); + Vector.resize(Offset + Count); + return Offset; +} + +template<typename T> +uint64_t +Append(std::vector<T>& Vector, const T* Data, uint64_t Count) +{ + const uint64_t Offset = Vector.size(); + Vector.resize(Offset + Count); + + memcpy(Vector.data() + Offset, Data, sizeof(T) * Count); + + return Offset; +} + +////////////////////////////////////////////////////////////////////////// + +enum class CbWriter::StateFlags : uint8_t +{ + None = 0, + /** Whether a name has been written for the current field. */ + Name = 1 << 0, + /** Whether this state is in the process of writing a field. */ + Field = 1 << 1, + /** Whether this state is for array fields. */ + Array = 1 << 2, + /** Whether this state is for object fields. */ + Object = 1 << 3, +}; + +ENUM_CLASS_FLAGS(CbWriter::StateFlags); + +/** Whether the field type can be used in a uniform array or uniform object. */ +static constexpr bool +IsUniformType(const CbFieldType Type) +{ + if (CbFieldTypeOps::HasFieldName(Type)) + { + return true; + } + + switch (Type) + { + case CbFieldType::None: + case CbFieldType::Null: + case CbFieldType::BoolFalse: + case CbFieldType::BoolTrue: + return false; + default: + return true; + } +} + +/** Append the payload from the compact binary value to the array and return its type. */ +static inline CbFieldType +AppendCompactBinary(const CbFieldView& Value, std::vector<uint8_t>& OutData) +{ + struct FCopy : public CbFieldView + { + using CbFieldView::GetPayloadView; + using CbFieldView::GetType; + }; + const FCopy& ValueCopy = static_cast<const FCopy&>(Value); + const MemoryView SourceView = ValueCopy.GetPayloadView(); + const uint64_t TargetOffset = OutData.size(); + OutData.resize(TargetOffset + SourceView.GetSize()); + memcpy(OutData.data() + TargetOffset, SourceView.GetData(), SourceView.GetSize()); + return CbFieldTypeOps::GetType(ValueCopy.GetType()); +} + +CbWriter::CbWriter() +{ + States.emplace_back(); +} + +CbWriter::CbWriter(const int64_t InitialSize) : CbWriter() +{ + Data.reserve(InitialSize); +} + +CbWriter::~CbWriter() +{ +} + +void +CbWriter::Reset() +{ + Data.resize(0); + States.resize(0); + States.emplace_back(); +} + +CbFieldIterator +CbWriter::Save() +{ + const uint64_t Size = GetSaveSize(); + UniqueBuffer Buffer = UniqueBuffer::Alloc(Size); + const CbFieldViewIterator Output = Save(Buffer); + + SharedBuffer SharedBuf(std::move(Buffer)); + SharedBuf.MakeImmutable(); + + return CbFieldIterator::MakeRangeView(Output, SharedBuf); +} + +CbFieldViewIterator +CbWriter::Save(const MutableMemoryView Buffer) +{ + ZEN_ASSERT(States.size() == 1 && States.back().Flags == StateFlags::None); + // TEXT("It is invalid to save while there are incomplete write operations.")); + ZEN_ASSERT(Data.size() > 0); // TEXT("It is invalid to save when nothing has been written.")); + ZEN_ASSERT(Buffer.GetSize() == Data.size()); + // TEXT("Buffer is %" UINT64_FMT " bytes but %" INT64_FMT " is required."), + // Buffer.GetSize(), + // Data.Num()); + memcpy(Buffer.GetData(), Data.data(), Data.size()); + return CbFieldViewIterator::MakeRange(Buffer); +} + +void +CbWriter::Save(BinaryWriter& Writer) +{ + ZEN_ASSERT(States.size() == 1 && States.back().Flags == StateFlags::None); + // TEXT("It is invalid to save while there are incomplete write operations.")); + ZEN_ASSERT(Data.size() > 0); // TEXT("It is invalid to save when nothing has been written.")); + Writer.Write(Data.data(), Data.size()); +} + +uint64_t +CbWriter::GetSaveSize() const +{ + return Data.size(); +} + +void +CbWriter::BeginField() +{ + WriterState& State = States.back(); + if ((State.Flags & StateFlags::Field) == StateFlags::None) + { + State.Flags |= StateFlags::Field; + State.Offset = Data.size(); + Data.push_back(0); + } + else + { + ZEN_ASSERT((State.Flags & StateFlags::Name) == StateFlags::Name); + // TEXT("A new field cannot be written until the previous field '%.*hs' is finished."), + // GetActiveName().Len(), + // GetActiveName().GetData()); + } +} + +void +CbWriter::EndField(CbFieldType Type) +{ + WriterState& State = States.back(); + + if ((State.Flags & StateFlags::Name) == StateFlags::Name) + { + Type |= CbFieldType::HasFieldName; + } + else + { + ZEN_ASSERT((State.Flags & StateFlags::Object) == StateFlags::None); + // TEXT("It is invalid to write an object field without a unique non-empty name.")); + } + + if (State.Count == 0) + { + State.UniformType = Type; + } + else if (State.UniformType != Type) + { + State.UniformType = CbFieldType::None; + } + + State.Flags &= ~(StateFlags::Name | StateFlags::Field); + ++State.Count; + Data[State.Offset] = uint8_t(Type); +} + +ZEN_NOINLINE +CbWriter& +CbWriter::SetName(const std::string_view Name) +{ + WriterState& State = States.back(); + ZEN_ASSERT((State.Flags & StateFlags::Array) != StateFlags::Array); + // TEXT("It is invalid to write a name for an array field. Name '%.*hs'"), + // Name.Len(), + // Name.GetData()); + ZEN_ASSERT(!Name.empty()); + // TEXT("%s"), + //(State.Flags & EStateFlags::Object) == EStateFlags::Object + // ? TEXT("It is invalid to write an empty name for an object field. Specify a unique non-empty name.") + // : TEXT("It is invalid to write an empty name for a top-level field. Specify a name or avoid this call.")); + ZEN_ASSERT((State.Flags & (StateFlags::Name | StateFlags::Field)) == StateFlags::None); + // TEXT("A new field '%.*hs' cannot be written until the previous field '%.*hs' is finished."), + // Name.Len(), + // Name.GetData(), + // GetActiveName().Len(), + // GetActiveName().GetData()); + + BeginField(); + State.Flags |= StateFlags::Name; + const uint32_t NameLenByteCount = MeasureVarUInt(uint32_t(Name.size())); + const int64_t NameLenOffset = Data.size(); + Data.resize(NameLenOffset + NameLenByteCount); + + WriteVarUInt(uint64_t(Name.size()), Data.data() + NameLenOffset); + + const uint8_t* NamePtr = reinterpret_cast<const uint8_t*>(Name.data()); + Data.insert(Data.end(), NamePtr, NamePtr + Name.size()); + return *this; +} + +void +CbWriter::SetNameOrAddString(const std::string_view NameOrValue) +{ + // A name is only written if it would begin a new field inside of an object. + if ((States.back().Flags & (StateFlags::Name | StateFlags::Field | StateFlags::Object)) == StateFlags::Object) + { + SetName(NameOrValue); + } + else + { + AddString(NameOrValue); + } +} + +std::string_view +CbWriter::GetActiveName() const +{ + const WriterState& State = States.back(); + if ((State.Flags & StateFlags::Name) == StateFlags::Name) + { + const uint8_t* const EncodedName = Data.data() + State.Offset + sizeof(CbFieldType); + uint32_t NameLenByteCount; + const uint64_t NameLen = ReadVarUInt(EncodedName, NameLenByteCount); + const size_t ClampedNameLen = std::clamp<uint64_t>(NameLen, 0, ~uint64_t(0)); + return std::string_view(reinterpret_cast<const char*>(EncodedName + NameLenByteCount), ClampedNameLen); + } + return std::string_view(); +} + +void +CbWriter::MakeFieldsUniform(const int64_t FieldBeginOffset, const int64_t FieldEndOffset) +{ + MutableMemoryView SourceView(Data.data() + FieldBeginOffset, uint64_t(FieldEndOffset - FieldBeginOffset)); + MutableMemoryView TargetView = SourceView; + TargetView.RightChopInline(sizeof(CbFieldType)); + + while (!SourceView.IsEmpty()) + { + const uint64_t FieldSize = MeasureCompactBinary(SourceView) - sizeof(CbFieldType); + SourceView.RightChopInline(sizeof(CbFieldType)); + if (TargetView.GetData() != SourceView.GetData()) + { + memmove(TargetView.GetData(), SourceView.GetData(), FieldSize); + } + SourceView.RightChopInline(FieldSize); + TargetView.RightChopInline(FieldSize); + } + + if (!TargetView.IsEmpty()) + { + const auto EraseBegin = Data.begin() + (FieldEndOffset - TargetView.GetSize()); + const auto EraseEnd = EraseBegin + TargetView.GetSize(); + + Data.erase(EraseBegin, EraseEnd); + } +} + +void +CbWriter::AddField(const CbFieldView& Value) +{ + ZEN_ASSERT(Value.HasValue()); // , TEXT("It is invalid to write a field with no value.")); + BeginField(); + EndField(AppendCompactBinary(Value, Data)); +} + +void +CbWriter::AddField(const CbField& Value) +{ + AddField(CbFieldView(Value)); +} + +void +CbWriter::BeginObject() +{ + BeginField(); + States.push_back(WriterState()); + States.back().Flags |= StateFlags::Object; +} + +void +CbWriter::EndObject() +{ + ZEN_ASSERT(States.size() > 1 && (States.back().Flags & StateFlags::Object) == StateFlags::Object); + + // TEXT("It is invalid to end an object when an object is not at the top of the stack.")); + ZEN_ASSERT((States.back().Flags & StateFlags::Field) == StateFlags::None); + // TEXT("It is invalid to end an object until the previous field is finished.")); + + const bool bUniform = IsUniformType(States.back().UniformType); + const uint64_t Count = States.back().Count; + States.pop_back(); + + // Calculate the offset of the payload. + const WriterState& State = States.back(); + int64_t PayloadOffset = State.Offset + 1; + if ((State.Flags & StateFlags::Name) == StateFlags::Name) + { + uint32_t NameLenByteCount; + const uint64_t NameLen = ReadVarUInt(Data.data() + PayloadOffset, NameLenByteCount); + PayloadOffset += NameLen + NameLenByteCount; + } + + // Remove redundant field types for uniform objects. + if (bUniform && Count > 1) + { + MakeFieldsUniform(PayloadOffset, Data.size()); + } + + // Insert the object size. + const uint64_t Size = uint64_t(Data.size() - PayloadOffset); + const uint32_t SizeByteCount = MeasureVarUInt(Size); + Data.insert(Data.begin() + PayloadOffset, SizeByteCount, 0); + WriteVarUInt(Size, Data.data() + PayloadOffset); + + EndField(bUniform ? CbFieldType::UniformObject : CbFieldType::Object); +} + +void +CbWriter::AddObject(const CbObjectView& Value) +{ + BeginField(); + EndField(AppendCompactBinary(Value.AsFieldView(), Data)); +} + +void +CbWriter::AddObject(const CbObject& Value) +{ + AddObject(CbObjectView(Value)); +} + +ZEN_NOINLINE +void +CbWriter::BeginArray() +{ + BeginField(); + States.push_back(WriterState()); + States.back().Flags |= StateFlags::Array; +} + +void +CbWriter::EndArray() +{ + ZEN_ASSERT(States.size() > 1 && (States.back().Flags & StateFlags::Array) == StateFlags::Array); + // TEXT("Invalid attempt to end an array when an array is not at the top of the stack.")); + ZEN_ASSERT((States.back().Flags & StateFlags::Field) == StateFlags::None); + // TEXT("It is invalid to end an array until the previous field is finished.")); + const bool bUniform = IsUniformType(States.back().UniformType); + const uint64_t Count = States.back().Count; + States.pop_back(); + + // Calculate the offset of the payload. + const WriterState& State = States.back(); + int64_t PayloadOffset = State.Offset + 1; + if ((State.Flags & StateFlags::Name) == StateFlags::Name) + { + uint32_t NameLenByteCount; + const uint64_t NameLen = ReadVarUInt(Data.data() + PayloadOffset, NameLenByteCount); + PayloadOffset += NameLen + NameLenByteCount; + } + + // Remove redundant field types for uniform arrays. + if (bUniform && Count > 1) + { + MakeFieldsUniform(PayloadOffset, Data.size()); + } + + // Insert the array size and field count. + const uint32_t CountByteCount = MeasureVarUInt(Count); + const uint64_t Size = uint64_t(Data.size() - PayloadOffset) + CountByteCount; + const uint32_t SizeByteCount = MeasureVarUInt(Size); + Data.insert(Data.begin() + PayloadOffset, SizeByteCount + CountByteCount, 0); + WriteVarUInt(Size, Data.data() + PayloadOffset); + WriteVarUInt(Count, Data.data() + PayloadOffset + SizeByteCount); + + EndField(bUniform ? CbFieldType::UniformArray : CbFieldType::Array); +} + +void +CbWriter::AddArray(const CbArrayView& Value) +{ + BeginField(); + EndField(AppendCompactBinary(Value.AsFieldView(), Data)); +} + +void +CbWriter::AddArray(const CbArray& Value) +{ + AddArray(CbArrayView(Value)); +} + +void +CbWriter::AddNull() +{ + BeginField(); + EndField(CbFieldType::Null); +} + +void +CbWriter::AddBinary(const void* const Value, const uint64_t Size) +{ + const size_t SizeByteCount = MeasureVarUInt(Size); + Data.reserve(Data.size() + 1 + SizeByteCount + Size); + BeginField(); + const size_t SizeOffset = Data.size(); + Data.resize(Data.size() + SizeByteCount); + WriteVarUInt(Size, Data.data() + SizeOffset); + Data.insert(Data.end(), static_cast<const uint8_t*>(Value), static_cast<const uint8_t*>(Value) + Size); + EndField(CbFieldType::Binary); +} + +void +CbWriter::AddBinary(IoBuffer Buffer) +{ + AddBinary(Buffer.Data(), Buffer.Size()); +} + +void +CbWriter::AddBinary(SharedBuffer Buffer) +{ + AddBinary(Buffer.GetData(), Buffer.GetSize()); +} + +void +CbWriter::AddBinary(const CompositeBuffer& Buffer) +{ + AddBinary(Buffer.Flatten()); +} + +void +CbWriter::AddString(const std::string_view Value) +{ + BeginField(); + const uint64_t Size = uint64_t(Value.size()); + const uint32_t SizeByteCount = MeasureVarUInt(Size); + const int64_t Offset = Data.size(); + + Data.resize(Offset + SizeByteCount + Size); + + uint8_t* StringData = Data.data() + Offset; + WriteVarUInt(Size, StringData); + StringData += SizeByteCount; + if (Size > 0) + { + memcpy(StringData, Value.data(), Value.size() * sizeof(char)); + } + EndField(CbFieldType::String); +} + +void +CbWriter::AddString(const std::wstring_view Value) +{ + BeginField(); + ExtendableStringBuilder<128> Utf8; + WideToUtf8(Value, Utf8); + + const uint32_t Size = uint32_t(Utf8.Size()); + const uint32_t SizeByteCount = MeasureVarUInt(Size); + const int64_t Offset = Data.size(); + Data.resize(Offset + SizeByteCount + Size); + uint8_t* StringData = Data.data() + Offset; + WriteVarUInt(Size, StringData); + StringData += SizeByteCount; + if (Size > 0) + { + memcpy(reinterpret_cast<char*>(StringData), Utf8.Data(), Utf8.Size()); + } + EndField(CbFieldType::String); +} + +ZEN_NOINLINE +void +CbWriter::AddInteger(const int32_t Value) +{ + if (Value >= 0) + { + return AddInteger(uint32_t(Value)); + } + BeginField(); + const uint32_t Magnitude = ~uint32_t(Value); + const uint32_t MagnitudeByteCount = MeasureVarUInt(Magnitude); + const int64_t Offset = Data.size(); + Data.resize(Offset + MagnitudeByteCount); + WriteVarUInt(Magnitude, Data.data() + Offset); + EndField(CbFieldType::IntegerNegative); +} + +void +CbWriter::AddInteger(const int64_t Value) +{ + if (Value >= 0) + { + return AddInteger(uint64_t(Value)); + } + BeginField(); + const uint64_t Magnitude = ~uint64_t(Value); + const uint32_t MagnitudeByteCount = MeasureVarUInt(Magnitude); + const uint64_t Offset = AddUninitialized(Data, MagnitudeByteCount); + WriteVarUInt(Magnitude, Data.data() + Offset); + EndField(CbFieldType::IntegerNegative); +} + +ZEN_NOINLINE +void +CbWriter::AddInteger(const uint32_t Value) +{ + BeginField(); + const uint32_t ValueByteCount = MeasureVarUInt(Value); + const uint64_t Offset = AddUninitialized(Data, ValueByteCount); + WriteVarUInt(Value, Data.data() + Offset); + EndField(CbFieldType::IntegerPositive); +} + +ZEN_NOINLINE +void +CbWriter::AddInteger(const uint64_t Value) +{ + BeginField(); + const uint32_t ValueByteCount = MeasureVarUInt(Value); + const uint64_t Offset = AddUninitialized(Data, ValueByteCount); + WriteVarUInt(Value, Data.data() + Offset); + EndField(CbFieldType::IntegerPositive); +} + +ZEN_NOINLINE +void +CbWriter::AddFloat(const float Value) +{ + BeginField(); + const uint32_t RawValue = FromNetworkOrder(reinterpret_cast<const uint32_t&>(Value)); + Append(Data, reinterpret_cast<const uint8_t*>(&RawValue), sizeof(uint32_t)); + EndField(CbFieldType::Float32); +} + +ZEN_NOINLINE +void +CbWriter::AddFloat(const double Value) +{ + const float Value32 = float(Value); + if (Value == double(Value32)) + { + return AddFloat(Value32); + } + BeginField(); + const uint64_t RawValue = FromNetworkOrder(reinterpret_cast<const uint64_t&>(Value)); + Append(Data, reinterpret_cast<const uint8_t*>(&RawValue), sizeof(uint64_t)); + EndField(CbFieldType::Float64); +} + +ZEN_NOINLINE +void +CbWriter::AddBool(const bool bValue) +{ + BeginField(); + EndField(bValue ? CbFieldType::BoolTrue : CbFieldType::BoolFalse); +} + +ZEN_NOINLINE +void +CbWriter::AddObjectAttachment(const IoHash& Value) +{ + BeginField(); + Append(Data, Value.Hash, sizeof Value.Hash); + EndField(CbFieldType::ObjectAttachment); +} + +ZEN_NOINLINE +void +CbWriter::AddBinaryAttachment(const IoHash& Value) +{ + BeginField(); + Append(Data, Value.Hash, sizeof Value.Hash); + EndField(CbFieldType::BinaryAttachment); +} + +ZEN_NOINLINE +void +CbWriter::AddAttachment(const CbAttachment& Attachment) +{ + BeginField(); + const IoHash& Value = Attachment.GetHash(); + Append(Data, Value.Hash, sizeof Value.Hash); + EndField(CbFieldType::BinaryAttachment); +} + +ZEN_NOINLINE +void +CbWriter::AddHash(const IoHash& Value) +{ + BeginField(); + Append(Data, Value.Hash, sizeof Value.Hash); + EndField(CbFieldType::Hash); +} + +void +CbWriter::AddUuid(const Guid& Value) +{ + const auto AppendSwappedBytes = [this](uint32_t In) { + In = FromNetworkOrder(In); + Append(Data, reinterpret_cast<const uint8_t*>(&In), sizeof In); + }; + BeginField(); + AppendSwappedBytes(Value.A); + AppendSwappedBytes(Value.B); + AppendSwappedBytes(Value.C); + AppendSwappedBytes(Value.D); + EndField(CbFieldType::Uuid); +} + +void +CbWriter::AddObjectId(const Oid& Value) +{ + BeginField(); + Append(Data, reinterpret_cast<const uint8_t*>(&Value.OidBits), sizeof Value.OidBits); + EndField(CbFieldType::ObjectId); +} + +void +CbWriter::AddDateTimeTicks(const int64_t Ticks) +{ + BeginField(); + const uint64_t RawValue = FromNetworkOrder(uint64_t(Ticks)); + Append(Data, reinterpret_cast<const uint8_t*>(&RawValue), sizeof(uint64_t)); + EndField(CbFieldType::DateTime); +} + +void +CbWriter::AddDateTime(const DateTime Value) +{ + AddDateTimeTicks(Value.GetTicks()); +} + +void +CbWriter::AddTimeSpanTicks(const int64_t Ticks) +{ + BeginField(); + const uint64_t RawValue = FromNetworkOrder(uint64_t(Ticks)); + Append(Data, reinterpret_cast<const uint8_t*>(&RawValue), sizeof(uint64_t)); + EndField(CbFieldType::TimeSpan); +} + +void +CbWriter::AddTimeSpan(const TimeSpan Value) +{ + AddTimeSpanTicks(Value.GetTicks()); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +CbWriter& +operator<<(CbWriter& Writer, const DateTime Value) +{ + Writer.AddDateTime(Value); + return Writer; +} + +CbWriter& +operator<<(CbWriter& Writer, const TimeSpan Value) +{ + Writer.AddTimeSpan(Value); + return Writer; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS +void +usonbuilder_forcelink() +{ +} + +// doctest::String +// toString(const DateTime&) +// { +// // TODO:implement +// return ""; +// } + +// doctest::String +// toString(const TimeSpan&) +// { +// // TODO:implement +// return ""; +// } + +TEST_CASE("usonbuilder.object") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + SUBCASE("EmptyObject") + { + Writer.BeginObject(); + Writer.EndObject(); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + CHECK(Field.IsObject() == true); + CHECK(Field.AsObjectView().CreateViewIterator().HasValue() == false); + } + + SUBCASE("NamedEmptyObject") + { + Writer.SetName("Object"sv); + Writer.BeginObject(); + Writer.EndObject(); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + CHECK(Field.IsObject() == true); + CHECK(Field.AsObjectView().CreateViewIterator().HasValue() == false); + } + + SUBCASE("BasicObject") + { + Writer.BeginObject(); + Writer.SetName("Integer"sv).AddInteger(0); + Writer.SetName("Float"sv).AddFloat(0.0f); + Writer.EndObject(); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + CHECK(Field.IsObject() == true); + + CbObjectView Object = Field.AsObjectView(); + CHECK(Object["Integer"sv].IsInteger() == true); + CHECK(Object["Float"sv].IsFloat() == true); + } + + SUBCASE("UniformObject") + { + Writer.BeginObject(); + Writer.SetName("Field1"sv).AddInteger(0); + Writer.SetName("Field2"sv).AddInteger(1); + Writer.EndObject(); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + CHECK(Field.IsObject() == true); + + CbObjectView Object = Field.AsObjectView(); + CHECK(Object["Field1"sv].IsInteger() == true); + CHECK(Object["Field2"sv].IsInteger() == true); + } +} + +TEST_CASE("usonbuilder.array") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + SUBCASE("EmptyArray") + { + Writer.BeginArray(); + Writer.EndArray(); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + CHECK(Field.IsArray() == true); + CHECK(Field.AsArrayView().Num() == 0); + } + + SUBCASE("NamedEmptyArray") + { + Writer.SetName("Array"sv); + Writer.BeginArray(); + Writer.EndArray(); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + CHECK(Field.IsArray() == true); + CHECK(Field.AsArrayView().Num() == 0); + } + + SUBCASE("BasicArray") + { + Writer.BeginArray(); + Writer.AddInteger(0); + Writer.AddFloat(0.0f); + Writer.EndArray(); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + CHECK(Field.IsArray() == true); + CbFieldViewIterator Iterator = Field.AsArrayView().CreateViewIterator(); + CHECK(Iterator.IsInteger() == true); + ++Iterator; + CHECK(Iterator.IsFloat() == true); + ++Iterator; + CHECK(Iterator.HasValue() == false); + } + + SUBCASE("UniformArray") + { + Writer.BeginArray(); + Writer.AddInteger(0); + Writer.AddInteger(1); + Writer.EndArray(); + + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + CHECK(Field.IsArray() == true); + CbFieldViewIterator Iterator = Field.AsArrayView().CreateViewIterator(); + CHECK(Iterator.IsInteger() == true); + ++Iterator; + CHECK(Iterator.IsInteger() == true); + ++Iterator; + CHECK(Iterator.HasValue() == false); + } +} + +TEST_CASE("usonbuilder.null") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + SUBCASE("Null") + { + Writer.AddNull(); + CbField Field = Writer.Save(); + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + CHECK(Field.HasName() == false); + CHECK(Field.IsNull() == true); + } + + SUBCASE("NullWithName") + { + Writer.SetName("Null"sv); + Writer.AddNull(); + CbField Field = Writer.Save(); + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + CHECK(Field.HasName() == true); + CHECK(Field.GetName().compare("Null"sv) == 0); + CHECK(Field.IsNull() == true); + } + + SUBCASE("Null Array/Object Uniformity") + { + Writer.BeginArray(); + Writer.AddNull(); + Writer.AddNull(); + Writer.AddNull(); + Writer.EndArray(); + + Writer.BeginObject(); + Writer.SetName("N1"sv).AddNull(); + Writer.SetName("N2"sv).AddNull(); + Writer.SetName("N3"sv).AddNull(); + Writer.EndObject(); + + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + } + + SUBCASE("Null with Save(Buffer)") + { + constexpr int NullCount = 3; + for (int Index = 0; Index < NullCount; ++Index) + { + Writer.AddNull(); + } + uint8_t Buffer[NullCount]{}; + MutableMemoryView BufferView(Buffer, sizeof Buffer); + CbFieldViewIterator Fields = Writer.Save(BufferView); + + CHECK(ValidateCompactBinaryRange(BufferView, CbValidateMode::All) == CbValidateError::None); + + for (int Index = 0; Index < NullCount; ++Index) + { + CHECK(Fields.IsNull() == true); + ++Fields; + } + CHECK(Fields.HasValue() == false); + } +} + +TEST_CASE("usonbuilder.binary") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; +} + +TEST_CASE("usonbuilder.string") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + SUBCASE("Empty Strings") + { + Writer.AddString(std::string_view()); + Writer.AddString(std::wstring_view()); + + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + for (CbFieldView Field : Fields) + { + CHECK(Field.HasName() == false); + CHECK(Field.IsString() == true); + CHECK(Field.AsString().empty() == true); + } + } + + SUBCASE("Test Basic Strings") + { + Writer.SetName("String"sv).AddString("Value"sv); + Writer.SetName("String"sv).AddString(L"Value"sv); + + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + for (CbFieldView Field : Fields) + { + CHECK(Field.GetName().compare("String"sv) == 0); + CHECK(Field.HasName() == true); + CHECK(Field.IsString() == true); + CHECK(Field.AsString().compare("Value"sv) == 0); + } + } + + SUBCASE("Long Strings") + { + constexpr int DotCount = 256; + StringBuilder<DotCount + 1> Dots; + for (int Index = 0; Index < DotCount; ++Index) + { + Dots.Append('.'); + } + Writer.AddString(Dots); + Writer.AddString(std::wstring().append(256, L'.')); + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + for (CbFieldView Field : Fields) + { + CHECK((Field.AsString() == std::string_view(Dots))); + } + } + + SUBCASE("Non-ASCII String") + { +# if ZEN_SIZEOF_WCHAR_T == 2 + wchar_t Value[2] = {0xd83d, 0xde00}; +# else + wchar_t Value[1] = {0x1f600}; +# endif + + Writer.AddString("\xf0\x9f\x98\x80"sv); + Writer.AddString(std::wstring_view(Value, ZEN_ARRAY_COUNT(Value))); + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + for (CbFieldView Field : Fields) + { + CHECK((Field.AsString() == "\xf0\x9f\x98\x80"sv)); + } + } +} + +TEST_CASE("usonbuilder.integer") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + auto TestInt32 = [&Writer](int32_t Value) { + Writer.Reset(); + Writer.AddInteger(Value); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + CHECK(Field.AsInt32() == Value); + CHECK(Field.HasError() == false); + }; + + auto TestUInt32 = [&Writer](uint32_t Value) { + Writer.Reset(); + Writer.AddInteger(Value); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + CHECK(Field.AsUInt32() == Value); + CHECK(Field.HasError() == false); + }; + + auto TestInt64 = [&Writer](int64_t Value) { + Writer.Reset(); + Writer.AddInteger(Value); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + CHECK(Field.AsInt64() == Value); + CHECK(Field.HasError() == false); + }; + + auto TestUInt64 = [&Writer](uint64_t Value) { + Writer.Reset(); + Writer.AddInteger(Value); + CbField Field = Writer.Save(); + + CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + CHECK(Field.AsUInt64() == Value); + CHECK(Field.HasError() == false); + }; + + TestUInt32(uint32_t(0x00)); + TestUInt32(uint32_t(0x7f)); + TestUInt32(uint32_t(0x80)); + TestUInt32(uint32_t(0xff)); + TestUInt32(uint32_t(0x0100)); + TestUInt32(uint32_t(0x7fff)); + TestUInt32(uint32_t(0x8000)); + TestUInt32(uint32_t(0xffff)); + TestUInt32(uint32_t(0x0001'0000)); + TestUInt32(uint32_t(0x7fff'ffff)); + TestUInt32(uint32_t(0x8000'0000)); + TestUInt32(uint32_t(0xffff'ffff)); + + TestUInt64(uint64_t(0x0000'0001'0000'0000)); + TestUInt64(uint64_t(0x7fff'ffff'ffff'ffff)); + TestUInt64(uint64_t(0x8000'0000'0000'0000)); + TestUInt64(uint64_t(0xffff'ffff'ffff'ffff)); + + TestInt32(int32_t(0x01)); + TestInt32(int32_t(0x80)); + TestInt32(int32_t(0x81)); + TestInt32(int32_t(0x8000)); + TestInt32(int32_t(0x8001)); + TestInt32(int32_t(0x7fff'ffff)); + TestInt32(int32_t(0x8000'0000)); + TestInt32(int32_t(0x8000'0001)); + + TestInt64(int64_t(0x0000'0001'0000'0000)); + TestInt64(int64_t(0x8000'0000'0000'0000)); + TestInt64(int64_t(0x7fff'ffff'ffff'ffff)); + TestInt64(int64_t(0x8000'0000'0000'0001)); + TestInt64(int64_t(0xffff'ffff'ffff'ffff)); +} + +TEST_CASE("usonbuilder.float") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + SUBCASE("Float32") + { + constexpr float Values[] = { + 0.0f, + 1.0f, + -1.0f, + 3.14159265358979323846f, // PI + 3.402823466e+38f, // FLT_MAX + 1.175494351e-38f // FLT_MIN + }; + + for (float Value : Values) + { + Writer.AddFloat(Value); + } + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + const float* CheckValue = Values; + for (CbFieldView Field : Fields) + { + CHECK(Field.AsFloat() == *CheckValue++); + CHECK(Field.HasError() == false); + } + } + + SUBCASE("Float64") + { + constexpr double Values[] = { + 0.0f, + 1.0f, + -1.0f, + 3.14159265358979323846, // PI + 1.9999998807907104, + 1.9999999403953552, + 3.4028234663852886e38, + 6.8056469327705771e38, + 2.2250738585072014e-308, // DBL_MIN + 1.7976931348623158e+308 // DBL_MAX + }; + + for (double Value : Values) + { + Writer.AddFloat(Value); + } + + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + const double* CheckValue = Values; + for (CbFieldView Field : Fields) + { + CHECK(Field.AsDouble() == *CheckValue++); + CHECK(Field.HasError() == false); + } + } +} + +TEST_CASE("usonbuilder.bool") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + SUBCASE("Bool") + { + Writer.AddBool(true); + Writer.AddBool(false); + + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + CHECK(Fields.AsBool() == true); + CHECK(Fields.HasError() == false); + ++Fields; + CHECK(Fields.AsBool() == false); + CHECK(Fields.HasError() == false); + ++Fields; + CHECK(Fields.HasValue() == false); + } + + SUBCASE("Bool Array/Object Uniformity") + { + Writer.BeginArray(); + Writer.AddBool(false); + Writer.AddBool(false); + Writer.AddBool(false); + Writer.EndArray(); + + Writer.BeginObject(); + Writer.SetName("B1"sv).AddBool(false); + Writer.SetName("B2"sv).AddBool(false); + Writer.SetName("B3"sv).AddBool(false); + Writer.EndObject(); + + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + } +} + +TEST_CASE("usonbuilder.usonattachment") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; +} + +TEST_CASE("usonbuilder.binaryattachment") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; +} + +TEST_CASE("usonbuilder.hash") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; +} + +TEST_CASE("usonbuilder.uuid") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; +} + +TEST_CASE("usonbuilder.datetime") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + const DateTime Values[] = {DateTime(0), DateTime(2020, 5, 13, 15, 10)}; + for (DateTime Value : Values) + { + Writer.AddDateTime(Value); + } + + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + const DateTime* CheckValue = Values; + for (CbFieldView Field : Fields) + { + CHECK(Field.AsDateTime() == *CheckValue++); + CHECK(Field.HasError() == false); + } +} + +TEST_CASE("usonbuilder.timespan") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + const TimeSpan Values[] = {TimeSpan(0), TimeSpan(1, 2, 4, 8)}; + for (TimeSpan Value : Values) + { + Writer.AddTimeSpan(Value); + } + + CbFieldIterator Fields = Writer.Save(); + + CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + + const TimeSpan* CheckValue = Values; + for (CbFieldView Field : Fields) + { + CHECK(Field.AsTimeSpan() == *CheckValue++); + CHECK(Field.HasError() == false); + } +} + +TEST_CASE("usonbuilder.complex") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + SUBCASE("complex") + { + CbObject Object; + + { + Writer.BeginObject(); + + const uint8_t LocalField[] = {uint8_t(CbFieldType::IntegerPositive | CbFieldType::HasFieldName), 1, 'I', 42}; + Writer.AddField("FieldCopy"sv, CbFieldView(LocalField)); + Writer.AddField("FieldRefCopy"sv, CbField(SharedBuffer::Clone(MakeMemoryView(LocalField)))); + + const uint8_t LocalObject[] = {uint8_t(CbFieldType::Object | CbFieldType::HasFieldName), + 1, + 'O', + 7, + uint8_t(CbFieldType::IntegerPositive | CbFieldType::HasFieldName), + 1, + 'I', + 42, + uint8_t(CbFieldType::Null | CbFieldType::HasFieldName), + 1, + 'N'}; + Writer.AddObject("ObjectCopy"sv, CbObjectView(LocalObject)); + Writer.AddObject("ObjectRefCopy"sv, CbObject(SharedBuffer::Clone(MakeMemoryView(LocalObject)))); + + const uint8_t LocalArray[] = {uint8_t(CbFieldType::UniformArray | CbFieldType::HasFieldName), + 1, + 'A', + 4, + 2, + uint8_t(CbFieldType::IntegerPositive), + 42, + 21}; + Writer.AddArray("ArrayCopy"sv, CbArrayView(LocalArray)); + Writer.AddArray("ArrayRefCopy"sv, CbArray(SharedBuffer::Clone(MakeMemoryView(LocalArray)))); + + Writer.AddNull("Null"sv); + + Writer.BeginObject("Binary"sv); + { + Writer.AddBinary("Empty"sv, MemoryView()); + Writer.AddBinary("Value"sv, MakeMemoryView("BinaryValue")); + Writer.AddBinary("LargeValue"sv, MakeMemoryView(std::wstring().append(256, L'.'))); + Writer.AddBinary("LargeRefValue"sv, SharedBuffer::Clone(MakeMemoryView(std::wstring().append(256, L'!')))); + } + Writer.EndObject(); + + Writer.BeginObject("Strings"sv); + { + Writer.AddString("AnsiString"sv, "AnsiValue"sv); + Writer.AddString("WideString"sv, std::wstring().append(256, L'.')); + Writer.AddString("EmptyAnsiString"sv, std::string_view()); + Writer.AddString("EmptyWideString"sv, std::wstring_view()); + Writer.AddString("AnsiStringLiteral", "AnsiValue"); + Writer.AddString("WideStringLiteral", L"AnsiValue"); + } + Writer.EndObject(); + + Writer.BeginArray("Integers"sv); + { + Writer.AddInteger(int32_t(-1)); + Writer.AddInteger(int64_t(-1)); + Writer.AddInteger(uint32_t(1)); + Writer.AddInteger(uint64_t(1)); + Writer.AddInteger(std::numeric_limits<int32_t>::min()); + Writer.AddInteger(std::numeric_limits<int32_t>::max()); + Writer.AddInteger(std::numeric_limits<uint32_t>::max()); + Writer.AddInteger(std::numeric_limits<int64_t>::min()); + Writer.AddInteger(std::numeric_limits<int64_t>::max()); + Writer.AddInteger(std::numeric_limits<uint64_t>::max()); + } + Writer.EndArray(); + + Writer.BeginArray("UniformIntegers"sv); + { + Writer.AddInteger(0); + Writer.AddInteger(std::numeric_limits<int32_t>::max()); + Writer.AddInteger(std::numeric_limits<uint32_t>::max()); + Writer.AddInteger(std::numeric_limits<int64_t>::max()); + Writer.AddInteger(std::numeric_limits<uint64_t>::max()); + } + Writer.EndArray(); + + Writer.AddFloat("Float32"sv, 1.0f); + Writer.AddFloat("Float64as32"sv, 2.0); + Writer.AddFloat("Float64"sv, 3.0e100); + + Writer.AddBool("False"sv, false); + Writer.AddBool("True"sv, true); + + Writer.AddObjectAttachment("ObjectAttachment"sv, IoHash()); + Writer.AddBinaryAttachment("BinaryAttachment"sv, IoHash()); + Writer.AddAttachment("Attachment"sv, CbAttachment()); + + Writer.AddHash("Hash"sv, IoHash()); + Writer.AddUuid("Uuid"sv, Guid()); + + Writer.AddDateTimeTicks("DateTimeZero"sv, 0); + Writer.AddDateTime("DateTime2020"sv, DateTime(2020, 5, 13, 15, 10)); + + Writer.AddTimeSpanTicks("TimeSpanZero"sv, 0); + Writer.AddTimeSpan("TimeSpan"sv, TimeSpan(1, 2, 4, 8)); + + Writer.BeginObject("NestedObjects"sv); + { + Writer.BeginObject("Empty"sv); + Writer.EndObject(); + + Writer.BeginObject("Null"sv); + Writer.AddNull("Null"sv); + Writer.EndObject(); + } + Writer.EndObject(); + + Writer.BeginArray("NestedArrays"sv); + { + Writer.BeginArray(); + Writer.EndArray(); + + Writer.BeginArray(); + Writer.AddNull(); + Writer.AddNull(); + Writer.AddNull(); + Writer.EndArray(); + + Writer.BeginArray(); + Writer.AddBool(false); + Writer.AddBool(false); + Writer.AddBool(false); + Writer.EndArray(); + + Writer.BeginArray(); + Writer.AddBool(true); + Writer.AddBool(true); + Writer.AddBool(true); + Writer.EndArray(); + } + Writer.EndArray(); + + Writer.BeginArray("ArrayOfObjects"sv); + { + Writer.BeginObject(); + Writer.EndObject(); + + Writer.BeginObject(); + Writer.AddNull("Null"sv); + Writer.EndObject(); + } + Writer.EndArray(); + + Writer.BeginArray("LargeArray"sv); + for (int Index = 0; Index < 256; ++Index) + { + Writer.AddInteger(Index - 128); + } + Writer.EndArray(); + + Writer.BeginArray("LargeUniformArray"sv); + for (int Index = 0; Index < 256; ++Index) + { + Writer.AddInteger(Index); + } + Writer.EndArray(); + + Writer.BeginArray("NestedUniformArray"sv); + for (int Index = 0; Index < 16; ++Index) + { + Writer.BeginArray(); + for (int Value = 0; Value < 4; ++Value) + { + Writer.AddInteger(Value); + } + Writer.EndArray(); + } + Writer.EndArray(); + + Writer.EndObject(); + Object = Writer.Save().AsObject(); + } + CHECK(ValidateCompactBinary(Object.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + } +} + +TEST_CASE("usonbuilder.stream") +{ + using namespace std::literals; + + FixedCbWriter<256> Writer; + + SUBCASE("basic") + { + CbObject Object; + { + Writer.BeginObject(); + + const uint8_t LocalField[] = {uint8_t(CbFieldType::IntegerPositive | CbFieldType::HasFieldName), 1, 'I', 42}; + Writer << "FieldCopy"sv << CbFieldView(LocalField); + + const uint8_t LocalObject[] = {uint8_t(CbFieldType::Object | CbFieldType::HasFieldName), + 1, + 'O', + 7, + uint8_t(CbFieldType::IntegerPositive | CbFieldType::HasFieldName), + 1, + 'I', + 42, + uint8_t(CbFieldType::Null | CbFieldType::HasFieldName), + 1, + 'N'}; + Writer << "ObjectCopy"sv << CbObjectView(LocalObject); + + const uint8_t LocalArray[] = {uint8_t(CbFieldType::UniformArray | CbFieldType::HasFieldName), + 1, + 'A', + 4, + 2, + uint8_t(CbFieldType::IntegerPositive), + 42, + 21}; + Writer << "ArrayCopy"sv << CbArrayView(LocalArray); + + Writer << "Null"sv << nullptr; + + Writer << "Strings"sv; + Writer.BeginObject(); + Writer << "AnsiString"sv + << "AnsiValue"sv + << "AnsiStringLiteral"sv + << "AnsiValue" + << "WideString"sv << L"WideValue"sv << "WideStringLiteral"sv << L"WideValue"; + Writer.EndObject(); + + Writer << "Integers"sv; + Writer.BeginArray(); + Writer << int32_t(-1) << int64_t(-1) << uint32_t(1) << uint64_t(1); + Writer.EndArray(); + + Writer << "Float32"sv << 1.0f; + Writer << "Float64"sv << 2.0; + + Writer << "False"sv << false << "True"sv << true; + + Writer << "Attachment"sv << CbAttachment(); + + Writer << "Hash"sv << IoHash(); + Writer << "Uuid"sv << Guid(); + + Writer << "DateTime"sv << DateTime(2020, 5, 13, 15, 10); + Writer << "TimeSpan"sv << TimeSpan(1, 2, 4, 8); + + Writer << "LiteralName" << nullptr; + + Writer.EndObject(); + Object = Writer.Save().AsObject(); + } + + CHECK(ValidateCompactBinary(Object.GetBuffer(), CbValidateMode::All) == CbValidateError::None); + } +} +#endif + +} // namespace zen diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp new file mode 100644 index 000000000..a4fa38a1d --- /dev/null +++ b/src/zencore/compactbinarypackage.cpp @@ -0,0 +1,1350 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/compactbinarypackage.h" +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/endian.h> +#include <zencore/stream.h> +#include <zencore/testing.h> + +namespace zen { + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +CbAttachment::CbAttachment(const CompressedBuffer& InValue, const IoHash& Hash) : CbAttachment(InValue.MakeOwned(), Hash) +{ +} + +CbAttachment::CbAttachment(const SharedBuffer& InValue) : CbAttachment(CompositeBuffer(InValue)) +{ +} + +CbAttachment::CbAttachment(const SharedBuffer& InValue, const IoHash& InHash) : CbAttachment(CompositeBuffer(InValue), InHash) +{ +} + +CbAttachment::CbAttachment(const CompositeBuffer& InValue) +: Hash(InValue.IsNull() ? IoHash::Zero : IoHash::HashBuffer(InValue)) +, Value(InValue) +{ + if (std::get<CompositeBuffer>(Value).IsNull()) + { + Value.emplace<std::nullptr_t>(); + } +} + +CbAttachment::CbAttachment(CompositeBuffer&& InValue) +: Hash(InValue.IsNull() ? IoHash::Zero : IoHash::HashBuffer(InValue)) +, Value(std::move(InValue)) + +{ + if (std::get<CompositeBuffer>(Value).IsNull()) + { + Value.emplace<std::nullptr_t>(); + } +} + +CbAttachment::CbAttachment(CompositeBuffer&& InValue, const IoHash& InHash) : Hash(InHash), Value(InValue) +{ + if (std::get<CompositeBuffer>(Value).IsNull()) + { + Value.emplace<std::nullptr_t>(); + } +} + +CbAttachment::CbAttachment(CompressedBuffer&& InValue, const IoHash& InHash) : Hash(InHash), Value(InValue) +{ + if (std::get<CompressedBuffer>(Value).IsNull()) + { + Value.emplace<std::nullptr_t>(); + } +} + +CbAttachment::CbAttachment(const CbObject& InValue, const IoHash* const InHash) +{ + auto SetValue = [&](const CbObject& ValueToSet) { + if (InHash) + { + Value.emplace<CbObject>(ValueToSet); + Hash = *InHash; + } + else + { + Value.emplace<CbObject>(ValueToSet); + Hash = ValueToSet.GetHash(); + } + }; + + MemoryView View; + if (!InValue.IsOwned() || !InValue.TryGetSerializedView(View)) + { + SetValue(CbObject::Clone(InValue)); + } + else + { + SetValue(InValue); + } +} + +bool +CbAttachment::TryLoad(IoBuffer& InBuffer, BufferAllocator Allocator) +{ + BinaryReader Reader(InBuffer.Data(), InBuffer.Size()); + + return TryLoad(Reader, Allocator); +} + +bool +CbAttachment::TryLoad(CbFieldIterator& Fields) +{ + if (const CbObjectView ObjectView = Fields.AsObjectView(); !Fields.HasError()) + { + // Is a null object or object not prefixed with a precomputed hash value + Value.emplace<CbObject>(CbObject(ObjectView, Fields.GetOuterBuffer())); + Hash = ObjectView.GetHash(); + ++Fields; + } + else if (const IoHash ObjectAttachmentHash = Fields.AsObjectAttachment(); !Fields.HasError()) + { + // Is an object + ++Fields; + const CbObjectView InnerObjectView = Fields.AsObjectView(); + if (Fields.HasError()) + { + return false; + } + Value.emplace<CbObject>(CbObject(InnerObjectView, Fields.GetOuterBuffer())); + Hash = ObjectAttachmentHash; + ++Fields; + } + else if (const IoHash BinaryAttachmentHash = Fields.AsBinaryAttachment(); !Fields.HasError()) + { + // Is an uncompressed binary blob + ++Fields; + MemoryView BinaryView = Fields.AsBinaryView(); + if (Fields.HasError()) + { + return false; + } + Value.emplace<CompositeBuffer>(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer())); + Hash = BinaryAttachmentHash; + ++Fields; + } + else if (MemoryView BinaryView = Fields.AsBinaryView(); !Fields.HasError()) + { + if (BinaryView.GetSize() > 0) + { + // Is a compressed binary blob + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = + CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer()), RawHash, RawSize).MakeOwned(); + Value.emplace<CompressedBuffer>(Compressed); + Hash = RawHash; + ++Fields; + } + else + { + // Is an uncompressed empty binary blob + Value.emplace<CompositeBuffer>(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer())); + Hash = IoHash::HashBuffer(nullptr, 0); + ++Fields; + } + } + else + { + return false; + } + + return true; +} + +static bool +TryLoad_ArchiveFieldIntoAttachment(CbAttachment& TargetAttachment, CbField&& Field, BinaryReader& Reader, BufferAllocator Allocator) +{ + if (const CbObjectView ObjectView = Field.AsObjectView(); !Field.HasError()) + { + // Is a null object or object not prefixed with a precomputed hash value + TargetAttachment = CbAttachment(CbObject(ObjectView, std::move(Field)), ObjectView.GetHash()); + } + else if (const IoHash ObjectAttachmentHash = Field.AsObjectAttachment(); !Field.HasError()) + { + // Is an object + Field = LoadCompactBinary(Reader, Allocator); + if (!Field.IsObject()) + { + return false; + } + TargetAttachment = CbAttachment(std::move(Field).AsObject(), ObjectAttachmentHash); + } + else if (const IoHash BinaryAttachmentHash = Field.AsBinaryAttachment(); !Field.HasError()) + { + // Is an uncompressed binary blob + Field = LoadCompactBinary(Reader, Allocator); + SharedBuffer Buffer = Field.AsBinary(); + if (Field.HasError()) + { + return false; + } + TargetAttachment = CbAttachment(CompositeBuffer(Buffer), BinaryAttachmentHash); + } + else if (SharedBuffer Buffer = Field.AsBinary(); !Field.HasError()) + { + if (Buffer.GetSize() > 0) + { + // Is a compressed binary blob + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(std::move(Buffer), RawHash, RawSize); + TargetAttachment = CbAttachment(Compressed, RawHash); + } + else + { + // Is an uncompressed empty binary blob + TargetAttachment = CbAttachment(CompositeBuffer(Buffer), IoHash::HashBuffer(nullptr, 0)); + } + } + else + { + return false; + } + + return true; +} + +bool +CbAttachment::TryLoad(BinaryReader& Reader, BufferAllocator Allocator) +{ + CbField Field = LoadCompactBinary(Reader, Allocator); + return TryLoad_ArchiveFieldIntoAttachment(*this, std::move(Field), Reader, Allocator); +} + +void +CbAttachment::Save(CbWriter& Writer) const +{ + if (const CbObject* Object = std::get_if<CbObject>(&Value)) + { + if (*Object) + { + Writer.AddObjectAttachment(Hash); + } + Writer.AddObject(*Object); + } + else if (const CompositeBuffer* Binary = std::get_if<CompositeBuffer>(&Value)) + { + if (Binary->GetSize() > 0) + { + Writer.AddBinaryAttachment(Hash); + } + Writer.AddBinary(*Binary); + } + else if (const CompressedBuffer* Compressed = std::get_if<CompressedBuffer>(&Value)) + { + Writer.AddBinary(Compressed->GetCompressed()); + } +} + +void +CbAttachment::Save(BinaryWriter& Writer) const +{ + CbWriter TempWriter; + Save(TempWriter); + TempWriter.Save(Writer); +} + +bool +CbAttachment::IsNull() const +{ + return std::holds_alternative<std::nullptr_t>(Value); +} + +bool +CbAttachment::IsBinary() const +{ + return std::holds_alternative<CompositeBuffer>(Value); +} + +bool +CbAttachment::IsCompressedBinary() const +{ + return std::holds_alternative<CompressedBuffer>(Value); +} + +bool +CbAttachment::IsObject() const +{ + return std::holds_alternative<CbObject>(Value); +} + +IoHash +CbAttachment::GetHash() const +{ + return Hash; +} + +CompositeBuffer +CbAttachment::AsCompositeBinary() const +{ + if (const CompositeBuffer* BinValue = std::get_if<CompositeBuffer>(&Value)) + { + return *BinValue; + } + + return CompositeBuffer::Null; +} + +SharedBuffer +CbAttachment::AsBinary() const +{ + if (const CompositeBuffer* BinValue = std::get_if<CompositeBuffer>(&Value)) + { + return BinValue->Flatten(); + } + + return {}; +} + +CompressedBuffer +CbAttachment::AsCompressedBinary() const +{ + if (const CompressedBuffer* CompValue = std::get_if<CompressedBuffer>(&Value)) + { + return *CompValue; + } + + return CompressedBuffer::Null; +} + +/** Access the attachment as compact binary. Defaults to a field iterator with no value on error. */ +CbObject +CbAttachment::AsObject() const +{ + if (const CbObject* ObjectValue = std::get_if<CbObject>(&Value)) + { + return *ObjectValue; + } + + return {}; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +void +CbPackage::SetObject(CbObject InObject, const IoHash* InObjectHash, AttachmentResolver* InResolver) +{ + if (InObject) + { + Object = InObject.IsOwned() ? std::move(InObject) : CbObject::Clone(InObject); + if (InObjectHash) + { + ObjectHash = *InObjectHash; + ZEN_ASSERT_SLOW(ObjectHash == Object.GetHash()); + } + else + { + ObjectHash = Object.GetHash(); + } + if (InResolver) + { + GatherAttachments(Object, *InResolver); + } + } + else + { + Object.Reset(); + ObjectHash = IoHash::Zero; + } +} + +void +CbPackage::AddAttachment(const CbAttachment& Attachment, AttachmentResolver* Resolver) +{ + if (!Attachment.IsNull()) + { + auto It = std::lower_bound(begin(Attachments), end(Attachments), Attachment); + if (It != Attachments.end() && *It == Attachment) + { + CbAttachment& Existing = *It; + Existing = Attachment; + } + else + { + Attachments.insert(It, Attachment); + } + + if (Attachment.IsObject() && Resolver) + { + GatherAttachments(Attachment.AsObject(), *Resolver); + } + } +} + +void +CbPackage::AddAttachments(std::span<const CbAttachment> InAttachments) +{ + if (InAttachments.empty()) + { + return; + } + // Assume we have no duplicates! + Attachments.insert(Attachments.end(), InAttachments.begin(), InAttachments.end()); + std::sort(Attachments.begin(), Attachments.end()); + ZEN_ASSERT_SLOW(std::unique(Attachments.begin(), Attachments.end()) == Attachments.end()); +} + +int32_t +CbPackage::RemoveAttachment(const IoHash& Hash) +{ + return gsl::narrow_cast<int32_t>( + std::erase_if(Attachments, [&Hash](const CbAttachment& Attachment) -> bool { return Attachment.GetHash() == Hash; })); +} + +bool +CbPackage::Equals(const CbPackage& Package) const +{ + return ObjectHash == Package.ObjectHash && Attachments == Package.Attachments; +} + +const CbAttachment* +CbPackage::FindAttachment(const IoHash& Hash) const +{ + auto It = std::find_if(begin(Attachments), end(Attachments), [&Hash](const CbAttachment& Attachment) -> bool { + return Attachment.GetHash() == Hash; + }); + + if (It == end(Attachments)) + return nullptr; + + return &*It; +} + +void +CbPackage::GatherAttachments(const CbObject& Value, AttachmentResolver Resolver) +{ + Value.IterateAttachments([this, &Resolver](CbFieldView Field) { + const IoHash& Hash = Field.AsAttachment(); + + if (SharedBuffer Buffer = Resolver(Hash)) + { + if (Field.IsObjectAttachment()) + { + AddAttachment(CbAttachment(CbObject(std::move(Buffer)), Hash), &Resolver); + } + else + { + AddAttachment(CbAttachment(std::move(Buffer))); + } + } + }); +} + +bool +CbPackage::TryLoad(IoBuffer InBuffer, BufferAllocator Allocator, AttachmentResolver* Mapper) +{ + BinaryReader Reader(InBuffer.Data(), InBuffer.Size()); + + return TryLoad(Reader, Allocator, Mapper); +} + +bool +CbPackage::TryLoad(CbFieldIterator& Fields) +{ + *this = CbPackage(); + + while (Fields) + { + if (Fields.IsNull()) + { + ++Fields; + break; + } + else if (IoHash Hash = Fields.AsHash(); !Fields.HasError() && !Fields.IsAttachment()) + { + ++Fields; + CbObjectView ObjectView = Fields.AsObjectView(); + if (Fields.HasError() || Hash != ObjectView.GetHash()) + { + return false; + } + Object = CbObject(ObjectView, Fields.GetOuterBuffer()); + Object.MakeOwned(); + ObjectHash = Hash; + ++Fields; + } + else + { + CbAttachment Attachment; + if (!Attachment.TryLoad(Fields)) + { + return false; + } + AddAttachment(Attachment); + } + } + return true; +} + +bool +CbPackage::TryLoad(BinaryReader& Reader, BufferAllocator Allocator, AttachmentResolver* Mapper) +{ + // TODO: this needs to re-grow the ability to accept a reference to an attachment which is + // not embedded + + ZEN_UNUSED(Mapper); + +#if 1 + *this = CbPackage(); + for (;;) + { + CbField Field = LoadCompactBinary(Reader, Allocator); + if (!Field) + { + return false; + } + + if (Field.IsNull()) + { + return true; + } + else if (IoHash Hash = Field.AsHash(); !Field.HasError() && !Field.IsAttachment()) + { + Field = LoadCompactBinary(Reader, Allocator); + CbObjectView ObjectView = Field.AsObjectView(); + if (Field.HasError() || Hash != ObjectView.GetHash()) + { + return false; + } + Object = CbObject(ObjectView, Field.GetOuterBuffer()); + ObjectHash = Hash; + } + else + { + CbAttachment Attachment; + if (!TryLoad_ArchiveFieldIntoAttachment(Attachment, std::move(Field), Reader, Allocator)) + { + return false; + } + AddAttachment(Attachment); + } + } +#else + uint8_t StackBuffer[64]; + const auto StackAllocator = [&Allocator, &StackBuffer](uint64_t Size) -> UniqueBuffer { + if (Size <= sizeof(StackBuffer)) + { + return UniqueBuffer::MakeMutableView(StackBuffer, Size); + } + + return Allocator(Size); + }; + + *this = CbPackage(); + + for (;;) + { + CbField ValueField = LoadCompactBinary(Reader, StackAllocator); + if (!ValueField) + { + return false; + } + if (ValueField.IsNull()) + { + return true; + } + else if (ValueField.IsBinary()) + { + const MemoryView View = ValueField.AsBinaryView(); + if (View.GetSize() > 0) + { + SharedBuffer Buffer = SharedBuffer::MakeView(View, ValueField.GetOuterBuffer()).MakeOwned(); + CbField HashField = LoadCompactBinary(Reader, StackAllocator); + const IoHash& Hash = HashField.AsAttachment(); + ZEN_ASSERT(!HashField.HasError(), "Attachments must be a non-empty binary value with a content hash."); + if (HashField.IsObjectAttachment()) + { + AddAttachment(CbAttachment(CbObject(std::move(Buffer)), Hash)); + } + else + { + AddAttachment(CbAttachment(std::move(Buffer), Hash)); + } + } + } + else if (ValueField.IsHash()) + { + const IoHash Hash = ValueField.AsHash(); + + ZEN_ASSERT(Mapper); + + AddAttachment(CbAttachment((*Mapper)(Hash), Hash)); + } + else + { + Object = ValueField.AsObject(); + if (ValueField.HasError()) + { + return false; + } + Object.MakeOwned(); + if (Object) + { + CbField HashField = LoadCompactBinary(Reader, StackAllocator); + ObjectHash = HashField.AsObjectAttachment(); + if (HashField.HasError() || Object.GetHash() != ObjectHash) + { + return false; + } + } + else + { + Object.Reset(); + } + } + } +#endif +} + +void +CbPackage::Save(CbWriter& Writer) const +{ + if (Object) + { + Writer.AddHash(ObjectHash); + Writer.AddObject(Object); + } + for (const CbAttachment& Attachment : Attachments) + { + Attachment.Save(Writer); + } + Writer.AddNull(); +} + +void +CbPackage::Save(BinaryWriter& StreamWriter) const +{ + CbWriter Writer; + Save(Writer); + Writer.Save(StreamWriter); +} + +////////////////////////////////////////////////////////////////////////// +// +// Legacy package serialization support +// + +namespace legacy { + + void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer) + { + if (Attachment.IsObject()) + { + CbObject Object = Attachment.AsObject(); + Writer.AddBinary(Object.GetBuffer()); + if (Object) + { + Writer.AddObjectAttachment(Attachment.GetHash()); + } + } + else if (Attachment.IsBinary()) + { + Writer.AddBinary(Attachment.AsBinary()); + Writer.AddBinaryAttachment(Attachment.GetHash()); + } + else if (Attachment.IsCompressedBinary()) + { + Writer.AddBinary(Attachment.AsCompressedBinary().GetCompressed()); + Writer.AddBinaryAttachment(Attachment.GetHash()); + } + else if (Attachment.IsNull()) + { + Writer.AddBinary(MemoryView()); + } + else + { + ZEN_NOT_IMPLEMENTED("Compressed binary is not supported in this serialization format"); + } + } + + void SaveCbPackage(const CbPackage& Package, CbWriter& Writer) + { + if (const CbObject& RootObject = Package.GetObject()) + { + Writer.AddObject(RootObject); + Writer.AddObjectAttachment(Package.GetObjectHash()); + } + for (const CbAttachment& Attachment : Package.GetAttachments()) + { + SaveCbAttachment(Attachment, Writer); + } + Writer.AddNull(); + } + + void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar) + { + CbWriter Writer; + SaveCbPackage(Package, Writer); + Writer.Save(Ar); + } + + bool TryLoadCbPackage(CbPackage& Package, IoBuffer InBuffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper) + { + BinaryReader Reader(InBuffer.Data(), InBuffer.Size()); + + return TryLoadCbPackage(Package, Reader, Allocator, Mapper); + } + + bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper) + { + Package = CbPackage(); + for (;;) + { + CbField ValueField = LoadCompactBinary(Reader, Allocator); + if (!ValueField) + { + return false; + } + if (ValueField.IsNull()) + { + return true; + } + if (ValueField.IsBinary()) + { + const MemoryView View = ValueField.AsBinaryView(); + if (View.GetSize() > 0) + { + SharedBuffer Buffer = SharedBuffer::MakeView(View, ValueField.GetOuterBuffer()).MakeOwned(); + CbField HashField = LoadCompactBinary(Reader, Allocator); + const IoHash& Hash = HashField.AsAttachment(); + if (HashField.HasError()) + { + return false; + } + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer Compressed = CompressedBuffer::FromCompressed(Buffer, RawHash, RawSize)) + { + if (RawHash != Hash) + { + return false; + } + Package.AddAttachment(CbAttachment(Compressed, Hash)); + } + else + { + if (IoHash::HashBuffer(Buffer) != Hash) + { + return false; + } + if (HashField.IsObjectAttachment()) + { + Package.AddAttachment(CbAttachment(CbObject(std::move(Buffer)), Hash)); + } + else + { + Package.AddAttachment(CbAttachment(CompositeBuffer(std::move(Buffer)), Hash)); + } + } + } + } + else if (ValueField.IsHash()) + { + const IoHash Hash = ValueField.AsHash(); + + ZEN_ASSERT(Mapper); + if (SharedBuffer AttachmentData = (*Mapper)(Hash)) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer Compressed = CompressedBuffer::FromCompressed(AttachmentData, RawHash, RawSize)) + { + if (RawHash != Hash) + { + return false; + } + Package.AddAttachment(CbAttachment(Compressed, Hash)); + } + else + { + const CbValidateError ValidationResult = ValidateCompactBinary(AttachmentData.GetView(), CbValidateMode::All); + if (ValidationResult != CbValidateError::None) + { + return false; + } + Package.AddAttachment(CbAttachment(CbObject(std::move(AttachmentData)), Hash)); + } + } + } + else + { + CbObject Object = ValueField.AsObject(); + if (ValueField.HasError()) + { + return false; + } + + if (Object) + { + CbField HashField = LoadCompactBinary(Reader, Allocator); + IoHash ObjectHash = HashField.AsObjectAttachment(); + if (HashField.HasError() || Object.GetHash() != ObjectHash) + { + return false; + } + Package.SetObject(Object, ObjectHash); + } + } + } + } + +} // namespace legacy + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +void +usonpackage_forcelink() +{ +} + +TEST_CASE("usonpackage") +{ + using namespace std::literals; + + const auto TestSaveLoadValidate = [&](const char* Test, const CbAttachment& Attachment) { + ZEN_UNUSED(Test); + + CbWriter Writer; + Attachment.Save(Writer); + CbFieldIterator Fields = Writer.Save(); + + BinaryWriter StreamWriter; + Attachment.Save(StreamWriter); + + CHECK(MakeMemoryView(StreamWriter).EqualBytes(Fields.GetRangeBuffer().GetView())); + CHECK(ValidateCompactBinaryRange(MakeMemoryView(StreamWriter), CbValidateMode::All) == CbValidateError::None); + CHECK(ValidateObjectAttachment(MakeMemoryView(StreamWriter), CbValidateMode::All) == CbValidateError::None); + + CbAttachment FromFields; + FromFields.TryLoad(Fields); + CHECK(!bool(Fields)); + CHECK(FromFields == Attachment); + + CbAttachment FromArchive; + BinaryReader Reader(MakeMemoryView(StreamWriter)); + FromArchive.TryLoad(Reader); + CHECK(Reader.CurrentOffset() == Reader.Size()); + CHECK(FromArchive == Attachment); + }; + + SUBCASE("Empty Attachment") + { + CbAttachment Attachment; + CHECK(Attachment.IsNull()); + CHECK_FALSE(bool(Attachment)); + CHECK_FALSE(bool(Attachment.AsBinary())); + CHECK_FALSE(bool(Attachment.AsObject())); + CHECK_FALSE(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsObject()); + CHECK(Attachment.GetHash() == IoHash::Zero); + } + + SUBCASE("Binary Attachment") + { + const SharedBuffer Buffer = SharedBuffer::Clone(MakeMemoryView<uint8_t>({0, 1, 2, 3})); + CbAttachment Attachment(Buffer); + CHECK_FALSE(Attachment.IsNull()); + CHECK(bool(Attachment)); + CHECK(Attachment.AsBinary() == Buffer); + CHECK_FALSE(bool(Attachment.AsObject())); + CHECK(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsObject()); + CHECK(Attachment.GetHash() == IoHash::HashBuffer(Buffer)); + TestSaveLoadValidate("Binary", Attachment); + } + + SUBCASE("Object Attachment") + { + CbWriter Writer; + Writer.BeginObject(); + Writer << "Name"sv << 42; + Writer.EndObject(); + CbObject Object = Writer.Save().AsObject(); + CbAttachment Attachment(Object); + + CHECK_FALSE(Attachment.IsNull()); + CHECK(bool(Attachment)); + CHECK(Attachment.AsBinary() == SharedBuffer()); + CHECK(Attachment.AsObject().Equals(Object)); + CHECK_FALSE(Attachment.IsBinary()); + CHECK(Attachment.IsObject()); + CHECK(Attachment.GetHash() == Object.GetHash()); + TestSaveLoadValidate("Object", Attachment); + } + + SUBCASE("Binary View") + { + const uint8_t Value[]{0, 1, 2, 3}; + SharedBuffer Buffer = SharedBuffer::MakeView(MakeMemoryView(Value)); + CbAttachment Attachment(Buffer); + CHECK_FALSE(Attachment.IsNull()); + CHECK(bool(Attachment)); + CHECK(Attachment.AsBinary().GetView().EqualBytes(Buffer.GetView())); + CHECK_FALSE(bool(Attachment.AsObject())); + CHECK(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsObject()); + CHECK(Attachment.GetHash() == IoHash::HashBuffer(Buffer)); + } + + SUBCASE("Object View") + { + CbWriter Writer; + Writer.BeginObject(); + Writer << "Name"sv << 42; + Writer.EndObject(); + CbObject Object = Writer.Save().AsObject(); + CbObject ObjectView = CbObject::MakeView(Object); + CbAttachment Attachment(ObjectView); + + CHECK_FALSE(Attachment.IsNull()); + CHECK(bool(Attachment)); + + CHECK(Attachment.AsBinary() != ObjectView.GetBuffer()); + CHECK(Attachment.AsObject().Equals(Object)); + CHECK_FALSE(Attachment.IsBinary()); + CHECK(Attachment.IsObject()); + CHECK(Attachment.GetHash() == IoHash(Object.GetHash())); + } + + SUBCASE("Binary Load from View") + { + const uint8_t Value[]{0, 1, 2, 3}; + const SharedBuffer Buffer = SharedBuffer::MakeView(MakeMemoryView(Value)); + CbAttachment Attachment(Buffer); + + CbWriter Writer; + Attachment.Save(Writer); + CbFieldIterator Fields = Writer.Save(); + CbFieldIterator FieldsView = CbFieldIterator::MakeRangeView(CbFieldViewIterator(Fields)); + Attachment.TryLoad(FieldsView); + + CHECK_FALSE(Attachment.IsNull()); + CHECK(bool(Attachment)); + CHECK_FALSE(FieldsView.GetRangeBuffer().GetView().Contains(Attachment.AsBinary().GetView())); + CHECK(Attachment.AsBinary().GetView().EqualBytes(Buffer.GetView())); + CHECK_FALSE(Attachment.AsObject()); + CHECK(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsObject()); + CHECK(Attachment.GetHash() == IoHash::HashBuffer(MakeMemoryView(Value))); + } + + SUBCASE("Object Load from View") + { + CbWriter ValueWriter; + ValueWriter.BeginObject(); + ValueWriter << "Name"sv << 42; + ValueWriter.EndObject(); + const CbObject Value = ValueWriter.Save().AsObject(); + + CHECK(ValidateCompactBinaryRange(Value.GetView(), CbValidateMode::All) == CbValidateError::None); + CbAttachment Attachment(Value); + + CbWriter Writer; + Attachment.Save(Writer); + CbFieldIterator Fields = Writer.Save(); + CbFieldIterator FieldsView = CbFieldIterator::MakeRangeView(CbFieldViewIterator(Fields)); + + Attachment.TryLoad(FieldsView); + MemoryView View; + + CHECK_FALSE(Attachment.IsNull()); + CHECK(bool(Attachment)); + CHECK(Attachment.AsBinary().GetView().EqualBytes(MemoryView())); + CHECK_FALSE((!Attachment.AsObject().TryGetSerializedView(View) || FieldsView.GetOuterBuffer().GetView().Contains(View))); + CHECK_FALSE(Attachment.IsBinary()); + CHECK(Attachment.IsObject()); + CHECK(Attachment.GetHash() == Value.GetHash()); + } + + SUBCASE("Binary Null") + { + const CbAttachment Attachment(SharedBuffer{}); + + CHECK(Attachment.IsNull()); + CHECK_FALSE(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsObject()); + CHECK(Attachment.GetHash() == IoHash::Zero); + } + + SUBCASE("Binary Empty") + { + const CbAttachment Attachment(UniqueBuffer::Alloc(0).MoveToShared()); + + CHECK_FALSE(Attachment.IsNull()); + CHECK(Attachment.IsBinary()); + CHECK_FALSE(Attachment.IsObject()); + CHECK(Attachment.GetHash() == IoHash::HashBuffer(SharedBuffer{})); + } + + SUBCASE("Object Empty") + { + const CbAttachment Attachment(CbObject{}); + + CHECK_FALSE(Attachment.IsNull()); + CHECK_FALSE(Attachment.IsBinary()); + CHECK(Attachment.IsObject()); + CHECK(Attachment.GetHash() == CbObject().GetHash()); + } +} + +TEST_CASE("usonpackage.serialization") +{ + using namespace std::literals; + + const auto TestSaveLoadValidate = [&](const char* Test, CbPackage& InOutPackage) { + ZEN_UNUSED(Test); + + CbWriter Writer; + InOutPackage.Save(Writer); + CbFieldIterator Fields = Writer.Save(); + + BinaryWriter MemStream; + InOutPackage.Save(MemStream); + + CHECK(MakeMemoryView(MemStream).EqualBytes(Fields.GetRangeBuffer().GetView())); + CHECK(ValidateCompactBinaryRange(MakeMemoryView(MemStream), CbValidateMode::All) == CbValidateError::None); + CHECK(ValidateCompactBinaryPackage(MakeMemoryView(MemStream), CbValidateMode::All) == CbValidateError::None); + + CbPackage FromFields; + FromFields.TryLoad(Fields); + CHECK_FALSE(bool(Fields)); + CHECK(FromFields == InOutPackage); + + CbPackage FromArchive; + BinaryReader ReadAr(MakeMemoryView(MemStream)); + FromArchive.TryLoad(ReadAr); + CHECK(ReadAr.CurrentOffset() == ReadAr.Size()); + CHECK(FromArchive == InOutPackage); + InOutPackage = FromArchive; + }; + + SUBCASE("Empty") + { + CbPackage Package; + CHECK(Package.IsNull()); + CHECK_FALSE(bool(Package)); + CHECK(Package.GetAttachments().size() == 0); + TestSaveLoadValidate("Empty", Package); + } + + SUBCASE("Object Only") + { + CbWriter Writer; + Writer.BeginObject(); + Writer << "Field" << 42; + Writer.EndObject(); + + const CbObject Object = Writer.Save().AsObject(); + CbPackage Package(Object); + CHECK_FALSE(Package.IsNull()); + CHECK(bool(Package)); + CHECK(Package.GetAttachments().size() == 0); + CHECK(Package.GetObject().GetOuterBuffer() == Object.GetOuterBuffer()); + CHECK(Package.GetObject()["Field"].AsInt32() == 42); + CHECK(Package.GetObjectHash() == Package.GetObject().GetHash()); + TestSaveLoadValidate("Object", Package); + } + + // Object View Only + { + CbWriter Writer; + Writer.BeginObject(); + Writer << "Field" << 42; + Writer.EndObject(); + + const CbObject Object = Writer.Save().AsObject(); + CbPackage Package(CbObject::MakeView(Object)); + CHECK_FALSE(Package.IsNull()); + CHECK(bool(Package)); + CHECK(Package.GetAttachments().size() == 0); + CHECK(Package.GetObject().GetOuterBuffer() != Object.GetOuterBuffer()); + CHECK(Package.GetObject()["Field"].AsInt32() == 42); + CHECK(Package.GetObjectHash() == Package.GetObject().GetHash()); + TestSaveLoadValidate("Object", Package); + } + + // Attachment Only + { + CbObject Object1; + { + CbWriter Writer; + Writer.BeginObject(); + Writer << "Field1" << 42; + Writer.EndObject(); + Object1 = Writer.Save().AsObject(); + } + CbObject Object2; + { + CbWriter Writer; + Writer.BeginObject(); + Writer << "Field2" << 42; + Writer.EndObject(); + Object2 = Writer.Save().AsObject(); + } + + CbPackage Package; + Package.AddAttachment(CbAttachment(Object1)); + Package.AddAttachment(CbAttachment(Object2.GetBuffer())); + + CHECK_FALSE(Package.IsNull()); + CHECK(bool(Package)); + CHECK(Package.GetAttachments().size() == 2); + CHECK(Package.GetObject().Equals(CbObject())); + CHECK(Package.GetObjectHash() == IoHash::Zero); + TestSaveLoadValidate("Attachments", Package); + + const CbAttachment* const Object1Attachment = Package.FindAttachment(Object1.GetHash()); + const CbAttachment* const Object2Attachment = Package.FindAttachment(Object2.GetHash()); + + CHECK((Object1Attachment && Object1Attachment->AsObject().Equals(Object1))); + CHECK((Object2Attachment && Object2Attachment->AsBinary().GetView().EqualBytes(Object2.GetBuffer().GetView()))); + + SharedBuffer Object1ClonedBuffer = SharedBuffer::Clone(Object1.GetOuterBuffer()); + Package.AddAttachment(CbAttachment(Object1ClonedBuffer)); + Package.AddAttachment(CbAttachment(CbObject::Clone(Object2))); + + CHECK(Package.GetAttachments().size() == 2); + CHECK(Package.FindAttachment(Object1.GetHash()) == Object1Attachment); + CHECK(Package.FindAttachment(Object2.GetHash()) == Object2Attachment); + + CHECK((Object1Attachment && Object1Attachment->AsBinary() == Object1ClonedBuffer)); + CHECK((Object2Attachment && Object2Attachment->AsObject().Equals(Object2))); + + CHECK(std::is_sorted(begin(Package.GetAttachments()), end(Package.GetAttachments()))); + } + + // Shared Values + const uint8_t Level4Values[]{0, 1, 2, 3}; + SharedBuffer Level4 = SharedBuffer::MakeView(MakeMemoryView(Level4Values)); + const IoHash Level4Hash = IoHash::HashBuffer(Level4); + + CbObject Level3; + { + CbWriter Writer; + Writer.BeginObject(); + Writer.AddBinaryAttachment("Level4", Level4Hash); + Writer.EndObject(); + Level3 = Writer.Save().AsObject(); + } + const IoHash Level3Hash = Level3.GetHash(); + + CbObject Level2; + { + CbWriter Writer; + Writer.BeginObject(); + Writer.AddObjectAttachment("Level3", Level3Hash); + Writer.EndObject(); + Level2 = Writer.Save().AsObject(); + } + const IoHash Level2Hash = Level2.GetHash(); + + CbObject Level1; + { + CbWriter Writer; + Writer.BeginObject(); + Writer.AddObjectAttachment("Level2", Level2Hash); + Writer.EndObject(); + Level1 = Writer.Save().AsObject(); + } + const IoHash Level1Hash = Level1.GetHash(); + + const auto Resolver = [&Level2, &Level2Hash, &Level3, &Level3Hash, &Level4, &Level4Hash](const IoHash& Hash) -> SharedBuffer { + return Hash == Level2Hash ? Level2.GetOuterBuffer() + : Hash == Level3Hash ? Level3.GetOuterBuffer() + : Hash == Level4Hash ? Level4 + : SharedBuffer(); + }; + + // Object + Attachments + { + CbPackage Package; + Package.SetObject(Level1, Level1Hash, Resolver); + + CHECK_FALSE(Package.IsNull()); + CHECK(bool(Package)); + CHECK(Package.GetAttachments().size() == 3); + CHECK(Package.GetObject().GetBuffer() == Level1.GetBuffer()); + CHECK(Package.GetObjectHash() == Level1Hash); + TestSaveLoadValidate("Object+Attachments", Package); + + const CbAttachment* const Level2Attachment = Package.FindAttachment(Level2Hash); + const CbAttachment* const Level3Attachment = Package.FindAttachment(Level3Hash); + const CbAttachment* const Level4Attachment = Package.FindAttachment(Level4Hash); + CHECK((Level2Attachment && Level2Attachment->AsObject().Equals(Level2))); + CHECK((Level3Attachment && Level3Attachment->AsObject().Equals(Level3))); + REQUIRE(Level4Attachment); + CHECK(Level4Attachment->AsBinary() != Level4); + CHECK(Level4Attachment->AsBinary().GetView().EqualBytes(Level4.GetView())); + + CHECK(std::is_sorted(begin(Package.GetAttachments()), end(Package.GetAttachments()))); + + const CbPackage PackageCopy = Package; + CHECK(PackageCopy == Package); + + CHECK(Package.RemoveAttachment(Level1Hash) == 0); + CHECK(Package.RemoveAttachment(Level2Hash) == 1); + CHECK(Package.RemoveAttachment(Level3Hash) == 1); + CHECK(Package.RemoveAttachment(Level4Hash) == 1); + CHECK(Package.RemoveAttachment(Level4Hash) == 0); + CHECK(Package.GetAttachments().size() == 0); + + CHECK(PackageCopy != Package); + Package = PackageCopy; + CHECK(PackageCopy == Package); + Package.SetObject(CbObject()); + CHECK(PackageCopy != Package); + CHECK(Package.GetObjectHash() == IoHash()); + } + + // Out of Order + { + CbWriter Writer; + CbAttachment Attachment2(Level2, Level2Hash); + Attachment2.Save(Writer); + CbAttachment Attachment4(Level4); + Attachment4.Save(Writer); + Writer.AddHash(Level1Hash); + Writer.AddObject(Level1); + CbAttachment Attachment3(Level3, Level3Hash); + Attachment3.Save(Writer); + Writer.AddNull(); + + CbFieldIterator Fields = Writer.Save(); + CbPackage FromFields; + FromFields.TryLoad(Fields); + + const CbAttachment* const Level2Attachment = FromFields.FindAttachment(Level2Hash); + REQUIRE(Level2Attachment); + const CbAttachment* const Level3Attachment = FromFields.FindAttachment(Level3Hash); + REQUIRE(Level3Attachment); + const CbAttachment* const Level4Attachment = FromFields.FindAttachment(Level4Hash); + REQUIRE(Level4Attachment); + + CHECK(FromFields.GetObject().Equals(Level1)); + CHECK(FromFields.GetObject().GetOuterBuffer() == Fields.GetOuterBuffer()); + CHECK(FromFields.GetObjectHash() == Level1Hash); + + const MemoryView FieldsOuterBufferView = Fields.GetOuterBuffer().GetView(); + + CHECK(Level2Attachment->AsObject().Equals(Level2)); + CHECK(Level2Attachment->GetHash() == Level2Hash); + + CHECK(Level3Attachment->AsObject().Equals(Level3)); + CHECK(Level3Attachment->GetHash() == Level3Hash); + + CHECK(Level4Attachment->AsBinary().GetView().EqualBytes(Level4.GetView())); + CHECK(FieldsOuterBufferView.Contains(Level4Attachment->AsBinary().GetView())); + CHECK(Level4Attachment->GetHash() == Level4Hash); + + BinaryWriter WriteStream; + Writer.Save(WriteStream); + CbPackage FromArchive; + BinaryReader ReadAr(MakeMemoryView(WriteStream)); + FromArchive.TryLoad(ReadAr); + + Writer.Reset(); + FromArchive.Save(Writer); + CbFieldIterator Saved = Writer.Save(); + + CHECK(Saved.AsHash() == Level1Hash); + ++Saved; + CHECK(Saved.AsObject().Equals(Level1)); + ++Saved; + CHECK_EQ(Saved.AsObjectAttachment(), Level2Hash); + ++Saved; + CHECK(Saved.AsObject().Equals(Level2)); + ++Saved; + CHECK_EQ(Saved.AsObjectAttachment(), Level3Hash); + ++Saved; + CHECK(Saved.AsObject().Equals(Level3)); + ++Saved; + CHECK_EQ(Saved.AsBinaryAttachment(), Level4Hash); + ++Saved; + SharedBuffer SavedLevel4Buffer = SharedBuffer::MakeView(Saved.AsBinaryView()); + CHECK(SavedLevel4Buffer.GetView().EqualBytes(Level4.GetView())); + ++Saved; + CHECK(Saved.IsNull()); + ++Saved; + CHECK(!Saved); + } + + // Null Attachment + { + CbAttachment NullAttachment; + CbPackage Package; + Package.AddAttachment(NullAttachment); + CHECK(Package.IsNull()); + CHECK_FALSE(bool(Package)); + CHECK(Package.GetAttachments().size() == 0); + CHECK_FALSE(Package.FindAttachment(NullAttachment)); + } + + // Resolve After Merge + { + bool bResolved = false; + CbPackage Package; + Package.AddAttachment(CbAttachment(Level3.GetBuffer())); + Package.AddAttachment(CbAttachment(Level3), [&bResolved](const IoHash& Hash) -> SharedBuffer { + ZEN_UNUSED(Hash); + bResolved = true; + return SharedBuffer(); + }); + CHECK(bResolved); + } +} + +TEST_CASE("usonpackage.invalidpackage") +{ + const auto TestLoad = [](std::initializer_list<uint8_t> RawData, BufferAllocator Allocator = UniqueBuffer::Alloc) { + const MemoryView RawView = MakeMemoryView(RawData); + CbPackage FromArchive; + BinaryReader ReadAr(RawView); + CHECK_FALSE(FromArchive.TryLoad(ReadAr, Allocator)); + }; + const auto AllocFail = [](uint64_t) -> UniqueBuffer { + FAIL_CHECK("Allocation is not expected"); + return UniqueBuffer(); + }; + SUBCASE("Empty") { TestLoad({}, AllocFail); } + SUBCASE("Invalid Initial Field") + { + TestLoad({uint8_t(CbFieldType::None)}); + TestLoad({uint8_t(CbFieldType::Array)}); + TestLoad({uint8_t(CbFieldType::UniformArray)}); + TestLoad({uint8_t(CbFieldType::Binary)}); + TestLoad({uint8_t(CbFieldType::String)}); + TestLoad({uint8_t(CbFieldType::IntegerPositive)}); + TestLoad({uint8_t(CbFieldType::IntegerNegative)}); + TestLoad({uint8_t(CbFieldType::Float32)}); + TestLoad({uint8_t(CbFieldType::Float64)}); + TestLoad({uint8_t(CbFieldType::BoolFalse)}); + TestLoad({uint8_t(CbFieldType::BoolTrue)}); + TestLoad({uint8_t(CbFieldType::ObjectAttachment)}); + TestLoad({uint8_t(CbFieldType::BinaryAttachment)}); + TestLoad({uint8_t(CbFieldType::Uuid)}); + TestLoad({uint8_t(CbFieldType::DateTime)}); + TestLoad({uint8_t(CbFieldType::TimeSpan)}); + TestLoad({uint8_t(CbFieldType::ObjectId)}); + TestLoad({uint8_t(CbFieldType::CustomById)}); + TestLoad({uint8_t(CbFieldType::CustomByName)}); + } + SUBCASE("Size Out Of Bounds") + { + TestLoad({uint8_t(CbFieldType::Object), 1}, AllocFail); + TestLoad({uint8_t(CbFieldType::Object), 0xff, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0}, AllocFail); + } +} + +#endif + +} // namespace zen diff --git a/src/zencore/compactbinaryvalidation.cpp b/src/zencore/compactbinaryvalidation.cpp new file mode 100644 index 000000000..02148d96a --- /dev/null +++ b/src/zencore/compactbinaryvalidation.cpp @@ -0,0 +1,664 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/compactbinaryvalidation.h" + +#include <zencore/compactbinarypackage.h> +#include <zencore/endian.h> +#include <zencore/memory.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +#include <algorithm> + +namespace zen { + +namespace CbValidationPrivate { + + template<typename T> + static constexpr inline T ReadUnaligned(const void* const Memory) + { +#if ZEN_PLATFORM_SUPPORTS_UNALIGNED_LOADS + return *static_cast<const T*>(Memory); +#else + T Value; + memcpy(&Value, Memory, sizeof(Value)); + return Value; +#endif + } + +} // namespace CbValidationPrivate + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * Adds the given error(s) to the error mask. + * + * This function exists to make validation errors easier to debug by providing one location to set a breakpoint. + */ +ZEN_NOINLINE static void +AddError(CbValidateError& OutError, const CbValidateError InError) +{ + OutError |= InError; +} + +/** + * Validate and read a field type from the view. + * + * A type argument with the HasFieldType flag indicates that the type will not be read from the view. + */ +static CbFieldType +ValidateCbFieldType(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, CbFieldType Type = CbFieldType::HasFieldType) +{ + ZEN_UNUSED(Mode); + if (CbFieldTypeOps::HasFieldType(Type)) + { + if (View.GetSize() >= 1) + { + Type = *static_cast<const CbFieldType*>(View.GetData()); + View += 1; + if (CbFieldTypeOps::HasFieldType(Type)) + { + AddError(Error, CbValidateError::InvalidType); + } + } + else + { + AddError(Error, CbValidateError::OutOfBounds); + View.Reset(); + return CbFieldType::None; + } + } + + if (CbFieldTypeOps::GetSerializedType(Type) != Type) + { + AddError(Error, CbValidateError::InvalidType); + View.Reset(); + } + + return Type; +} + +/** + * Validate and read an unsigned integer from the view. + * + * Modifies the view to start at the end of the value, and adds error flags if applicable. + */ +static uint64_t +ValidateCbUInt(MemoryView& View, CbValidateMode Mode, CbValidateError& Error) +{ + if (View.GetSize() > 0 && View.GetSize() >= MeasureVarUInt(View.GetData())) + { + uint32_t ValueByteCount; + const uint64_t Value = ReadVarUInt(View.GetData(), ValueByteCount); + if (EnumHasAnyFlags(Mode, CbValidateMode::Format) && ValueByteCount > MeasureVarUInt(Value)) + { + AddError(Error, CbValidateError::InvalidInteger); + } + View += ValueByteCount; + return Value; + } + else + { + AddError(Error, CbValidateError::OutOfBounds); + View.Reset(); + return 0; + } +} + +/** + * Validate a 64-bit floating point value from the view. + * + * Modifies the view to start at the end of the value, and adds error flags if applicable. + */ +static void +ValidateCbFloat64(MemoryView& View, CbValidateMode Mode, CbValidateError& Error) +{ + if (View.GetSize() >= sizeof(double)) + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Format)) + { + const uint64_t RawValue = FromNetworkOrder(CbValidationPrivate::ReadUnaligned<uint64_t>(View.GetData())); + const double Value = reinterpret_cast<const double&>(RawValue); + if (Value == double(float(Value))) + { + AddError(Error, CbValidateError::InvalidFloat); + } + } + View += sizeof(double); + } + else + { + AddError(Error, CbValidateError::OutOfBounds); + View.Reset(); + } +} + +/** + * Validate and read a string from the view. + * + * Modifies the view to start at the end of the string, and adds error flags if applicable. + */ +static std::string_view +ValidateCbString(MemoryView& View, CbValidateMode Mode, CbValidateError& Error) +{ + const uint64_t NameSize = ValidateCbUInt(View, Mode, Error); + if (View.GetSize() >= NameSize) + { + const std::string_view Name(static_cast<const char*>(View.GetData()), static_cast<int32_t>(NameSize)); + View += NameSize; + return Name; + } + else + { + AddError(Error, CbValidateError::OutOfBounds); + View.Reset(); + return std::string_view(); + } +} + +static CbFieldView ValidateCbField(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, CbFieldType ExternalType); + +/** A type that checks whether all validated fields are of the same type. */ +class CbUniformFieldsValidator +{ +public: + inline explicit CbUniformFieldsValidator(CbFieldType InExternalType) : ExternalType(InExternalType) {} + + inline CbFieldView ValidateField(MemoryView& View, CbValidateMode Mode, CbValidateError& Error) + { + const void* const FieldData = View.GetData(); + if (CbFieldView Field = ValidateCbField(View, Mode, Error, ExternalType)) + { + ++FieldCount; + if (CbFieldTypeOps::HasFieldType(ExternalType)) + { + const CbFieldType FieldType = *static_cast<const CbFieldType*>(FieldData); + if (FieldCount == 1) + { + FirstType = FieldType; + } + else if (FieldType != FirstType) + { + bUniform = false; + } + } + return Field; + } + + // It may not safe to check for uniformity if the field was invalid. + bUniform = false; + return CbFieldView(); + } + + inline bool IsUniform() const { return FieldCount > 0 && bUniform; } + +private: + uint32_t FieldCount = 0; + bool bUniform = true; + CbFieldType FirstType = CbFieldType::None; + CbFieldType ExternalType; +}; + +static void +ValidateCbObject(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, CbFieldType ObjectType) +{ + const uint64_t Size = ValidateCbUInt(View, Mode, Error); + MemoryView ObjectView = View.Left(Size); + View += Size; + + if (Size > 0) + { + std::vector<std::string_view> Names; + + const bool bUniformObject = CbFieldTypeOps::GetType(ObjectType) == CbFieldType::UniformObject; + const CbFieldType ExternalType = bUniformObject ? ValidateCbFieldType(ObjectView, Mode, Error) : CbFieldType::HasFieldType; + CbUniformFieldsValidator UniformValidator(ExternalType); + do + { + if (CbFieldView Field = UniformValidator.ValidateField(ObjectView, Mode, Error)) + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Names)) + { + if (Field.HasName()) + { + Names.push_back(Field.GetName()); + } + else + { + AddError(Error, CbValidateError::MissingName); + } + } + } + } while (!ObjectView.IsEmpty()); + + if (EnumHasAnyFlags(Mode, CbValidateMode::Names) && Names.size() > 1) + { + std::sort(begin(Names), end(Names), [](std::string_view L, std::string_view R) { return L.compare(R) < 0; }); + + for (const std::string_view *NamesIt = Names.data(), *NamesEnd = NamesIt + Names.size() - 1; NamesIt != NamesEnd; ++NamesIt) + { + if (NamesIt[0] == NamesIt[1]) + { + AddError(Error, CbValidateError::DuplicateName); + break; + } + } + } + + if (!bUniformObject && EnumHasAnyFlags(Mode, CbValidateMode::Format) && UniformValidator.IsUniform()) + { + AddError(Error, CbValidateError::NonUniformObject); + } + } +} + +static void +ValidateCbArray(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, CbFieldType ArrayType) +{ + const uint64_t Size = ValidateCbUInt(View, Mode, Error); + MemoryView ArrayView = View.Left(Size); + View += Size; + + const uint64_t Count = ValidateCbUInt(ArrayView, Mode, Error); + const uint64_t FieldsSize = ArrayView.GetSize(); + const bool bUniformArray = CbFieldTypeOps::GetType(ArrayType) == CbFieldType::UniformArray; + const CbFieldType ExternalType = bUniformArray ? ValidateCbFieldType(ArrayView, Mode, Error) : CbFieldType::HasFieldType; + CbUniformFieldsValidator UniformValidator(ExternalType); + + for (uint64_t Index = 0; Index < Count; ++Index) + { + if (CbFieldView Field = UniformValidator.ValidateField(ArrayView, Mode, Error)) + { + if (Field.HasName() && EnumHasAnyFlags(Mode, CbValidateMode::Names)) + { + AddError(Error, CbValidateError::ArrayName); + } + } + } + + if (!bUniformArray && EnumHasAnyFlags(Mode, CbValidateMode::Format) && UniformValidator.IsUniform() && FieldsSize > Count) + { + AddError(Error, CbValidateError::NonUniformArray); + } +} + +static CbFieldView +ValidateCbField(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, const CbFieldType ExternalType = CbFieldType::HasFieldType) +{ + const MemoryView FieldView = View; + const CbFieldType Type = ValidateCbFieldType(View, Mode, Error, ExternalType); + [[maybe_unused]] const std::string_view Name = + CbFieldTypeOps::HasFieldName(Type) ? ValidateCbString(View, Mode, Error) : std::string_view(); + + auto ValidateFixedPayload = [&View, &Error](uint32_t PayloadSize) { + if (View.GetSize() >= PayloadSize) + { + View += PayloadSize; + } + else + { + AddError(Error, CbValidateError::OutOfBounds); + View.Reset(); + } + }; + + if (EnumHasAnyFlags(Error, CbValidateError::OutOfBounds | CbValidateError::InvalidType)) + { + return CbFieldView(); + } + + switch (CbFieldType FieldType = CbFieldTypeOps::GetType(Type)) + { + default: + case CbFieldType::None: + AddError(Error, CbValidateError::InvalidType); + View.Reset(); + break; + case CbFieldType::Null: + case CbFieldType::BoolFalse: + case CbFieldType::BoolTrue: + if (FieldView == View) + { + // Reset the view because a zero-sized field can cause infinite field iteration. + AddError(Error, CbValidateError::InvalidType); + View.Reset(); + } + break; + case CbFieldType::Object: + case CbFieldType::UniformObject: + ValidateCbObject(View, Mode, Error, FieldType); + break; + case CbFieldType::Array: + case CbFieldType::UniformArray: + ValidateCbArray(View, Mode, Error, FieldType); + break; + case CbFieldType::Binary: + { + const uint64_t ValueSize = ValidateCbUInt(View, Mode, Error); + if (View.GetSize() < ValueSize) + { + AddError(Error, CbValidateError::OutOfBounds); + View.Reset(); + } + else + { + View += ValueSize; + } + break; + } + case CbFieldType::String: + ValidateCbString(View, Mode, Error); + break; + case CbFieldType::IntegerPositive: + ValidateCbUInt(View, Mode, Error); + break; + case CbFieldType::IntegerNegative: + ValidateCbUInt(View, Mode, Error); + break; + case CbFieldType::Float32: + ValidateFixedPayload(4); + break; + case CbFieldType::Float64: + ValidateCbFloat64(View, Mode, Error); + break; + case CbFieldType::ObjectAttachment: + case CbFieldType::BinaryAttachment: + case CbFieldType::Hash: + ValidateFixedPayload(20); + break; + case CbFieldType::Uuid: + ValidateFixedPayload(16); + break; + case CbFieldType::DateTime: + case CbFieldType::TimeSpan: + ValidateFixedPayload(8); + break; + case CbFieldType::ObjectId: + ValidateFixedPayload(12); + break; + case CbFieldType::CustomById: + case CbFieldType::CustomByName: + ZEN_NOT_IMPLEMENTED(); // TODO: FIX! + break; + } + + if (EnumHasAnyFlags(Error, CbValidateError::OutOfBounds | CbValidateError::InvalidType)) + { + return CbFieldView(); + } + + return CbFieldView(FieldView.GetData(), ExternalType); +} + +static CbFieldView +ValidateCbPackageField(MemoryView& View, CbValidateMode Mode, CbValidateError& Error) +{ + if (View.IsEmpty()) + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::InvalidPackageFormat); + } + return CbFieldView(); + } + if (CbFieldView Field = ValidateCbField(View, Mode, Error)) + { + if (Field.HasName() && EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::InvalidPackageFormat); + } + return Field; + } + return CbFieldView(); +} + +static IoHash +ValidateCbPackageAttachment(CbFieldView& Value, MemoryView& View, CbValidateMode Mode, CbValidateError& Error) +{ + if (const CbObjectView ObjectView = Value.AsObjectView(); !Value.HasError()) + { + return ObjectView.GetHash(); + } + + if (const IoHash ObjectAttachmentHash = Value.AsObjectAttachment(); !Value.HasError()) + { + if (CbFieldView ObjectField = ValidateCbPackageField(View, Mode, Error)) + { + const CbObjectView InnerObjectView = ObjectField.AsObjectView(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && ObjectField.HasError()) + { + AddError(Error, CbValidateError::InvalidPackageFormat); + } + else if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (ObjectAttachmentHash != InnerObjectView.GetHash())) + { + AddError(Error, CbValidateError::InvalidPackageHash); + } + return ObjectAttachmentHash; + } + } + else if (const IoHash BinaryAttachmentHash = Value.AsBinaryAttachment(); !Value.HasError()) + { + if (CbFieldView BinaryField = ValidateCbPackageField(View, Mode, Error)) + { + const MemoryView BinaryView = BinaryField.AsBinaryView(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && BinaryField.HasError()) + { + AddError(Error, CbValidateError::InvalidPackageFormat); + } + else + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && BinaryView.IsEmpty()) + { + AddError(Error, CbValidateError::NullPackageAttachment); + } + if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (BinaryAttachmentHash != IoHash::HashBuffer(BinaryView))) + { + AddError(Error, CbValidateError::InvalidPackageHash); + } + } + return BinaryAttachmentHash; + } + } + else if (const MemoryView BinaryView = Value.AsBinaryView(); !Value.HasError()) + { + if (BinaryView.GetSize() > 0) + { + IoHash DecodedHash; + uint64_t DecodedRawSize; + CompressedBuffer Buffer = CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView), DecodedHash, DecodedRawSize); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && Buffer.IsNull()) + { + AddError(Error, CbValidateError::NullPackageAttachment); + } + if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (DecodedHash != IoHash::HashBuffer(Buffer.DecompressToComposite()))) + { + AddError(Error, CbValidateError::InvalidPackageHash); + } + return DecodedHash; + } + else + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::NullPackageAttachment); + } + return IoHash::HashBuffer(MemoryView()); + } + } + else + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::InvalidPackageFormat); + } + } + + return IoHash(); +} + +static IoHash +ValidateCbPackageObject(CbFieldView& Value, MemoryView& View, CbValidateMode Mode, CbValidateError& Error) +{ + if (IoHash RootObjectHash = Value.AsHash(); !Value.HasError() && !Value.IsAttachment()) + { + CbFieldView RootObjectField = ValidateCbPackageField(View, Mode, Error); + + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + if (RootObjectField.HasError()) + { + AddError(Error, CbValidateError::InvalidPackageFormat); + } + } + + const CbObjectView RootObjectView = RootObjectField.AsObjectView(); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + if (!RootObjectView) + { + AddError(Error, CbValidateError::NullPackageObject); + } + } + + if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (RootObjectHash != RootObjectView.GetHash())) + { + AddError(Error, CbValidateError::InvalidPackageHash); + } + + return RootObjectHash; + } + else + { + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::InvalidPackageFormat); + } + } + + return IoHash(); +} + +CbValidateError +ValidateCompactBinary(MemoryView View, CbValidateMode Mode, CbFieldType Type) +{ + CbValidateError Error = CbValidateError::None; + if (EnumHasAnyFlags(Mode, CbValidateMode::All)) + { + ValidateCbField(View, Mode, Error, Type); + if (!View.IsEmpty() && EnumHasAnyFlags(Mode, CbValidateMode::Padding)) + { + AddError(Error, CbValidateError::Padding); + } + } + return Error; +} + +CbValidateError +ValidateCompactBinaryRange(MemoryView View, CbValidateMode Mode) +{ + CbValidateError Error = CbValidateError::None; + if (EnumHasAnyFlags(Mode, CbValidateMode::All)) + { + while (!View.IsEmpty()) + { + ValidateCbField(View, Mode, Error); + } + } + return Error; +} + +CbValidateError +ValidateObjectAttachment(MemoryView View, CbValidateMode Mode) +{ + CbValidateError Error = CbValidateError::None; + if (EnumHasAnyFlags(Mode, CbValidateMode::All)) + { + if (CbFieldView Value = ValidateCbPackageField(View, Mode, Error)) + { + ValidateCbPackageAttachment(Value, View, Mode, Error); + } + if (!View.IsEmpty() && EnumHasAnyFlags(Mode, CbValidateMode::Padding)) + { + AddError(Error, CbValidateError::Padding); + } + } + return Error; +} + +CbValidateError +ValidateCompactBinaryPackage(MemoryView View, CbValidateMode Mode) +{ + std::vector<IoHash> Attachments; + CbValidateError Error = CbValidateError::None; + if (EnumHasAnyFlags(Mode, CbValidateMode::All)) + { + uint32_t ObjectCount = 0; + while (CbFieldView Value = ValidateCbPackageField(View, Mode, Error)) + { + if (Value.IsHash() && !Value.IsAttachment()) + { + ValidateCbPackageObject(Value, View, Mode, Error); + if (++ObjectCount > 1 && EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::MultiplePackageObjects); + } + } + else if (Value.IsBinary() || Value.IsAttachment() || Value.IsObject()) + { + const IoHash Hash = ValidateCbPackageAttachment(Value, View, Mode, Error); + if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + Attachments.push_back(Hash); + } + } + else if (Value.IsNull()) + { + break; + } + else if (EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + AddError(Error, CbValidateError::InvalidPackageFormat); + } + + if (EnumHasAnyFlags(Error, CbValidateError::OutOfBounds)) + { + break; + } + } + + if (!View.IsEmpty() && EnumHasAnyFlags(Mode, CbValidateMode::Padding)) + { + AddError(Error, CbValidateError::Padding); + } + + if (Attachments.size() && EnumHasAnyFlags(Mode, CbValidateMode::Package)) + { + std::sort(begin(Attachments), end(Attachments)); + for (const IoHash *It = Attachments.data(), *End = It + Attachments.size() - 1; It != End; ++It) + { + if (It[0] == It[1]) + { + AddError(Error, CbValidateError::DuplicateAttachments); + break; + } + } + } + } + return Error; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS +void +usonvalidation_forcelink() +{ +} + +TEST_CASE("usonvalidation") +{ + SUBCASE("Basic") {} +} +#endif + +} // namespace zen diff --git a/src/zencore/compositebuffer.cpp b/src/zencore/compositebuffer.cpp new file mode 100644 index 000000000..735020451 --- /dev/null +++ b/src/zencore/compositebuffer.cpp @@ -0,0 +1,446 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/compositebuffer.h> + +#include <zencore/sharedbuffer.h> +#include <zencore/testing.h> + +namespace zen { + +const CompositeBuffer CompositeBuffer::Null; + +void +CompositeBuffer::Reset() +{ + m_Segments.clear(); +} + +uint64_t +CompositeBuffer::GetSize() const +{ + uint64_t Accum = 0; + + for (const SharedBuffer& It : m_Segments) + { + Accum += It.GetSize(); + } + + return Accum; +} + +bool +CompositeBuffer::IsOwned() const +{ + for (const SharedBuffer& It : m_Segments) + { + if (It.IsOwned() == false) + { + return false; + } + } + return true; +} + +CompositeBuffer +CompositeBuffer::MakeOwned() const& +{ + return CompositeBuffer(*this).MakeOwned(); +} + +CompositeBuffer +CompositeBuffer::MakeOwned() && +{ + for (SharedBuffer& Segment : m_Segments) + { + Segment = std::move(Segment).MakeOwned(); + } + return std::move(*this); +} + +SharedBuffer +CompositeBuffer::Flatten() const& +{ + switch (m_Segments.size()) + { + case 0: + return SharedBuffer(); + case 1: + return m_Segments[0]; + default: + UniqueBuffer Buffer = UniqueBuffer::Alloc(GetSize()); + MutableMemoryView OutView = Buffer.GetMutableView(); + + for (const SharedBuffer& Segment : m_Segments) + { + OutView.CopyFrom(Segment.GetView()); + OutView += Segment.GetSize(); + } + + return Buffer.MoveToShared(); + } +} + +SharedBuffer +CompositeBuffer::Flatten() && +{ + return m_Segments.size() == 1 ? std::move(m_Segments[0]) : std::as_const(*this).Flatten(); +} + +CompositeBuffer +CompositeBuffer::Mid(uint64_t Offset, uint64_t Size) const +{ + const uint64_t BufferSize = GetSize(); + Offset = Min(Offset, BufferSize); + Size = Min(Size, BufferSize - Offset); + CompositeBuffer Buffer; + IterateRange(Offset, Size, [&Buffer](MemoryView View, const SharedBuffer& ViewOuter) { + Buffer.m_Segments.push_back(SharedBuffer::MakeView(View, ViewOuter)); + }); + return Buffer; +} + +MemoryView +CompositeBuffer::ViewOrCopyRange(uint64_t Offset, uint64_t Size, UniqueBuffer& CopyBuffer) const +{ + MemoryView View; + IterateRange(Offset, Size, [Size, &View, &CopyBuffer, WriteView = MutableMemoryView()](MemoryView Segment) mutable { + if (Size == Segment.GetSize()) + { + View = Segment; + } + else + { + if (WriteView.IsEmpty()) + { + if (CopyBuffer.GetSize() < Size) + { + CopyBuffer = UniqueBuffer::Alloc(Size); + } + View = WriteView = CopyBuffer.GetMutableView().Left(Size); + } + WriteView = WriteView.CopyFrom(Segment); + } + }); + return View; +} + +CompositeBuffer::Iterator +CompositeBuffer::GetIterator(uint64_t Offset) const +{ + size_t SegmentCount = m_Segments.size(); + size_t SegmentIndex = 0; + while (SegmentIndex < SegmentCount) + { + size_t SegmentSize = m_Segments[SegmentIndex].GetSize(); + if (Offset < SegmentSize) + { + return {.SegmentIndex = SegmentIndex, .OffsetInSegment = Offset}; + } + Offset -= SegmentSize; + SegmentIndex++; + } + return {.SegmentIndex = ~0ull, .OffsetInSegment = ~0ull}; +} + +MemoryView +CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& CopyBuffer) const +{ + // We use a sub range IoBuffer when we want to copy data from a segment. + // This means we will only materialize that range of the segment when doing + // GetView() rather than the full segment. + // A hot path for this code is when we call CompressedBuffer::FromCompressed which + // is only interested in reading the header (first 64 bytes or so) and then throws + // away the materialized data. + MutableMemoryView WriteView; + size_t SegmentCount = m_Segments.size(); + ZEN_ASSERT(It.SegmentIndex < SegmentCount); + uint64_t SizeLeft = Size; + while (SizeLeft > 0 && It.SegmentIndex < SegmentCount) + { + const SharedBuffer& Segment = m_Segments[It.SegmentIndex]; + size_t SegmentSize = Segment.GetSize(); + if (Size == SizeLeft && Size <= (SegmentSize - It.OffsetInSegment)) + { + IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, SizeLeft); + MemoryView View = SubSegment.GetView(); + It.OffsetInSegment += SizeLeft; + ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize); + if (It.OffsetInSegment == SegmentSize) + { + It.SegmentIndex++; + It.OffsetInSegment = 0; + } + return View; + } + if (WriteView.GetSize() == 0) + { + if (CopyBuffer.GetSize() < Size) + { + CopyBuffer = UniqueBuffer::Alloc(Size); + } + WriteView = CopyBuffer.GetMutableView(); + } + size_t CopySize = zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft); + IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize); + MemoryView ReadView = SubSegment.GetView(); + WriteView = WriteView.CopyFrom(ReadView); + It.OffsetInSegment += CopySize; + ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize); + if (It.OffsetInSegment == SegmentSize) + { + It.SegmentIndex++; + It.OffsetInSegment = 0; + } + SizeLeft -= CopySize; + } + return CopyBuffer.GetView().Left(Size - SizeLeft); +} + +void +CompositeBuffer::CopyTo(MutableMemoryView WriteView, Iterator& It) const +{ + // We use a sub range IoBuffer when we want to copy data from a segment. + // This means we will only materialize that range of the segment when doing + // GetView() rather than the full segment. + // A hot path for this code is when we call CompressedBuffer::FromCompressed which + // is only interested in reading the header (first 64 bytes or so) and then throws + // away the materialized data. + + size_t SizeLeft = WriteView.GetSize(); + size_t SegmentCount = m_Segments.size(); + ZEN_ASSERT(It.SegmentIndex < SegmentCount); + while (WriteView.GetSize() > 0 && It.SegmentIndex < SegmentCount) + { + const SharedBuffer& Segment = m_Segments[It.SegmentIndex]; + size_t SegmentSize = Segment.GetSize(); + size_t CopySize = zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft); + IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize); + MemoryView ReadView = SubSegment.GetView(); + WriteView = WriteView.CopyFrom(ReadView); + It.OffsetInSegment += CopySize; + ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize); + if (It.OffsetInSegment == SegmentSize) + { + It.SegmentIndex++; + It.OffsetInSegment = 0; + } + SizeLeft -= CopySize; + } +} + +void +CompositeBuffer::CopyTo(MutableMemoryView Target, uint64_t Offset) const +{ + IterateRange(Offset, Target.GetSize(), [Target](MemoryView View, [[maybe_unused]] const SharedBuffer& ViewOuter) mutable { + Target = Target.CopyFrom(View); + }); +} + +void +CompositeBuffer::IterateRange(uint64_t Offset, uint64_t Size, std::function<void(MemoryView View)> Visitor) const +{ + IterateRange(Offset, Size, [Visitor](MemoryView View, [[maybe_unused]] const SharedBuffer& ViewOuter) { Visitor(View); }); +} + +void +CompositeBuffer::IterateRange(uint64_t Offset, + uint64_t Size, + std::function<void(MemoryView View, const SharedBuffer& ViewOuter)> Visitor) const +{ + ZEN_ASSERT(Offset + Size <= GetSize()); + for (const SharedBuffer& Segment : m_Segments) + { + if (const uint64_t SegmentSize = Segment.GetSize(); Offset <= SegmentSize) + { + const MemoryView View = Segment.GetView().Mid(Offset, Size); + Offset = 0; + if (Size == 0 || !View.IsEmpty()) + { + Visitor(View, Segment); + } + Size -= View.GetSize(); + if (Size == 0) + { + break; + } + } + else + { + Offset -= SegmentSize; + } + } +} + +#if ZEN_WITH_TESTS +TEST_CASE("CompositeBuffer Null") +{ + CompositeBuffer Buffer; + CHECK(Buffer.IsNull()); + CHECK(Buffer.IsOwned()); + CHECK(Buffer.MakeOwned().IsNull()); + CHECK(Buffer.Flatten().IsNull()); + CHECK(Buffer.Mid(0, 0).IsNull()); + CHECK(Buffer.GetSize() == 0); + CHECK(Buffer.GetSegments().size() == 0); + + UniqueBuffer CopyBuffer; + CHECK(Buffer.ViewOrCopyRange(0, 0, CopyBuffer).IsEmpty()); + CHECK(CopyBuffer.IsNull()); + + MutableMemoryView CopyView; + Buffer.CopyTo(CopyView); + + uint32_t VisitCount = 0; + Buffer.IterateRange(0, 0, [&VisitCount](MemoryView) { ++VisitCount; }); + CHECK(VisitCount == 0); +} + +TEST_CASE("CompositeBuffer Empty") +{ + const uint8_t EmptyArray[]{0}; + const SharedBuffer EmptyView = SharedBuffer::MakeView(EmptyArray, 0); + CompositeBuffer Buffer(EmptyView); + CHECK(Buffer.IsNull() == false); + CHECK(Buffer.IsOwned() == false); + CHECK(Buffer.MakeOwned().IsNull() == false); + CHECK(Buffer.MakeOwned().IsOwned() == true); + CHECK(Buffer.Flatten() == EmptyView); + CHECK(Buffer.Mid(0, 0).Flatten() == EmptyView); + CHECK(Buffer.GetSize() == 0); + CHECK(Buffer.GetSegments().size() == 1); + CHECK(Buffer.GetSegments()[0] == EmptyView); + + UniqueBuffer CopyBuffer; + CHECK(Buffer.ViewOrCopyRange(0, 0, CopyBuffer) == EmptyView.GetView()); + CHECK(CopyBuffer.IsNull()); + + MutableMemoryView CopyView; + Buffer.CopyTo(CopyView); + + uint32_t VisitCount = 0; + Buffer.IterateRange(0, 0, [&VisitCount](MemoryView) { ++VisitCount; }); + CHECK(VisitCount == 1); +} + +TEST_CASE("CompositeBuffer Empty[1]") +{ + const uint8_t EmptyArray[1]{}; + const SharedBuffer EmptyView1 = SharedBuffer::MakeView(EmptyArray, 0); + const SharedBuffer EmptyView2 = SharedBuffer::MakeView(EmptyArray + 1, 0); + CompositeBuffer Buffer(EmptyView1, EmptyView2); + CHECK(Buffer.Mid(0, 0).Flatten() == EmptyView1); + CHECK(Buffer.GetSize() == 0); + CHECK(Buffer.GetSegments().size() == 2); + CHECK(Buffer.GetSegments()[0] == EmptyView1); + CHECK(Buffer.GetSegments()[1] == EmptyView2); + + UniqueBuffer CopyBuffer; + CHECK(Buffer.ViewOrCopyRange(0, 0, CopyBuffer) == EmptyView1.GetView()); + CHECK(CopyBuffer.IsNull()); + + MutableMemoryView CopyView; + Buffer.CopyTo(CopyView); + + uint32_t VisitCount = 0; + Buffer.IterateRange(0, 0, [&VisitCount](MemoryView) { ++VisitCount; }); + CHECK(VisitCount == 1); +} + +TEST_CASE("CompositeBuffer Flat") +{ + const uint8_t FlatArray[]{1, 2, 3, 4, 5, 6, 7, 8}; + const SharedBuffer FlatView = SharedBuffer::Clone(MakeMemoryView(FlatArray)); + CompositeBuffer Buffer(FlatView); + + CHECK(Buffer.IsNull() == false); + CHECK(Buffer.IsOwned() == true); + CHECK(Buffer.Flatten() == FlatView); + CHECK(Buffer.MakeOwned().Flatten() == FlatView); + CHECK(Buffer.Mid(0).Flatten() == FlatView); + CHECK(Buffer.Mid(4).Flatten().GetView() == FlatView.GetView().Mid(4)); + CHECK(Buffer.Mid(8).Flatten().GetView() == FlatView.GetView().Mid(8)); + CHECK(Buffer.Mid(4, 2).Flatten().GetView() == FlatView.GetView().Mid(4, 2)); + CHECK(Buffer.Mid(8, 0).Flatten().GetView() == FlatView.GetView().Mid(8, 0)); + CHECK(Buffer.GetSize() == sizeof(FlatArray)); + CHECK(Buffer.GetSegments().size() == 1); + CHECK(Buffer.GetSegments()[0] == FlatView); + + UniqueBuffer CopyBuffer; + CHECK(Buffer.ViewOrCopyRange(0, sizeof(FlatArray), CopyBuffer) == FlatView.GetView()); + CHECK(CopyBuffer.IsNull()); + + uint8_t CopyArray[sizeof(FlatArray) - 3]; + Buffer.CopyTo(MakeMutableMemoryView(CopyArray), 3); + CHECK(MakeMemoryView(CopyArray).EqualBytes(MakeMemoryView(FlatArray) + 3)); + + uint32_t VisitCount = 0; + Buffer.IterateRange(0, sizeof(FlatArray), [&VisitCount](MemoryView) { ++VisitCount; }); + CHECK(VisitCount == 1); +} + +TEST_CASE("CompositeBuffer Composite") +{ + const uint8_t FlatArray[]{1, 2, 3, 4, 5, 6, 7, 8}; + const SharedBuffer FlatView1 = SharedBuffer::MakeView(MakeMemoryView(FlatArray).Left(4)); + const SharedBuffer FlatView2 = SharedBuffer::MakeView(MakeMemoryView(FlatArray).Right(4)); + CompositeBuffer Buffer(FlatView1, FlatView2); + + CHECK(Buffer.IsNull() == false); + CHECK(Buffer.IsOwned() == false); + CHECK(Buffer.Flatten().GetView().EqualBytes(MakeMemoryView(FlatArray))); + CHECK(Buffer.Mid(2, 4).Flatten().GetView().EqualBytes(MakeMemoryView(FlatArray).Mid(2, 4))); + CHECK(Buffer.Mid(0, 4).Flatten() == FlatView1); + CHECK(Buffer.Mid(4, 4).Flatten() == FlatView2); + CHECK(Buffer.GetSize() == sizeof(FlatArray)); + CHECK(Buffer.GetSegments().size() == 2); + CHECK(Buffer.GetSegments()[0] == FlatView1); + CHECK(Buffer.GetSegments()[1] == FlatView2); + + UniqueBuffer CopyBuffer; + + CHECK(Buffer.ViewOrCopyRange(0, 4, CopyBuffer) == FlatView1.GetView()); + CHECK(CopyBuffer.IsNull() == true); + CHECK(Buffer.ViewOrCopyRange(4, 4, CopyBuffer) == FlatView2.GetView()); + CHECK(CopyBuffer.IsNull() == true); + CHECK(Buffer.ViewOrCopyRange(3, 2, CopyBuffer).EqualBytes(MakeMemoryView(FlatArray).Mid(3, 2))); + CHECK(CopyBuffer.GetSize() == 2); + CHECK(Buffer.ViewOrCopyRange(1, 6, CopyBuffer).EqualBytes(MakeMemoryView(FlatArray).Mid(1, 6))); + CHECK(CopyBuffer.GetSize() == 6); + CHECK(Buffer.ViewOrCopyRange(2, 4, CopyBuffer).EqualBytes(MakeMemoryView(FlatArray).Mid(2, 4))); + CHECK(CopyBuffer.GetSize() == 6); + + uint8_t CopyArray[4]; + Buffer.CopyTo(MakeMutableMemoryView(CopyArray), 2); + CHECK(MakeMemoryView(CopyArray).EqualBytes(MakeMemoryView(FlatArray).Mid(2, 4))); + + uint32_t VisitCount = 0; + Buffer.IterateRange(0, sizeof(FlatArray), [&VisitCount](MemoryView) { ++VisitCount; }); + CHECK(VisitCount == 2); + + const auto TestIterateRange = + [&Buffer](uint64_t Offset, uint64_t Size, MemoryView ExpectedView, const SharedBuffer& ExpectedViewOuter) { + uint32_t VisitCount = 0; + MemoryView ActualView; + SharedBuffer ActualViewOuter; + Buffer.IterateRange(Offset, Size, [&VisitCount, &ActualView, &ActualViewOuter](MemoryView View, const SharedBuffer& ViewOuter) { + ++VisitCount; + ActualView = View; + ActualViewOuter = ViewOuter; + }); + CHECK(VisitCount == 1); + CHECK(ActualView == ExpectedView); + CHECK(ActualViewOuter == ExpectedViewOuter); + }; + TestIterateRange(0, 4, MakeMemoryView(FlatArray).Mid(0, 4), FlatView1); + TestIterateRange(4, 0, MakeMemoryView(FlatArray).Mid(4, 0), FlatView1); + TestIterateRange(4, 4, MakeMemoryView(FlatArray).Mid(4, 4), FlatView2); + TestIterateRange(8, 0, MakeMemoryView(FlatArray).Mid(8, 0), FlatView2); +} + +void +compositebuffer_forcelink() +{ +} +#endif + +} // namespace zen diff --git a/src/zencore/compress.cpp b/src/zencore/compress.cpp new file mode 100644 index 000000000..632e0e8f3 --- /dev/null +++ b/src/zencore/compress.cpp @@ -0,0 +1,1353 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/compress.h> + +#include <zencore/blake3.h> +#include <zencore/compositebuffer.h> +#include <zencore/crc32.h> +#include <zencore/endian.h> +#include <zencore/iohash.h> +#include <zencore/testing.h> + +#include "../../thirdparty/Oodle/include/oodle2.h" +#if ZEN_PLATFORM_WINDOWS +# pragma comment(lib, "oo2core_win64.lib") +#endif + +#include <lz4.h> +#include <functional> +#include <limits> + +namespace zen::detail { + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr uint64_t DefaultBlockSize = 256 * 1024; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** Method used to compress the data in a compressed buffer. */ +enum class CompressionMethod : uint8_t +{ + /** Header is followed by one uncompressed block. */ + None = 0, + /** Header is followed by an array of compressed block sizes then the compressed blocks. */ + Oodle = 3, + /** Header is followed by an array of compressed block sizes then the compressed blocks. */ + LZ4 = 4, +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** Header used on every compressed buffer. Always stored in big-endian format. */ +struct BufferHeader +{ + static constexpr uint32_t ExpectedMagic = 0xb7756362; // <dot>ucb + + uint32_t Magic = ExpectedMagic; // A magic number to identify a compressed buffer. Always 0xb7756362. + uint32_t Crc32 = 0; // A CRC-32 used to check integrity of the buffer. Uses the polynomial 0x04c11db7 + CompressionMethod Method = + CompressionMethod::None; // The method used to compress the buffer. Affects layout of data following the header + uint8_t Compressor = 0; // The method-specific compressor used to compress the buffer. + uint8_t CompressionLevel = 0; // The method-specific compression level used to compress the buffer. + uint8_t BlockSizeExponent = 0; // The power of two size of every uncompressed block except the last. Size is 1 << BlockSizeExponent + uint32_t BlockCount = 0; // The number of blocks that follow the header + uint64_t TotalRawSize = 0; // The total size of the uncompressed data + uint64_t TotalCompressedSize = 0; // The total size of the compressed data including the header + BLAKE3 RawHash; // The hash of the uncompressed data + + /** Checks validity of the buffer based on the magic number, method, and CRC-32. */ + static bool IsValid(const CompositeBuffer& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize); + static bool IsValid(const SharedBuffer& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) + { + return IsValid(CompositeBuffer(CompressedData), OutRawHash, OutRawSize); + } + + /** Read a header from a buffer that is at least sizeof(BufferHeader) without any validation. */ + static BufferHeader Read(const CompositeBuffer& CompressedData) + { + BufferHeader Header; + if (sizeof(BufferHeader) <= CompressedData.GetSize()) + { + // if (CompressedData.GetSegments()[0].AsIoBuffer().IsWholeFile()) + // { + // ZEN_ASSERT(true); + // } + CompositeBuffer::Iterator It; + CompressedData.CopyTo(MakeMutableMemoryView(&Header, &Header + 1), It); + Header.ByteSwap(); + } + return Header; + } + + /** + * Write a header to a memory view that is at least sizeof(BufferHeader). + * + * @param HeaderView View of the header to write, including any method-specific header data. + */ + void Write(MutableMemoryView HeaderView) const + { + BufferHeader Header = *this; + Header.ByteSwap(); + HeaderView.CopyFrom(MakeMemoryView(&Header, &Header + 1)); + Header.ByteSwap(); + Header.Crc32 = CalculateCrc32(HeaderView); + Header.ByteSwap(); + HeaderView.CopyFrom(MakeMemoryView(&Header, &Header + 1)); + } + + void ByteSwap() + { + Magic = zen::ByteSwap(Magic); + Crc32 = zen::ByteSwap(Crc32); + BlockCount = zen::ByteSwap(BlockCount); + TotalRawSize = zen::ByteSwap(TotalRawSize); + TotalCompressedSize = zen::ByteSwap(TotalCompressedSize); + } + + /** Calculate the CRC-32 from a view of a header including any method-specific header data. */ + static uint32_t CalculateCrc32(MemoryView HeaderView) + { + uint32_t Crc32 = 0; + constexpr uint64_t MethodOffset = offsetof(BufferHeader, Method); + for (MemoryView View = HeaderView + MethodOffset; const uint64_t ViewSize = View.GetSize();) + { + const int32_t Size = static_cast<int32_t>(zen::Min<uint64_t>(ViewSize, /* INT_MAX */ 2147483647u)); + Crc32 = zen::MemCrc32(View.GetData(), Size, Crc32); + View += Size; + } + return Crc32; + } +}; + +static_assert(sizeof(BufferHeader) == 64, "BufferHeader is the wrong size."); + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +class BaseEncoder +{ +public: + virtual CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t BlockSize = DefaultBlockSize) const = 0; +}; + +class BaseDecoder +{ +public: + virtual CompositeBuffer Decompress(const BufferHeader& Header, const CompositeBuffer& CompressedData) const = 0; + virtual bool TryDecompressTo(const BufferHeader& Header, + const CompositeBuffer& CompressedData, + MutableMemoryView RawView, + uint64_t RawOffset) const = 0; + virtual uint64_t GetHeaderSize(const BufferHeader& Header) const = 0; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +class NoneEncoder final : public BaseEncoder +{ +public: + [[nodiscard]] CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t /* BlockSize */) const final + { + BufferHeader Header; + Header.Method = CompressionMethod::None; + Header.BlockCount = 1; + Header.TotalRawSize = RawData.GetSize(); + Header.TotalCompressedSize = Header.TotalRawSize + sizeof(BufferHeader); + Header.RawHash = BLAKE3::HashBuffer(RawData); + + UniqueBuffer HeaderData = UniqueBuffer::Alloc(sizeof(BufferHeader)); + Header.Write(HeaderData); + return CompositeBuffer(HeaderData.MoveToShared(), RawData.MakeOwned()); + } +}; + +class NoneDecoder final : public BaseDecoder +{ +public: + [[nodiscard]] CompositeBuffer Decompress(const BufferHeader& Header, const CompositeBuffer& CompressedData) const final + { + if (Header.Method == CompressionMethod::None && Header.TotalCompressedSize == CompressedData.GetSize() && + Header.TotalCompressedSize == Header.TotalRawSize + sizeof(BufferHeader)) + { + return CompressedData.Mid(sizeof(BufferHeader), Header.TotalRawSize).MakeOwned(); + } + return CompositeBuffer(); + } + + [[nodiscard]] bool TryDecompressTo(const BufferHeader& Header, + const CompositeBuffer& CompressedData, + MutableMemoryView RawView, + uint64_t RawOffset) const final + { + if (Header.Method == CompressionMethod::None && RawOffset + RawView.GetSize() <= Header.TotalRawSize && + Header.TotalCompressedSize == CompressedData.GetSize() && + Header.TotalCompressedSize == Header.TotalRawSize + sizeof(BufferHeader)) + { + CompressedData.CopyTo(RawView, sizeof(BufferHeader) + RawOffset); + return true; + } + return false; + } + + [[nodiscard]] uint64_t GetHeaderSize(const BufferHeader&) const final { return sizeof(BufferHeader); } +}; + +////////////////////////////////////////////////////////////////////////// + +class BlockEncoder : public BaseEncoder +{ +public: + CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t BlockSize = DefaultBlockSize) const final; + +protected: + virtual CompressionMethod GetMethod() const = 0; + virtual uint8_t GetCompressor() const = 0; + virtual uint8_t GetCompressionLevel() const = 0; + virtual uint64_t CompressBlockBound(uint64_t RawSize) const = 0; + virtual bool CompressBlock(MutableMemoryView& CompressedData, MemoryView RawData) const = 0; + +private: + uint64_t GetCompressedBlocksBound(uint64_t BlockCount, uint64_t BlockSize, uint64_t RawSize) const + { + switch (BlockCount) + { + case 0: + return 0; + case 1: + return CompressBlockBound(RawSize); + default: + return CompressBlockBound(BlockSize) - BlockSize + RawSize; + } + } +}; + +CompositeBuffer +BlockEncoder::Compress(const CompositeBuffer& RawData, const uint64_t BlockSize) const +{ + ZEN_ASSERT(IsPow2(BlockSize) && (BlockSize <= (1u << 31))); + + const uint64_t RawSize = RawData.GetSize(); + BLAKE3Stream RawHash; + + const uint64_t BlockCount = RoundUp(RawSize, BlockSize) / BlockSize; + ZEN_ASSERT(BlockCount <= ~uint32_t(0)); + + // Allocate the buffer for the header, metadata, and compressed blocks. + const uint64_t MetaSize = BlockCount * sizeof(uint32_t); + const uint64_t CompressedDataSize = sizeof(BufferHeader) + MetaSize + GetCompressedBlocksBound(BlockCount, BlockSize, RawSize); + UniqueBuffer CompressedData = UniqueBuffer::Alloc(CompressedDataSize); + + // Compress the raw data in blocks and store the raw data for incompressible blocks. + std::vector<uint32_t> CompressedBlockSizes; + CompressedBlockSizes.reserve(BlockCount); + uint64_t CompressedSize = 0; + { + UniqueBuffer RawBlockCopy; + MutableMemoryView CompressedBlocksView = CompressedData.GetMutableView() + sizeof(BufferHeader) + MetaSize; + + CompositeBuffer::Iterator It = RawData.GetIterator(0); + + for (uint64_t RawOffset = 0; RawOffset < RawSize;) + { + const uint64_t RawBlockSize = zen::Min(RawSize - RawOffset, BlockSize); + const MemoryView RawBlock = RawData.ViewOrCopyRange(It, RawBlockSize, RawBlockCopy); + RawHash.Append(RawBlock); + + MutableMemoryView CompressedBlock = CompressedBlocksView; + if (!CompressBlock(CompressedBlock, RawBlock)) + { + return CompositeBuffer(); + } + + uint64_t CompressedBlockSize = CompressedBlock.GetSize(); + if (RawBlockSize <= CompressedBlockSize) + { + CompressedBlockSize = RawBlockSize; + CompressedBlocksView = CompressedBlocksView.CopyFrom(RawBlock); + } + else + { + CompressedBlocksView += CompressedBlockSize; + } + + CompressedBlockSizes.push_back(static_cast<uint32_t>(CompressedBlockSize)); + CompressedSize += CompressedBlockSize; + RawOffset += RawBlockSize; + } + } + + // Return an uncompressed buffer if the compressed data is larger than the raw data. + if (RawSize <= MetaSize + CompressedSize) + { + CompressedData.Reset(); + return NoneEncoder().Compress(RawData, BlockSize); + } + + // Write the header and calculate the CRC-32. + for (uint32_t& Size : CompressedBlockSizes) + { + Size = ByteSwap(Size); + } + CompressedData.GetMutableView().Mid(sizeof(BufferHeader), MetaSize).CopyFrom(MakeMemoryView(CompressedBlockSizes)); + + BufferHeader Header; + Header.Method = GetMethod(); + Header.Compressor = GetCompressor(); + Header.CompressionLevel = GetCompressionLevel(); + Header.BlockSizeExponent = static_cast<uint8_t>(zen::FloorLog2_64(BlockSize)); + Header.BlockCount = static_cast<uint32_t>(BlockCount); + Header.TotalRawSize = RawSize; + Header.TotalCompressedSize = sizeof(BufferHeader) + MetaSize + CompressedSize; + Header.RawHash = RawHash.GetHash(); + Header.Write(CompressedData.GetMutableView().Left(sizeof(BufferHeader) + MetaSize)); + + const MemoryView CompositeView = CompressedData.GetView().Left(Header.TotalCompressedSize); + return CompositeBuffer(SharedBuffer::MakeView(CompositeView, CompressedData.MoveToShared())); +} + +class BlockDecoder : public BaseDecoder +{ +public: + CompositeBuffer Decompress(const BufferHeader& Header, const CompositeBuffer& CompressedData) const final; + [[nodiscard]] bool TryDecompressTo(const BufferHeader& Header, + const CompositeBuffer& CompressedData, + MutableMemoryView RawView, + uint64_t RawOffset) const final; + [[nodiscard]] uint64_t GetHeaderSize(const BufferHeader& Header) const final + { + return sizeof(BufferHeader) + sizeof(uint32_t) * uint64_t(Header.BlockCount); + } + +protected: + virtual bool DecompressBlock(MutableMemoryView RawData, MemoryView CompressedData) const = 0; +}; + +CompositeBuffer +BlockDecoder::Decompress(const BufferHeader& Header, const CompositeBuffer& CompressedData) const +{ + if (Header.BlockCount == 0 || Header.TotalCompressedSize != CompressedData.GetSize()) + { + return CompositeBuffer(); + } + + // The raw data cannot reference the compressed data unless it is owned. + // An empty raw buffer requires an empty segment, which this path creates. + if (!CompressedData.IsOwned() || Header.TotalRawSize == 0) + { + UniqueBuffer Buffer = UniqueBuffer::Alloc(Header.TotalRawSize); + return TryDecompressTo(Header, CompressedData, Buffer, 0) ? CompositeBuffer(Buffer.MoveToShared()) : CompositeBuffer(); + } + + std::vector<uint32_t> CompressedBlockSizes; + CompressedBlockSizes.resize(Header.BlockCount); + CompressedData.CopyTo(MakeMutableMemoryView(CompressedBlockSizes), sizeof(BufferHeader)); + + for (uint32_t& Size : CompressedBlockSizes) + { + Size = ByteSwap(Size); + } + + // Allocate the buffer for the raw blocks that were compressed. + SharedBuffer RawData; + MutableMemoryView RawDataView; + const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent; + { + uint64_t RawDataSize = 0; + uint64_t RemainingRawSize = Header.TotalRawSize; + for (const uint32_t CompressedBlockSize : CompressedBlockSizes) + { + const uint64_t RawBlockSize = zen::Min(RemainingRawSize, BlockSize); + if (CompressedBlockSize < BlockSize) + { + RawDataSize += RawBlockSize; + } + RemainingRawSize -= RawBlockSize; + } + UniqueBuffer RawDataBuffer = UniqueBuffer::Alloc(RawDataSize); + RawDataView = RawDataBuffer; + RawData = RawDataBuffer.MoveToShared(); + } + + // Decompress the compressed data in blocks and reference the uncompressed blocks. + uint64_t PendingCompressedSegmentOffset = sizeof(BufferHeader) + uint64_t(Header.BlockCount) * sizeof(uint32_t); + uint64_t PendingCompressedSegmentSize = 0; + uint64_t PendingRawSegmentOffset = 0; + uint64_t PendingRawSegmentSize = 0; + std::vector<SharedBuffer> Segments; + + const auto CommitPendingCompressedSegment = + [&PendingCompressedSegmentOffset, &PendingCompressedSegmentSize, &CompressedData, &Segments] { + if (PendingCompressedSegmentSize) + { + CompressedData.IterateRange(PendingCompressedSegmentOffset, + PendingCompressedSegmentSize, + [&Segments](MemoryView View, const SharedBuffer& ViewOuter) { + Segments.push_back(SharedBuffer::MakeView(View, ViewOuter)); + }); + PendingCompressedSegmentOffset += PendingCompressedSegmentSize; + PendingCompressedSegmentSize = 0; + } + }; + + const auto CommitPendingRawSegment = [&PendingRawSegmentOffset, &PendingRawSegmentSize, &RawData, &Segments] { + if (PendingRawSegmentSize) + { + const MemoryView PendingSegment = RawData.GetView().Mid(PendingRawSegmentOffset, PendingRawSegmentSize); + Segments.push_back(SharedBuffer::MakeView(PendingSegment, RawData)); + PendingRawSegmentOffset += PendingRawSegmentSize; + PendingRawSegmentSize = 0; + } + }; + + UniqueBuffer CompressedBlockCopy; + uint64_t RemainingRawSize = Header.TotalRawSize; + uint64_t RemainingCompressedSize = CompressedData.GetSize(); + for (const uint32_t CompressedBlockSize : CompressedBlockSizes) + { + if (RemainingCompressedSize < CompressedBlockSize) + { + return CompositeBuffer(); + } + + const uint64_t RawBlockSize = zen::Min(RemainingRawSize, BlockSize); + if (RawBlockSize == CompressedBlockSize) + { + CommitPendingRawSegment(); + PendingCompressedSegmentSize += RawBlockSize; + } + else + { + CommitPendingCompressedSegment(); + const MemoryView CompressedBlock = + CompressedData.ViewOrCopyRange(PendingCompressedSegmentOffset, CompressedBlockSize, CompressedBlockCopy); + if (!DecompressBlock(RawDataView.Left(RawBlockSize), CompressedBlock)) + { + return CompositeBuffer(); + } + PendingCompressedSegmentOffset += CompressedBlockSize; + PendingRawSegmentSize += RawBlockSize; + RawDataView += RawBlockSize; + } + + RemainingCompressedSize -= CompressedBlockSize; + RemainingRawSize -= RawBlockSize; + } + + CommitPendingCompressedSegment(); + CommitPendingRawSegment(); + + return CompositeBuffer(std::move(Segments)); +} + +bool +BlockDecoder::TryDecompressTo(const BufferHeader& Header, + const CompositeBuffer& CompressedData, + MutableMemoryView RawView, + uint64_t RawOffset) const +{ + if (Header.TotalRawSize < RawOffset + RawView.GetSize() || Header.TotalCompressedSize != CompressedData.GetSize()) + { + return false; + } + + const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent; + + UniqueBuffer BlockSizeBuffer; + MemoryView BlockSizeView = CompressedData.ViewOrCopyRange(sizeof(BufferHeader), Header.BlockCount * sizeof(uint32_t), BlockSizeBuffer); + std::span<uint32_t const> CompressedBlockSizes(reinterpret_cast<const uint32_t*>(BlockSizeView.GetData()), Header.BlockCount); + + UniqueBuffer CompressedBlockCopy; + UniqueBuffer UncompressedBlockCopy; + + const size_t FirstBlockIndex = uint64_t(RawOffset / BlockSize); + const size_t LastBlockIndex = uint64_t((RawOffset + RawView.GetSize() - 1) / BlockSize); + const uint64_t LastBlockSize = BlockSize - ((Header.BlockCount * BlockSize) - Header.TotalRawSize); + uint64_t OffsetInFirstBlock = RawOffset % BlockSize; + uint64_t CompressedOffset = sizeof(BufferHeader) + uint64_t(Header.BlockCount) * sizeof(uint32_t); + uint64_t RemainingRawSize = RawView.GetSize(); + + for (size_t BlockIndex = 0; BlockIndex < FirstBlockIndex; BlockIndex++) + { + const uint32_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]); + CompressedOffset += CompressedBlockSize; + } + + for (size_t BlockIndex = FirstBlockIndex; BlockIndex <= LastBlockIndex; BlockIndex++) + { + const uint64_t UncompressedBlockSize = BlockIndex == Header.BlockCount - 1 ? LastBlockSize : BlockSize; + const uint32_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]); + const bool IsCompressed = CompressedBlockSize < UncompressedBlockSize; + + const uint64_t BytesToUncompress = OffsetInFirstBlock > 0 ? zen::Min(RawView.GetSize(), UncompressedBlockSize - OffsetInFirstBlock) + : zen::Min(RemainingRawSize, BlockSize); + + MemoryView CompressedBlock = CompressedData.ViewOrCopyRange(CompressedOffset, CompressedBlockSize, CompressedBlockCopy); + + if (IsCompressed) + { + MutableMemoryView UncompressedBlock = RawView.Left(BytesToUncompress); + + const bool IsAligned = BytesToUncompress == UncompressedBlockSize; + if (!IsAligned) + { + // Decompress to a temporary buffer when the first or the last block reads are not aligned with the block boundaries. + if (UncompressedBlockCopy.IsNull()) + { + UncompressedBlockCopy = UniqueBuffer::Alloc(BlockSize); + } + UncompressedBlock = UncompressedBlockCopy.GetMutableView().Mid(0, UncompressedBlockSize); + } + + if (!DecompressBlock(UncompressedBlock, CompressedBlock)) + { + return false; + } + + if (!IsAligned) + { + RawView.CopyFrom(UncompressedBlock.Mid(OffsetInFirstBlock, BytesToUncompress)); + } + } + else + { + RawView.CopyFrom(CompressedBlock.Mid(OffsetInFirstBlock, BytesToUncompress)); + } + + OffsetInFirstBlock = 0; + RemainingRawSize -= BytesToUncompress; + CompressedOffset += CompressedBlockSize; + RawView += BytesToUncompress; + } + + return RemainingRawSize == 0; +} + +////////////////////////////////////////////////////////////////////////// + +struct OodleInit +{ + OodleInit() + { + OodleConfigValues Config; + Oodle_GetConfigValues(&Config); + // Always read/write Oodle v9 binary data. + Config.m_OodleLZ_BackwardsCompatible_MajorVersion = 9; + Oodle_SetConfigValues(&Config); + } +}; + +OodleInit InitOodle; + +class OodleEncoder final : public BlockEncoder +{ +public: + OodleEncoder(OodleCompressor InCompressor, OodleCompressionLevel InCompressionLevel) + : Compressor(InCompressor) + , CompressionLevel(InCompressionLevel) + { + } + +protected: + CompressionMethod GetMethod() const final { return CompressionMethod::Oodle; } + uint8_t GetCompressor() const final { return static_cast<uint8_t>(Compressor); } + uint8_t GetCompressionLevel() const final { return static_cast<uint8_t>(CompressionLevel); } + + uint64_t CompressBlockBound(uint64_t RawSize) const final + { + return static_cast<uint64_t>(OodleLZ_GetCompressedBufferSizeNeeded(OodleLZ_Compressor_Kraken, static_cast<OO_SINTa>(RawSize))); + } + + bool CompressBlock(MutableMemoryView& CompressedData, MemoryView RawData) const final + { + const OodleLZ_Compressor LZCompressor = GetOodleLZCompressor(Compressor); + const OodleLZ_CompressionLevel LZCompressionLevel = GetOodleLZCompressionLevel(CompressionLevel); + if (LZCompressor == OodleLZ_Compressor_Invalid || LZCompressionLevel == OodleLZ_CompressionLevel_Invalid || + LZCompressionLevel == OodleLZ_CompressionLevel_None) + { + return false; + } + + const OO_SINTa RawSize = static_cast<OO_SINTa>(RawData.GetSize()); + if (static_cast<OO_SINTa>(CompressedData.GetSize()) < OodleLZ_GetCompressedBufferSizeNeeded(LZCompressor, RawSize)) + { + return false; + } + + const OO_SINTa Size = OodleLZ_Compress(LZCompressor, RawData.GetData(), RawSize, CompressedData.GetData(), LZCompressionLevel); + CompressedData.LeftInline(static_cast<uint64_t>(Size)); + return Size > 0; + } + + static OodleLZ_Compressor GetOodleLZCompressor(OodleCompressor Compressor) + { + switch (Compressor) + { + case OodleCompressor::Selkie: + return OodleLZ_Compressor_Selkie; + case OodleCompressor::Mermaid: + return OodleLZ_Compressor_Mermaid; + case OodleCompressor::Kraken: + return OodleLZ_Compressor_Kraken; + case OodleCompressor::Leviathan: + return OodleLZ_Compressor_Leviathan; + case OodleCompressor::NotSet: + default: + return OodleLZ_Compressor_Invalid; + } + } + + static OodleLZ_CompressionLevel GetOodleLZCompressionLevel(OodleCompressionLevel Level) + { + const int IntLevel = (int)Level; + if (IntLevel < (int)OodleLZ_CompressionLevel_Min || IntLevel > (int)OodleLZ_CompressionLevel_Max) + { + return OodleLZ_CompressionLevel_Invalid; + } + return OodleLZ_CompressionLevel(IntLevel); + } + +private: + const OodleCompressor Compressor; + const OodleCompressionLevel CompressionLevel; +}; + +class OodleDecoder final : public BlockDecoder +{ +protected: + bool DecompressBlock(MutableMemoryView RawData, MemoryView CompressedData) const final + { + const OO_SINTa RawSize = static_cast<OO_SINTa>(RawData.GetSize()); + const OO_SINTa Size = OodleLZ_Decompress(CompressedData.GetData(), + static_cast<OO_SINTa>(CompressedData.GetSize()), + RawData.GetData(), + RawSize, + OodleLZ_FuzzSafe_Yes, + OodleLZ_CheckCRC_Yes, + OodleLZ_Verbosity_None); + return Size == RawSize; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +class LZ4Decoder final : public BlockDecoder +{ +protected: + bool DecompressBlock(MutableMemoryView RawData, MemoryView CompressedData) const final + { + if (CompressedData.GetSize() <= std::numeric_limits<int>::max()) + { + const int Size = LZ4_decompress_safe(static_cast<const char*>(CompressedData.GetData()), + static_cast<char*>(RawData.GetData()), + static_cast<int>(CompressedData.GetSize()), + static_cast<int>(zen::Min<uint64_t>(RawData.GetSize(), uint64_t(LZ4_MAX_INPUT_SIZE)))); + return static_cast<uint64_t>(Size) == RawData.GetSize(); + } + return false; + } +}; + +////////////////////////////////////////////////////////////////////////// + +static const BaseDecoder* +GetDecoder(CompressionMethod Method) +{ + static NoneDecoder None; + static OodleDecoder Oodle; + static LZ4Decoder LZ4; + + switch (Method) + { + default: + return nullptr; + case CompressionMethod::None: + return &None; + case CompressionMethod::Oodle: + return &Oodle; + case CompressionMethod::LZ4: + return &LZ4; + } +} + +////////////////////////////////////////////////////////////////////////// + +bool +BufferHeader::IsValid(const CompositeBuffer& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) +{ + uint64_t Size = CompressedData.GetSize(); + if (Size < sizeof(BufferHeader)) + { + return false; + } + const size_t StackBufferSize = 256; + uint8_t StackBuffer[StackBufferSize]; + uint64_t ReadSize = Min(Size, StackBufferSize); + BufferHeader* Header = reinterpret_cast<BufferHeader*>(StackBuffer); + { + CompositeBuffer::Iterator It; + CompressedData.CopyTo(MutableMemoryView(StackBuffer, StackBuffer + StackBufferSize), It); + } + Header->ByteSwap(); + if (Header->Magic != BufferHeader::ExpectedMagic) + { + return false; + } + const BaseDecoder* const Decoder = GetDecoder(Header->Method); + if (!Decoder) + { + return false; + } + uint32_t Crc32 = Header->Crc32; + OutRawHash = IoHash::FromBLAKE3(Header->RawHash); + OutRawSize = Header->TotalRawSize; + uint64_t HeaderSize = Decoder->GetHeaderSize(*Header); + Header->ByteSwap(); + + if (HeaderSize > ReadSize) + { + // 0.004% of cases on a Fortnite hot cache cook + UniqueBuffer HeaderCopy = UniqueBuffer::Alloc(HeaderSize); + CompositeBuffer::Iterator It; + CompressedData.CopyTo(HeaderCopy.GetMutableView(), It); + const MemoryView HeaderView = HeaderCopy.GetView(); + if (Crc32 != BufferHeader::CalculateCrc32(HeaderView)) + { + return false; + } + } + else + { + MemoryView FullHeaderView(StackBuffer, StackBuffer + HeaderSize); + if (Crc32 != BufferHeader::CalculateCrc32(FullHeaderView)) + { + return false; + } + } + return true; +} + +////////////////////////////////////////////////////////////////////////// + +template<typename BufferType> +inline CompositeBuffer +ValidBufferOrEmpty(BufferType&& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) +{ + return BufferHeader::IsValid(CompressedData, OutRawHash, OutRawSize) ? CompositeBuffer(std::forward<BufferType>(CompressedData)) + : CompositeBuffer(); +} + +CompositeBuffer +CopyCompressedRange(const BufferHeader& Header, const CompositeBuffer& CompressedData, uint64_t RawOffset, uint64_t RawSize) +{ + if (Header.TotalRawSize < RawOffset + RawSize) + { + return CompositeBuffer(); + } + + if (Header.Method == CompressionMethod::None) + { + UniqueBuffer NewCompressedData = UniqueBuffer::Alloc(RawSize); + CompressedData.CopyTo(NewCompressedData.GetMutableView(), sizeof(Header) + RawOffset); + + BufferHeader NewHeader = Header; + NewHeader.Crc32 = 0; + NewHeader.TotalRawSize = RawSize; + NewHeader.TotalCompressedSize = NewHeader.TotalRawSize + sizeof(BufferHeader); + NewHeader.RawHash = BLAKE3(); + + UniqueBuffer HeaderData = UniqueBuffer::Alloc(sizeof(BufferHeader)); + NewHeader.Write(HeaderData); + + return CompositeBuffer(HeaderData.MoveToShared(), NewCompressedData.MoveToShared()); + } + else + { + UniqueBuffer BlockSizeBuffer; + MemoryView BlockSizeView = + CompressedData.ViewOrCopyRange(sizeof(BufferHeader), Header.BlockCount * sizeof(uint32_t), BlockSizeBuffer); + std::span<uint32_t const> CompressedBlockSizes(reinterpret_cast<const uint32_t*>(BlockSizeView.GetData()), Header.BlockCount); + + const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent; + const uint64_t LastBlockSize = BlockSize - ((Header.BlockCount * BlockSize) - Header.TotalRawSize); + const size_t FirstBlock = uint64_t(RawOffset / BlockSize); + const size_t LastBlock = uint64_t((RawOffset + RawSize - 1) / BlockSize); + uint64_t CompressedOffset = sizeof(BufferHeader) + uint64_t(Header.BlockCount) * sizeof(uint32_t); + + const uint64_t NewBlockCount = LastBlock - FirstBlock + 1; + const uint64_t NewMetaSize = NewBlockCount * sizeof(uint32_t); + uint64_t NewCompressedSize = 0; + uint64_t NewTotalRawSize = 0; + std::vector<uint32_t> NewCompressedBlockSizes; + + NewCompressedBlockSizes.reserve(NewBlockCount); + for (size_t BlockIndex = FirstBlock; BlockIndex <= LastBlock; ++BlockIndex) + { + const uint64_t UncompressedBlockSize = (BlockIndex == Header.BlockCount - 1) ? LastBlockSize : BlockSize; + NewTotalRawSize += UncompressedBlockSize; + + const uint32_t CompressedBlockSize = CompressedBlockSizes[BlockIndex]; + NewCompressedBlockSizes.push_back(CompressedBlockSize); + NewCompressedSize += ByteSwap(CompressedBlockSize); + } + + const uint64_t NewTotalCompressedSize = sizeof(BufferHeader) + NewBlockCount * sizeof(uint32_t) + NewCompressedSize; + UniqueBuffer NewCompressedData = UniqueBuffer::Alloc(NewTotalCompressedSize); + MutableMemoryView NewCompressedBlocks = NewCompressedData.GetMutableView() + sizeof(BufferHeader) + NewMetaSize; + + // Seek to first compressed block + for (size_t BlockIndex = 0; BlockIndex < FirstBlock; ++BlockIndex) + { + const uint64_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]); + CompressedOffset += CompressedBlockSize; + } + + // Copy blocks + UniqueBuffer CompressedBlockCopy; + const MemoryView CompressedRange = CompressedData.ViewOrCopyRange(CompressedOffset, NewCompressedSize, CompressedBlockCopy); + NewCompressedBlocks.CopyFrom(CompressedRange); + + // Copy block sizes + NewCompressedData.GetMutableView().Mid(sizeof(BufferHeader), NewMetaSize).CopyFrom(MakeMemoryView(NewCompressedBlockSizes)); + + BufferHeader NewHeader; + NewHeader.Crc32 = 0; + NewHeader.Method = Header.Method; + NewHeader.Compressor = Header.Compressor; + NewHeader.CompressionLevel = Header.CompressionLevel; + NewHeader.BlockSizeExponent = Header.BlockSizeExponent; + NewHeader.BlockCount = static_cast<uint32_t>(NewBlockCount); + NewHeader.TotalRawSize = NewTotalRawSize; + NewHeader.TotalCompressedSize = NewTotalCompressedSize; + NewHeader.RawHash = BLAKE3(); + NewHeader.Write(NewCompressedData.GetMutableView().Left(sizeof(BufferHeader) + NewMetaSize)); + + return CompositeBuffer(NewCompressedData.MoveToShared()); + } +} + +} // namespace zen::detail + +namespace zen { + +const CompressedBuffer CompressedBuffer::Null; + +CompressedBuffer +CompressedBuffer::Compress(const CompositeBuffer& RawData, + OodleCompressor Compressor, + OodleCompressionLevel CompressionLevel, + uint64_t BlockSize) +{ + using namespace detail; + + if (BlockSize == 0) + { + BlockSize = DefaultBlockSize; + } + + CompressedBuffer Local; + if (CompressionLevel == OodleCompressionLevel::None) + { + Local.CompressedData = NoneEncoder().Compress(RawData, BlockSize); + } + else + { + Local.CompressedData = OodleEncoder(Compressor, CompressionLevel).Compress(RawData, BlockSize); + } + return Local; +} + +CompressedBuffer +CompressedBuffer::Compress(const SharedBuffer& RawData, + OodleCompressor Compressor, + OodleCompressionLevel CompressionLevel, + uint64_t BlockSize) +{ + return Compress(CompositeBuffer(RawData), Compressor, CompressionLevel, BlockSize); +} + +CompressedBuffer +CompressedBuffer::FromCompressed(const CompositeBuffer& InCompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) +{ + CompressedBuffer Local; + Local.CompressedData = detail::ValidBufferOrEmpty(InCompressedData, OutRawHash, OutRawSize); + return Local; +} + +CompressedBuffer +CompressedBuffer::FromCompressed(CompositeBuffer&& InCompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) +{ + CompressedBuffer Local; + Local.CompressedData = detail::ValidBufferOrEmpty(std::move(InCompressedData), OutRawHash, OutRawSize); + return Local; +} + +CompressedBuffer +CompressedBuffer::FromCompressed(const SharedBuffer& InCompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) +{ + CompressedBuffer Local; + Local.CompressedData = detail::ValidBufferOrEmpty(InCompressedData, OutRawHash, OutRawSize); + return Local; +} + +CompressedBuffer +CompressedBuffer::FromCompressed(SharedBuffer&& InCompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) +{ + CompressedBuffer Local; + Local.CompressedData = detail::ValidBufferOrEmpty(std::move(InCompressedData), OutRawHash, OutRawSize); + return Local; +} + +CompressedBuffer +CompressedBuffer::FromCompressedNoValidate(IoBuffer&& InCompressedData) +{ + if (InCompressedData.GetSize() <= sizeof(detail::BufferHeader)) + { + return CompressedBuffer(); + } + CompressedBuffer Local; + Local.CompressedData = CompositeBuffer(SharedBuffer(std::move(InCompressedData))); + return Local; +} + +CompressedBuffer +CompressedBuffer::FromCompressedNoValidate(CompositeBuffer&& InCompressedData) +{ + if (InCompressedData.GetSize() <= sizeof(detail::BufferHeader)) + { + return CompressedBuffer(); + } + CompressedBuffer Local; + Local.CompressedData = std::move(InCompressedData); + return Local; +} + +bool +CompressedBuffer::ValidateCompressedHeader(IoBuffer&& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) +{ + return detail::BufferHeader::IsValid(SharedBuffer(std::move(CompressedData)), OutRawHash, OutRawSize); +} + +bool +CompressedBuffer::ValidateCompressedHeader(const IoBuffer& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize) +{ + return detail::BufferHeader::IsValid(SharedBuffer(CompressedData), OutRawHash, OutRawSize); +} + +uint64_t +CompressedBuffer::DecodeRawSize() const +{ + return CompressedData ? detail::BufferHeader::Read(CompressedData).TotalRawSize : 0; +} + +IoHash +CompressedBuffer::DecodeRawHash() const +{ + return CompressedData ? IoHash::FromBLAKE3(detail::BufferHeader::Read(CompressedData).RawHash) : IoHash(); +} + +CompressedBuffer +CompressedBuffer::CopyRange(uint64_t RawOffset, uint64_t RawSize) const +{ + using namespace detail; + const BufferHeader Header = BufferHeader::Read(CompressedData); + const uint64_t TotalRawSize = RawSize < ~uint64_t(0) ? RawSize : Header.TotalRawSize - RawOffset; + + CompressedBuffer Range; + Range.CompressedData = CopyCompressedRange(Header, CompressedData, RawOffset, TotalRawSize); + + return Range; +} + +bool +CompressedBuffer::TryDecompressTo(MutableMemoryView RawView, uint64_t RawOffset) const +{ + using namespace detail; + if (CompressedData) + { + const BufferHeader Header = BufferHeader::Read(CompressedData); + if (Header.Magic == BufferHeader::ExpectedMagic) + { + if (const BaseDecoder* const Decoder = GetDecoder(Header.Method)) + { + return Decoder->TryDecompressTo(Header, CompressedData, RawView, RawOffset); + } + } + } + return false; +} + +SharedBuffer +CompressedBuffer::Decompress(uint64_t RawOffset, uint64_t RawSize) const +{ + using namespace detail; + if (CompressedData && RawSize > 0) + { + const BufferHeader Header = BufferHeader::Read(CompressedData); + if (Header.Magic == BufferHeader::ExpectedMagic) + { + if (const BaseDecoder* const Decoder = GetDecoder(Header.Method)) + { + const uint64_t TotalRawSize = RawSize < ~uint64_t(0) ? RawSize : Header.TotalRawSize - RawOffset; + UniqueBuffer RawData = UniqueBuffer::Alloc(TotalRawSize); + if (Decoder->TryDecompressTo(Header, CompressedData, RawData, RawOffset)) + { + return RawData.MoveToShared(); + } + } + } + } + return SharedBuffer(); +} + +CompositeBuffer +CompressedBuffer::DecompressToComposite() const +{ + using namespace detail; + if (CompressedData) + { + const BufferHeader Header = BufferHeader::Read(CompressedData); + if (Header.Magic == BufferHeader::ExpectedMagic) + { + if (const BaseDecoder* const Decoder = GetDecoder(Header.Method)) + { + return Decoder->Decompress(Header, CompressedData); + } + } + } + return CompositeBuffer(); +} + +bool +CompressedBuffer::TryGetCompressParameters(OodleCompressor& OutCompressor, + OodleCompressionLevel& OutCompressionLevel, + uint64_t& OutBlockSize) const +{ + using namespace detail; + if (CompressedData) + { + switch (const BufferHeader Header = BufferHeader::Read(CompressedData); Header.Method) + { + case CompressionMethod::None: + OutCompressor = OodleCompressor::NotSet; + OutCompressionLevel = OodleCompressionLevel::None; + OutBlockSize = 0; + return true; + case CompressionMethod::Oodle: + OutCompressor = OodleCompressor(Header.Compressor); + OutCompressionLevel = OodleCompressionLevel(Header.CompressionLevel); + OutBlockSize = uint64_t(1) << Header.BlockSizeExponent; + return true; + default: + break; + } + } + return false; +} + +/** + ______________________ _____________________________ + \__ ___/\_ _____// _____/\__ ___/ _____/ + | | | __)_ \_____ \ | | \_____ \ + | | | \/ \ | | / \ + |____| /_______ /_______ / |____| /_______ / + \/ \/ \/ + */ + +#if ZEN_WITH_TESTS + +TEST_CASE("CompressedBuffer") +{ + uint8_t Zeroes[1024]{}; + uint8_t Ones[1024]; + memset(Ones, 1, sizeof Ones); + + { + CompressedBuffer Buffer = CompressedBuffer::Compress(CompositeBuffer(SharedBuffer::MakeView(MakeMemoryView(Zeroes))), + OodleCompressor::NotSet, + OodleCompressionLevel::None); + + CHECK(Buffer.DecodeRawSize() == sizeof(Zeroes)); + CHECK(Buffer.GetCompressedSize() == (sizeof(Zeroes) + sizeof(detail::BufferHeader))); + + CompositeBuffer Compressed = Buffer.GetCompressed(); + IoHash DecodedHash; + uint64_t DecodedRawSize; + CompressedBuffer BufferD = CompressedBuffer::FromCompressed(Compressed, DecodedHash, DecodedRawSize); + + CHECK(BufferD.IsNull() == false); + + CompositeBuffer Decomp = BufferD.DecompressToComposite(); + + CHECK(Decomp.GetSize() == DecodedRawSize); + CHECK(IoHash::HashBuffer(Decomp) == DecodedHash); + } + + { + CompressedBuffer Buffer = CompressedBuffer::Compress( + CompositeBuffer(SharedBuffer::MakeView(MakeMemoryView(Zeroes)), SharedBuffer::MakeView(MakeMemoryView(Ones))), + OodleCompressor::NotSet, + OodleCompressionLevel::None); + + CHECK(Buffer.DecodeRawSize() == (sizeof(Zeroes) + sizeof(Ones))); + CHECK(Buffer.GetCompressedSize() == (sizeof(Zeroes) + sizeof(Ones) + sizeof(detail::BufferHeader))); + + CompositeBuffer Compressed = Buffer.GetCompressed(); + IoHash DecodedHash; + uint64_t DecodedRawSize; + CompressedBuffer BufferD = CompressedBuffer::FromCompressed(Compressed, DecodedHash, DecodedRawSize); + + CHECK(BufferD.IsNull() == false); + + CompositeBuffer Decomp = BufferD.DecompressToComposite(); + + CHECK(Decomp.GetSize() == DecodedRawSize); + CHECK(IoHash::HashBuffer(Decomp) == DecodedHash); + } + + { + CompressedBuffer Buffer = CompressedBuffer::Compress(CompositeBuffer(SharedBuffer::MakeView(MakeMemoryView(Zeroes)))); + + CHECK(Buffer.DecodeRawSize() == sizeof(Zeroes)); + CHECK(Buffer.GetCompressedSize() < sizeof(Zeroes)); + + CompositeBuffer Compressed = Buffer.GetCompressed(); + IoHash DecodedHash; + uint64_t DecodedRawSize; + CompressedBuffer BufferD = CompressedBuffer::FromCompressed(Compressed, DecodedHash, DecodedRawSize); + + CHECK(BufferD.IsNull() == false); + + CompositeBuffer Decomp = BufferD.DecompressToComposite(); + + CHECK(Decomp.GetSize() == DecodedRawSize); + CHECK(IoHash::HashBuffer(Decomp) == DecodedHash); + } + + { + CompressedBuffer Buffer = CompressedBuffer::Compress( + CompositeBuffer(SharedBuffer::MakeView(MakeMemoryView(Zeroes)), SharedBuffer::MakeView(MakeMemoryView(Ones)))); + + CHECK(Buffer.DecodeRawSize() == (sizeof(Zeroes) + sizeof(Ones))); + CHECK(Buffer.GetCompressedSize() < (sizeof(Zeroes) + sizeof(Ones))); + + CompositeBuffer Compressed = Buffer.GetCompressed(); + IoHash DecodedHash; + uint64_t DecodedRawSize; + CompressedBuffer BufferD = CompressedBuffer::FromCompressed(Compressed, DecodedHash, DecodedRawSize); + + CHECK(BufferD.IsNull() == false); + + CompositeBuffer Decomp = BufferD.DecompressToComposite(); + + CHECK(Decomp.GetSize() == DecodedRawSize); + CHECK(IoHash::HashBuffer(Decomp) == DecodedHash); + } + + auto GenerateData = [](uint64_t N) -> std::vector<uint64_t> { + std::vector<uint64_t> Data; + Data.resize(N); + for (size_t Idx = 0; Idx < Data.size(); ++Idx) + { + Data[Idx] = Idx; + } + return Data; + }; + + auto ValidateData = [](std::span<uint64_t const> Values, std::span<uint64_t const> ExpectedValues, uint64_t Offset) { + for (size_t Idx = Offset; uint64_t Value : Values) + { + const uint64_t ExpectedValue = ExpectedValues[Idx++]; + CHECK(Value == ExpectedValue); + } + }; + + SUBCASE("decompress with offset and size") + { + auto UncompressAndValidate = [&ValidateData](CompressedBuffer Compressed, + uint64_t OffsetCount, + uint64_t Count, + const std::vector<uint64_t>& ExpectedValues) { + SharedBuffer Uncompressed = Compressed.Decompress(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)); + CHECK(Uncompressed.GetSize() == Count * sizeof(uint64_t)); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + ValidateData(Values, ExpectedValues, OffsetCount); + }; + + const uint64_t BlockSize = 64 * sizeof(uint64_t); + const uint64_t N = 5000; + std::vector<uint64_t> ExpectedValues = GenerateData(N); + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)), + OodleCompressor::Mermaid, + OodleCompressionLevel::Optimal4, + BlockSize); + UncompressAndValidate(Compressed, 0, N, ExpectedValues); + UncompressAndValidate(Compressed, 1, N - 1, ExpectedValues); + UncompressAndValidate(Compressed, N - 1, 1, ExpectedValues); + UncompressAndValidate(Compressed, 0, 1, ExpectedValues); + UncompressAndValidate(Compressed, 2, 4, ExpectedValues); + UncompressAndValidate(Compressed, 0, 512, ExpectedValues); + UncompressAndValidate(Compressed, 3, 514, ExpectedValues); + UncompressAndValidate(Compressed, 256, 512, ExpectedValues); + UncompressAndValidate(Compressed, 512, 512, ExpectedValues); + } + + SUBCASE("decompress with offset only") + { + const uint64_t BlockSize = 64 * sizeof(uint64_t); + const uint64_t N = 1000; + std::vector<uint64_t> ExpectedValues = GenerateData(N); + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)), + OodleCompressor::Mermaid, + OodleCompressionLevel::Optimal4, + BlockSize); + const uint64_t OffsetCount = 150; + SharedBuffer Uncompressed = Compressed.Decompress(OffsetCount * sizeof(uint64_t)); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + SUBCASE("decompress buffer with one block") + { + const uint64_t BlockSize = 256 * sizeof(uint64_t); + const uint64_t N = 100; + std::vector<uint64_t> ExpectedValues = GenerateData(N); + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)), + OodleCompressor::Mermaid, + OodleCompressionLevel::Optimal4, + BlockSize); + const uint64_t OffsetCount = 2; + const uint64_t Count = 50; + SharedBuffer Uncompressed = Compressed.Decompress(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + SUBCASE("decompress uncompressed buffer") + { + const uint64_t N = 4242; + std::vector<uint64_t> ExpectedValues = GenerateData(N); + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)), + OodleCompressor::NotSet, + OodleCompressionLevel::None); + { + const uint64_t OffsetCount = 0; + const uint64_t Count = N; + SharedBuffer Uncompressed = Compressed.Decompress(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 21; + const uint64_t Count = 999; + SharedBuffer Uncompressed = Compressed.Decompress(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + ValidateData(Values, ExpectedValues, OffsetCount); + } + } + + SUBCASE("copy range") + { + const uint64_t BlockSize = 64 * sizeof(uint64_t); + const uint64_t N = 1000; + std::vector<uint64_t> ExpectedValues = GenerateData(N); + + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)), + OodleCompressor::Mermaid, + OodleCompressionLevel::Optimal4, + BlockSize); + + { + const uint64_t OffsetCount = 0; + const uint64_t Count = N; + SharedBuffer Uncompressed = Compressed.CopyRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress(); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 64; + const uint64_t Count = N - 64; + SharedBuffer Uncompressed = Compressed.CopyRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress(); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 64 * 2 + 32; + const uint64_t Count = N - OffsetCount; + const uint64_t RawOffset = OffsetCount * sizeof(uint64_t); + const uint64_t RawSize = Count * sizeof(uint64_t); + uint64_t FirstBlockOffset = RawOffset % BlockSize; + + SharedBuffer Uncompressed = Compressed.CopyRange(RawOffset, RawSize).Decompress(); + std::span<uint64_t const> AllValues((const uint64_t*)Uncompressed.GetData(), RawSize / sizeof(uint64_t)); + std::span<uint64_t const> Values((const uint64_t*)(((const uint8_t*)(Uncompressed.GetData()) + FirstBlockOffset)), + RawSize / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 64 * 2 + 63; + const uint64_t Count = N - OffsetCount - 5; + const uint64_t RawOffset = OffsetCount * sizeof(uint64_t); + const uint64_t RawSize = Count * sizeof(uint64_t); + uint64_t FirstBlockOffset = RawOffset % BlockSize; + + SharedBuffer Uncompressed = Compressed.CopyRange(RawOffset, RawSize).Decompress(); + std::span<uint64_t const> AllValues((const uint64_t*)Uncompressed.GetData(), RawSize / sizeof(uint64_t)); + std::span<uint64_t const> Values((const uint64_t*)(((const uint8_t*)(Uncompressed.GetData()) + FirstBlockOffset)), + RawSize / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + } + + SUBCASE("copy uncompressed range") + { + const uint64_t N = 1000; + std::vector<uint64_t> ExpectedValues = GenerateData(N); + + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)), + OodleCompressor::NotSet, + OodleCompressionLevel::None); + + { + const uint64_t OffsetCount = 0; + const uint64_t Count = N; + SharedBuffer Uncompressed = Compressed.CopyRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress(); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 1; + const uint64_t Count = N - OffsetCount; + SharedBuffer Uncompressed = Compressed.CopyRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress(); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + + { + const uint64_t OffsetCount = 42; + const uint64_t Count = 100; + SharedBuffer Uncompressed = Compressed.CopyRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress(); + std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t)); + CHECK(Values.size() == Count); + ValidateData(Values, ExpectedValues, OffsetCount); + } + } +} + +void +compress_forcelink() +{ +} +#endif + +} // namespace zen diff --git a/src/zencore/crc32.cpp b/src/zencore/crc32.cpp new file mode 100644 index 000000000..d4a3cac57 --- /dev/null +++ b/src/zencore/crc32.cpp @@ -0,0 +1,545 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/crc32.h" + +namespace CRC32 { + +static const uint32_t CRCTable_DEPRECATED[256] = { + 0x00000000, 0x04C11DB7, 0x09823B6E, 0x0D4326D9, 0x130476DC, 0x17C56B6B, 0x1A864DB2, 0x1E475005, 0x2608EDB8, 0x22C9F00F, 0x2F8AD6D6, + 0x2B4BCB61, 0x350C9B64, 0x31CD86D3, 0x3C8EA00A, 0x384FBDBD, 0x4C11DB70, 0x48D0C6C7, 0x4593E01E, 0x4152FDA9, 0x5F15ADAC, 0x5BD4B01B, + 0x569796C2, 0x52568B75, 0x6A1936C8, 0x6ED82B7F, 0x639B0DA6, 0x675A1011, 0x791D4014, 0x7DDC5DA3, 0x709F7B7A, 0x745E66CD, 0x9823B6E0, + 0x9CE2AB57, 0x91A18D8E, 0x95609039, 0x8B27C03C, 0x8FE6DD8B, 0x82A5FB52, 0x8664E6E5, 0xBE2B5B58, 0xBAEA46EF, 0xB7A96036, 0xB3687D81, + 0xAD2F2D84, 0xA9EE3033, 0xA4AD16EA, 0xA06C0B5D, 0xD4326D90, 0xD0F37027, 0xDDB056FE, 0xD9714B49, 0xC7361B4C, 0xC3F706FB, 0xCEB42022, + 0xCA753D95, 0xF23A8028, 0xF6FB9D9F, 0xFBB8BB46, 0xFF79A6F1, 0xE13EF6F4, 0xE5FFEB43, 0xE8BCCD9A, 0xEC7DD02D, 0x34867077, 0x30476DC0, + 0x3D044B19, 0x39C556AE, 0x278206AB, 0x23431B1C, 0x2E003DC5, 0x2AC12072, 0x128E9DCF, 0x164F8078, 0x1B0CA6A1, 0x1FCDBB16, 0x018AEB13, + 0x054BF6A4, 0x0808D07D, 0x0CC9CDCA, 0x7897AB07, 0x7C56B6B0, 0x71159069, 0x75D48DDE, 0x6B93DDDB, 0x6F52C06C, 0x6211E6B5, 0x66D0FB02, + 0x5E9F46BF, 0x5A5E5B08, 0x571D7DD1, 0x53DC6066, 0x4D9B3063, 0x495A2DD4, 0x44190B0D, 0x40D816BA, 0xACA5C697, 0xA864DB20, 0xA527FDF9, + 0xA1E6E04E, 0xBFA1B04B, 0xBB60ADFC, 0xB6238B25, 0xB2E29692, 0x8AAD2B2F, 0x8E6C3698, 0x832F1041, 0x87EE0DF6, 0x99A95DF3, 0x9D684044, + 0x902B669D, 0x94EA7B2A, 0xE0B41DE7, 0xE4750050, 0xE9362689, 0xEDF73B3E, 0xF3B06B3B, 0xF771768C, 0xFA325055, 0xFEF34DE2, 0xC6BCF05F, + 0xC27DEDE8, 0xCF3ECB31, 0xCBFFD686, 0xD5B88683, 0xD1799B34, 0xDC3ABDED, 0xD8FBA05A, 0x690CE0EE, 0x6DCDFD59, 0x608EDB80, 0x644FC637, + 0x7A089632, 0x7EC98B85, 0x738AAD5C, 0x774BB0EB, 0x4F040D56, 0x4BC510E1, 0x46863638, 0x42472B8F, 0x5C007B8A, 0x58C1663D, 0x558240E4, + 0x51435D53, 0x251D3B9E, 0x21DC2629, 0x2C9F00F0, 0x285E1D47, 0x36194D42, 0x32D850F5, 0x3F9B762C, 0x3B5A6B9B, 0x0315D626, 0x07D4CB91, + 0x0A97ED48, 0x0E56F0FF, 0x1011A0FA, 0x14D0BD4D, 0x19939B94, 0x1D528623, 0xF12F560E, 0xF5EE4BB9, 0xF8AD6D60, 0xFC6C70D7, 0xE22B20D2, + 0xE6EA3D65, 0xEBA91BBC, 0xEF68060B, 0xD727BBB6, 0xD3E6A601, 0xDEA580D8, 0xDA649D6F, 0xC423CD6A, 0xC0E2D0DD, 0xCDA1F604, 0xC960EBB3, + 0xBD3E8D7E, 0xB9FF90C9, 0xB4BCB610, 0xB07DABA7, 0xAE3AFBA2, 0xAAFBE615, 0xA7B8C0CC, 0xA379DD7B, 0x9B3660C6, 0x9FF77D71, 0x92B45BA8, + 0x9675461F, 0x8832161A, 0x8CF30BAD, 0x81B02D74, 0x857130C3, 0x5D8A9099, 0x594B8D2E, 0x5408ABF7, 0x50C9B640, 0x4E8EE645, 0x4A4FFBF2, + 0x470CDD2B, 0x43CDC09C, 0x7B827D21, 0x7F436096, 0x7200464F, 0x76C15BF8, 0x68860BFD, 0x6C47164A, 0x61043093, 0x65C52D24, 0x119B4BE9, + 0x155A565E, 0x18197087, 0x1CD86D30, 0x029F3D35, 0x065E2082, 0x0B1D065B, 0x0FDC1BEC, 0x3793A651, 0x3352BBE6, 0x3E119D3F, 0x3AD08088, + 0x2497D08D, 0x2056CD3A, 0x2D15EBE3, 0x29D4F654, 0xC5A92679, 0xC1683BCE, 0xCC2B1D17, 0xC8EA00A0, 0xD6AD50A5, 0xD26C4D12, 0xDF2F6BCB, + 0xDBEE767C, 0xE3A1CBC1, 0xE760D676, 0xEA23F0AF, 0xEEE2ED18, 0xF0A5BD1D, 0xF464A0AA, 0xF9278673, 0xFDE69BC4, 0x89B8FD09, 0x8D79E0BE, + 0x803AC667, 0x84FBDBD0, 0x9ABC8BD5, 0x9E7D9662, 0x933EB0BB, 0x97FFAD0C, 0xAFB010B1, 0xAB710D06, 0xA6322BDF, 0xA2F33668, 0xBCB4666D, + 0xB8757BDA, 0xB5365D03, 0xB1F740B4}; + +static const uint32_t CRCTablesSB8_DEPRECATED[8][256] = { + {0x00000000, 0xb71dc104, 0x6e3b8209, 0xd926430d, 0xdc760413, 0x6b6bc517, 0xb24d861a, 0x0550471e, 0xb8ed0826, 0x0ff0c922, 0xd6d68a2f, + 0x61cb4b2b, 0x649b0c35, 0xd386cd31, 0x0aa08e3c, 0xbdbd4f38, 0x70db114c, 0xc7c6d048, 0x1ee09345, 0xa9fd5241, 0xacad155f, 0x1bb0d45b, + 0xc2969756, 0x758b5652, 0xc836196a, 0x7f2bd86e, 0xa60d9b63, 0x11105a67, 0x14401d79, 0xa35ddc7d, 0x7a7b9f70, 0xcd665e74, 0xe0b62398, + 0x57abe29c, 0x8e8da191, 0x39906095, 0x3cc0278b, 0x8bdde68f, 0x52fba582, 0xe5e66486, 0x585b2bbe, 0xef46eaba, 0x3660a9b7, 0x817d68b3, + 0x842d2fad, 0x3330eea9, 0xea16ada4, 0x5d0b6ca0, 0x906d32d4, 0x2770f3d0, 0xfe56b0dd, 0x494b71d9, 0x4c1b36c7, 0xfb06f7c3, 0x2220b4ce, + 0x953d75ca, 0x28803af2, 0x9f9dfbf6, 0x46bbb8fb, 0xf1a679ff, 0xf4f63ee1, 0x43ebffe5, 0x9acdbce8, 0x2dd07dec, 0x77708634, 0xc06d4730, + 0x194b043d, 0xae56c539, 0xab068227, 0x1c1b4323, 0xc53d002e, 0x7220c12a, 0xcf9d8e12, 0x78804f16, 0xa1a60c1b, 0x16bbcd1f, 0x13eb8a01, + 0xa4f64b05, 0x7dd00808, 0xcacdc90c, 0x07ab9778, 0xb0b6567c, 0x69901571, 0xde8dd475, 0xdbdd936b, 0x6cc0526f, 0xb5e61162, 0x02fbd066, + 0xbf469f5e, 0x085b5e5a, 0xd17d1d57, 0x6660dc53, 0x63309b4d, 0xd42d5a49, 0x0d0b1944, 0xba16d840, 0x97c6a5ac, 0x20db64a8, 0xf9fd27a5, + 0x4ee0e6a1, 0x4bb0a1bf, 0xfcad60bb, 0x258b23b6, 0x9296e2b2, 0x2f2bad8a, 0x98366c8e, 0x41102f83, 0xf60dee87, 0xf35da999, 0x4440689d, + 0x9d662b90, 0x2a7bea94, 0xe71db4e0, 0x500075e4, 0x892636e9, 0x3e3bf7ed, 0x3b6bb0f3, 0x8c7671f7, 0x555032fa, 0xe24df3fe, 0x5ff0bcc6, + 0xe8ed7dc2, 0x31cb3ecf, 0x86d6ffcb, 0x8386b8d5, 0x349b79d1, 0xedbd3adc, 0x5aa0fbd8, 0xeee00c69, 0x59fdcd6d, 0x80db8e60, 0x37c64f64, + 0x3296087a, 0x858bc97e, 0x5cad8a73, 0xebb04b77, 0x560d044f, 0xe110c54b, 0x38368646, 0x8f2b4742, 0x8a7b005c, 0x3d66c158, 0xe4408255, + 0x535d4351, 0x9e3b1d25, 0x2926dc21, 0xf0009f2c, 0x471d5e28, 0x424d1936, 0xf550d832, 0x2c769b3f, 0x9b6b5a3b, 0x26d61503, 0x91cbd407, + 0x48ed970a, 0xfff0560e, 0xfaa01110, 0x4dbdd014, 0x949b9319, 0x2386521d, 0x0e562ff1, 0xb94beef5, 0x606dadf8, 0xd7706cfc, 0xd2202be2, + 0x653deae6, 0xbc1ba9eb, 0x0b0668ef, 0xb6bb27d7, 0x01a6e6d3, 0xd880a5de, 0x6f9d64da, 0x6acd23c4, 0xddd0e2c0, 0x04f6a1cd, 0xb3eb60c9, + 0x7e8d3ebd, 0xc990ffb9, 0x10b6bcb4, 0xa7ab7db0, 0xa2fb3aae, 0x15e6fbaa, 0xccc0b8a7, 0x7bdd79a3, 0xc660369b, 0x717df79f, 0xa85bb492, + 0x1f467596, 0x1a163288, 0xad0bf38c, 0x742db081, 0xc3307185, 0x99908a5d, 0x2e8d4b59, 0xf7ab0854, 0x40b6c950, 0x45e68e4e, 0xf2fb4f4a, + 0x2bdd0c47, 0x9cc0cd43, 0x217d827b, 0x9660437f, 0x4f460072, 0xf85bc176, 0xfd0b8668, 0x4a16476c, 0x93300461, 0x242dc565, 0xe94b9b11, + 0x5e565a15, 0x87701918, 0x306dd81c, 0x353d9f02, 0x82205e06, 0x5b061d0b, 0xec1bdc0f, 0x51a69337, 0xe6bb5233, 0x3f9d113e, 0x8880d03a, + 0x8dd09724, 0x3acd5620, 0xe3eb152d, 0x54f6d429, 0x7926a9c5, 0xce3b68c1, 0x171d2bcc, 0xa000eac8, 0xa550add6, 0x124d6cd2, 0xcb6b2fdf, + 0x7c76eedb, 0xc1cba1e3, 0x76d660e7, 0xaff023ea, 0x18ede2ee, 0x1dbda5f0, 0xaaa064f4, 0x738627f9, 0xc49be6fd, 0x09fdb889, 0xbee0798d, + 0x67c63a80, 0xd0dbfb84, 0xd58bbc9a, 0x62967d9e, 0xbbb03e93, 0x0cadff97, 0xb110b0af, 0x060d71ab, 0xdf2b32a6, 0x6836f3a2, 0x6d66b4bc, + 0xda7b75b8, 0x035d36b5, 0xb440f7b1}, + {0x00000000, 0xdcc119d2, 0x0f9ef2a0, 0xd35feb72, 0xa9212445, 0x75e03d97, 0xa6bfd6e5, 0x7a7ecf37, 0x5243488a, 0x8e825158, 0x5dddba2a, + 0x811ca3f8, 0xfb626ccf, 0x27a3751d, 0xf4fc9e6f, 0x283d87bd, 0x139b5110, 0xcf5a48c2, 0x1c05a3b0, 0xc0c4ba62, 0xbaba7555, 0x667b6c87, + 0xb52487f5, 0x69e59e27, 0x41d8199a, 0x9d190048, 0x4e46eb3a, 0x9287f2e8, 0xe8f93ddf, 0x3438240d, 0xe767cf7f, 0x3ba6d6ad, 0x2636a320, + 0xfaf7baf2, 0x29a85180, 0xf5694852, 0x8f178765, 0x53d69eb7, 0x808975c5, 0x5c486c17, 0x7475ebaa, 0xa8b4f278, 0x7beb190a, 0xa72a00d8, + 0xdd54cfef, 0x0195d63d, 0xd2ca3d4f, 0x0e0b249d, 0x35adf230, 0xe96cebe2, 0x3a330090, 0xe6f21942, 0x9c8cd675, 0x404dcfa7, 0x931224d5, + 0x4fd33d07, 0x67eebaba, 0xbb2fa368, 0x6870481a, 0xb4b151c8, 0xcecf9eff, 0x120e872d, 0xc1516c5f, 0x1d90758d, 0x4c6c4641, 0x90ad5f93, + 0x43f2b4e1, 0x9f33ad33, 0xe54d6204, 0x398c7bd6, 0xead390a4, 0x36128976, 0x1e2f0ecb, 0xc2ee1719, 0x11b1fc6b, 0xcd70e5b9, 0xb70e2a8e, + 0x6bcf335c, 0xb890d82e, 0x6451c1fc, 0x5ff71751, 0x83360e83, 0x5069e5f1, 0x8ca8fc23, 0xf6d63314, 0x2a172ac6, 0xf948c1b4, 0x2589d866, + 0x0db45fdb, 0xd1754609, 0x022aad7b, 0xdeebb4a9, 0xa4957b9e, 0x7854624c, 0xab0b893e, 0x77ca90ec, 0x6a5ae561, 0xb69bfcb3, 0x65c417c1, + 0xb9050e13, 0xc37bc124, 0x1fbad8f6, 0xcce53384, 0x10242a56, 0x3819adeb, 0xe4d8b439, 0x37875f4b, 0xeb464699, 0x913889ae, 0x4df9907c, + 0x9ea67b0e, 0x426762dc, 0x79c1b471, 0xa500ada3, 0x765f46d1, 0xaa9e5f03, 0xd0e09034, 0x0c2189e6, 0xdf7e6294, 0x03bf7b46, 0x2b82fcfb, + 0xf743e529, 0x241c0e5b, 0xf8dd1789, 0x82a3d8be, 0x5e62c16c, 0x8d3d2a1e, 0x51fc33cc, 0x98d88c82, 0x44199550, 0x97467e22, 0x4b8767f0, + 0x31f9a8c7, 0xed38b115, 0x3e675a67, 0xe2a643b5, 0xca9bc408, 0x165addda, 0xc50536a8, 0x19c42f7a, 0x63bae04d, 0xbf7bf99f, 0x6c2412ed, + 0xb0e50b3f, 0x8b43dd92, 0x5782c440, 0x84dd2f32, 0x581c36e0, 0x2262f9d7, 0xfea3e005, 0x2dfc0b77, 0xf13d12a5, 0xd9009518, 0x05c18cca, + 0xd69e67b8, 0x0a5f7e6a, 0x7021b15d, 0xace0a88f, 0x7fbf43fd, 0xa37e5a2f, 0xbeee2fa2, 0x622f3670, 0xb170dd02, 0x6db1c4d0, 0x17cf0be7, + 0xcb0e1235, 0x1851f947, 0xc490e095, 0xecad6728, 0x306c7efa, 0xe3339588, 0x3ff28c5a, 0x458c436d, 0x994d5abf, 0x4a12b1cd, 0x96d3a81f, + 0xad757eb2, 0x71b46760, 0xa2eb8c12, 0x7e2a95c0, 0x04545af7, 0xd8954325, 0x0bcaa857, 0xd70bb185, 0xff363638, 0x23f72fea, 0xf0a8c498, + 0x2c69dd4a, 0x5617127d, 0x8ad60baf, 0x5989e0dd, 0x8548f90f, 0xd4b4cac3, 0x0875d311, 0xdb2a3863, 0x07eb21b1, 0x7d95ee86, 0xa154f754, + 0x720b1c26, 0xaeca05f4, 0x86f78249, 0x5a369b9b, 0x896970e9, 0x55a8693b, 0x2fd6a60c, 0xf317bfde, 0x204854ac, 0xfc894d7e, 0xc72f9bd3, + 0x1bee8201, 0xc8b16973, 0x147070a1, 0x6e0ebf96, 0xb2cfa644, 0x61904d36, 0xbd5154e4, 0x956cd359, 0x49adca8b, 0x9af221f9, 0x4633382b, + 0x3c4df71c, 0xe08ceece, 0x33d305bc, 0xef121c6e, 0xf28269e3, 0x2e437031, 0xfd1c9b43, 0x21dd8291, 0x5ba34da6, 0x87625474, 0x543dbf06, + 0x88fca6d4, 0xa0c12169, 0x7c0038bb, 0xaf5fd3c9, 0x739eca1b, 0x09e0052c, 0xd5211cfe, 0x067ef78c, 0xdabfee5e, 0xe11938f3, 0x3dd82121, + 0xee87ca53, 0x3246d381, 0x48381cb6, 0x94f90564, 0x47a6ee16, 0x9b67f7c4, 0xb35a7079, 0x6f9b69ab, 0xbcc482d9, 0x60059b0b, 0x1a7b543c, + 0xc6ba4dee, 0x15e5a69c, 0xc924bf4e}, + {0x00000000, 0x87acd801, 0x0e59b103, 0x89f56902, 0x1cb26207, 0x9b1eba06, 0x12ebd304, 0x95470b05, 0x3864c50e, 0xbfc81d0f, 0x363d740d, + 0xb191ac0c, 0x24d6a709, 0xa37a7f08, 0x2a8f160a, 0xad23ce0b, 0x70c88a1d, 0xf764521c, 0x7e913b1e, 0xf93de31f, 0x6c7ae81a, 0xebd6301b, + 0x62235919, 0xe58f8118, 0x48ac4f13, 0xcf009712, 0x46f5fe10, 0xc1592611, 0x541e2d14, 0xd3b2f515, 0x5a479c17, 0xddeb4416, 0xe090153b, + 0x673ccd3a, 0xeec9a438, 0x69657c39, 0xfc22773c, 0x7b8eaf3d, 0xf27bc63f, 0x75d71e3e, 0xd8f4d035, 0x5f580834, 0xd6ad6136, 0x5101b937, + 0xc446b232, 0x43ea6a33, 0xca1f0331, 0x4db3db30, 0x90589f26, 0x17f44727, 0x9e012e25, 0x19adf624, 0x8ceafd21, 0x0b462520, 0x82b34c22, + 0x051f9423, 0xa83c5a28, 0x2f908229, 0xa665eb2b, 0x21c9332a, 0xb48e382f, 0x3322e02e, 0xbad7892c, 0x3d7b512d, 0xc0212b76, 0x478df377, + 0xce789a75, 0x49d44274, 0xdc934971, 0x5b3f9170, 0xd2caf872, 0x55662073, 0xf845ee78, 0x7fe93679, 0xf61c5f7b, 0x71b0877a, 0xe4f78c7f, + 0x635b547e, 0xeaae3d7c, 0x6d02e57d, 0xb0e9a16b, 0x3745796a, 0xbeb01068, 0x391cc869, 0xac5bc36c, 0x2bf71b6d, 0xa202726f, 0x25aeaa6e, + 0x888d6465, 0x0f21bc64, 0x86d4d566, 0x01780d67, 0x943f0662, 0x1393de63, 0x9a66b761, 0x1dca6f60, 0x20b13e4d, 0xa71de64c, 0x2ee88f4e, + 0xa944574f, 0x3c035c4a, 0xbbaf844b, 0x325aed49, 0xb5f63548, 0x18d5fb43, 0x9f792342, 0x168c4a40, 0x91209241, 0x04679944, 0x83cb4145, + 0x0a3e2847, 0x8d92f046, 0x5079b450, 0xd7d56c51, 0x5e200553, 0xd98cdd52, 0x4ccbd657, 0xcb670e56, 0x42926754, 0xc53ebf55, 0x681d715e, + 0xefb1a95f, 0x6644c05d, 0xe1e8185c, 0x74af1359, 0xf303cb58, 0x7af6a25a, 0xfd5a7a5b, 0x804356ec, 0x07ef8eed, 0x8e1ae7ef, 0x09b63fee, + 0x9cf134eb, 0x1b5decea, 0x92a885e8, 0x15045de9, 0xb82793e2, 0x3f8b4be3, 0xb67e22e1, 0x31d2fae0, 0xa495f1e5, 0x233929e4, 0xaacc40e6, + 0x2d6098e7, 0xf08bdcf1, 0x772704f0, 0xfed26df2, 0x797eb5f3, 0xec39bef6, 0x6b9566f7, 0xe2600ff5, 0x65ccd7f4, 0xc8ef19ff, 0x4f43c1fe, + 0xc6b6a8fc, 0x411a70fd, 0xd45d7bf8, 0x53f1a3f9, 0xda04cafb, 0x5da812fa, 0x60d343d7, 0xe77f9bd6, 0x6e8af2d4, 0xe9262ad5, 0x7c6121d0, + 0xfbcdf9d1, 0x723890d3, 0xf59448d2, 0x58b786d9, 0xdf1b5ed8, 0x56ee37da, 0xd142efdb, 0x4405e4de, 0xc3a93cdf, 0x4a5c55dd, 0xcdf08ddc, + 0x101bc9ca, 0x97b711cb, 0x1e4278c9, 0x99eea0c8, 0x0ca9abcd, 0x8b0573cc, 0x02f01ace, 0x855cc2cf, 0x287f0cc4, 0xafd3d4c5, 0x2626bdc7, + 0xa18a65c6, 0x34cd6ec3, 0xb361b6c2, 0x3a94dfc0, 0xbd3807c1, 0x40627d9a, 0xc7cea59b, 0x4e3bcc99, 0xc9971498, 0x5cd01f9d, 0xdb7cc79c, + 0x5289ae9e, 0xd525769f, 0x7806b894, 0xffaa6095, 0x765f0997, 0xf1f3d196, 0x64b4da93, 0xe3180292, 0x6aed6b90, 0xed41b391, 0x30aaf787, + 0xb7062f86, 0x3ef34684, 0xb95f9e85, 0x2c189580, 0xabb44d81, 0x22412483, 0xa5edfc82, 0x08ce3289, 0x8f62ea88, 0x0697838a, 0x813b5b8b, + 0x147c508e, 0x93d0888f, 0x1a25e18d, 0x9d89398c, 0xa0f268a1, 0x275eb0a0, 0xaeabd9a2, 0x290701a3, 0xbc400aa6, 0x3becd2a7, 0xb219bba5, + 0x35b563a4, 0x9896adaf, 0x1f3a75ae, 0x96cf1cac, 0x1163c4ad, 0x8424cfa8, 0x038817a9, 0x8a7d7eab, 0x0dd1a6aa, 0xd03ae2bc, 0x57963abd, + 0xde6353bf, 0x59cf8bbe, 0xcc8880bb, 0x4b2458ba, 0xc2d131b8, 0x457de9b9, 0xe85e27b2, 0x6ff2ffb3, 0xe60796b1, 0x61ab4eb0, 0xf4ec45b5, + 0x73409db4, 0xfab5f4b6, 0x7d192cb7}, + {0x00000000, 0xb79a6ddc, 0xd9281abc, 0x6eb27760, 0x054cf57c, 0xb2d698a0, 0xdc64efc0, 0x6bfe821c, 0x0a98eaf9, 0xbd028725, 0xd3b0f045, + 0x642a9d99, 0x0fd41f85, 0xb84e7259, 0xd6fc0539, 0x616668e5, 0xa32d14f7, 0x14b7792b, 0x7a050e4b, 0xcd9f6397, 0xa661e18b, 0x11fb8c57, + 0x7f49fb37, 0xc8d396eb, 0xa9b5fe0e, 0x1e2f93d2, 0x709de4b2, 0xc707896e, 0xacf90b72, 0x1b6366ae, 0x75d111ce, 0xc24b7c12, 0xf146e9ea, + 0x46dc8436, 0x286ef356, 0x9ff49e8a, 0xf40a1c96, 0x4390714a, 0x2d22062a, 0x9ab86bf6, 0xfbde0313, 0x4c446ecf, 0x22f619af, 0x956c7473, + 0xfe92f66f, 0x49089bb3, 0x27baecd3, 0x9020810f, 0x526bfd1d, 0xe5f190c1, 0x8b43e7a1, 0x3cd98a7d, 0x57270861, 0xe0bd65bd, 0x8e0f12dd, + 0x39957f01, 0x58f317e4, 0xef697a38, 0x81db0d58, 0x36416084, 0x5dbfe298, 0xea258f44, 0x8497f824, 0x330d95f8, 0x559013d1, 0xe20a7e0d, + 0x8cb8096d, 0x3b2264b1, 0x50dce6ad, 0xe7468b71, 0x89f4fc11, 0x3e6e91cd, 0x5f08f928, 0xe89294f4, 0x8620e394, 0x31ba8e48, 0x5a440c54, + 0xedde6188, 0x836c16e8, 0x34f67b34, 0xf6bd0726, 0x41276afa, 0x2f951d9a, 0x980f7046, 0xf3f1f25a, 0x446b9f86, 0x2ad9e8e6, 0x9d43853a, + 0xfc25eddf, 0x4bbf8003, 0x250df763, 0x92979abf, 0xf96918a3, 0x4ef3757f, 0x2041021f, 0x97db6fc3, 0xa4d6fa3b, 0x134c97e7, 0x7dfee087, + 0xca648d5b, 0xa19a0f47, 0x1600629b, 0x78b215fb, 0xcf287827, 0xae4e10c2, 0x19d47d1e, 0x77660a7e, 0xc0fc67a2, 0xab02e5be, 0x1c988862, + 0x722aff02, 0xc5b092de, 0x07fbeecc, 0xb0618310, 0xded3f470, 0x694999ac, 0x02b71bb0, 0xb52d766c, 0xdb9f010c, 0x6c056cd0, 0x0d630435, + 0xbaf969e9, 0xd44b1e89, 0x63d17355, 0x082ff149, 0xbfb59c95, 0xd107ebf5, 0x669d8629, 0x1d3de6a6, 0xaaa78b7a, 0xc415fc1a, 0x738f91c6, + 0x187113da, 0xafeb7e06, 0xc1590966, 0x76c364ba, 0x17a50c5f, 0xa03f6183, 0xce8d16e3, 0x79177b3f, 0x12e9f923, 0xa57394ff, 0xcbc1e39f, + 0x7c5b8e43, 0xbe10f251, 0x098a9f8d, 0x6738e8ed, 0xd0a28531, 0xbb5c072d, 0x0cc66af1, 0x62741d91, 0xd5ee704d, 0xb48818a8, 0x03127574, + 0x6da00214, 0xda3a6fc8, 0xb1c4edd4, 0x065e8008, 0x68ecf768, 0xdf769ab4, 0xec7b0f4c, 0x5be16290, 0x355315f0, 0x82c9782c, 0xe937fa30, + 0x5ead97ec, 0x301fe08c, 0x87858d50, 0xe6e3e5b5, 0x51798869, 0x3fcbff09, 0x885192d5, 0xe3af10c9, 0x54357d15, 0x3a870a75, 0x8d1d67a9, + 0x4f561bbb, 0xf8cc7667, 0x967e0107, 0x21e46cdb, 0x4a1aeec7, 0xfd80831b, 0x9332f47b, 0x24a899a7, 0x45cef142, 0xf2549c9e, 0x9ce6ebfe, + 0x2b7c8622, 0x4082043e, 0xf71869e2, 0x99aa1e82, 0x2e30735e, 0x48adf577, 0xff3798ab, 0x9185efcb, 0x261f8217, 0x4de1000b, 0xfa7b6dd7, + 0x94c91ab7, 0x2353776b, 0x42351f8e, 0xf5af7252, 0x9b1d0532, 0x2c8768ee, 0x4779eaf2, 0xf0e3872e, 0x9e51f04e, 0x29cb9d92, 0xeb80e180, + 0x5c1a8c5c, 0x32a8fb3c, 0x853296e0, 0xeecc14fc, 0x59567920, 0x37e40e40, 0x807e639c, 0xe1180b79, 0x568266a5, 0x383011c5, 0x8faa7c19, + 0xe454fe05, 0x53ce93d9, 0x3d7ce4b9, 0x8ae68965, 0xb9eb1c9d, 0x0e717141, 0x60c30621, 0xd7596bfd, 0xbca7e9e1, 0x0b3d843d, 0x658ff35d, + 0xd2159e81, 0xb373f664, 0x04e99bb8, 0x6a5becd8, 0xddc18104, 0xb63f0318, 0x01a56ec4, 0x6f1719a4, 0xd88d7478, 0x1ac6086a, 0xad5c65b6, + 0xc3ee12d6, 0x74747f0a, 0x1f8afd16, 0xa81090ca, 0xc6a2e7aa, 0x71388a76, 0x105ee293, 0xa7c48f4f, 0xc976f82f, 0x7eec95f3, 0x151217ef, + 0xa2887a33, 0xcc3a0d53, 0x7ba0608f}, + {0x00000000, 0x8d670d49, 0x1acf1a92, 0x97a817db, 0x8383f420, 0x0ee4f969, 0x994ceeb2, 0x142be3fb, 0x0607e941, 0x8b60e408, 0x1cc8f3d3, + 0x91affe9a, 0x85841d61, 0x08e31028, 0x9f4b07f3, 0x122c0aba, 0x0c0ed283, 0x8169dfca, 0x16c1c811, 0x9ba6c558, 0x8f8d26a3, 0x02ea2bea, + 0x95423c31, 0x18253178, 0x0a093bc2, 0x876e368b, 0x10c62150, 0x9da12c19, 0x898acfe2, 0x04edc2ab, 0x9345d570, 0x1e22d839, 0xaf016503, + 0x2266684a, 0xb5ce7f91, 0x38a972d8, 0x2c829123, 0xa1e59c6a, 0x364d8bb1, 0xbb2a86f8, 0xa9068c42, 0x2461810b, 0xb3c996d0, 0x3eae9b99, + 0x2a857862, 0xa7e2752b, 0x304a62f0, 0xbd2d6fb9, 0xa30fb780, 0x2e68bac9, 0xb9c0ad12, 0x34a7a05b, 0x208c43a0, 0xadeb4ee9, 0x3a435932, + 0xb724547b, 0xa5085ec1, 0x286f5388, 0xbfc74453, 0x32a0491a, 0x268baae1, 0xabeca7a8, 0x3c44b073, 0xb123bd3a, 0x5e03ca06, 0xd364c74f, + 0x44ccd094, 0xc9abdddd, 0xdd803e26, 0x50e7336f, 0xc74f24b4, 0x4a2829fd, 0x58042347, 0xd5632e0e, 0x42cb39d5, 0xcfac349c, 0xdb87d767, + 0x56e0da2e, 0xc148cdf5, 0x4c2fc0bc, 0x520d1885, 0xdf6a15cc, 0x48c20217, 0xc5a50f5e, 0xd18eeca5, 0x5ce9e1ec, 0xcb41f637, 0x4626fb7e, + 0x540af1c4, 0xd96dfc8d, 0x4ec5eb56, 0xc3a2e61f, 0xd78905e4, 0x5aee08ad, 0xcd461f76, 0x4021123f, 0xf102af05, 0x7c65a24c, 0xebcdb597, + 0x66aab8de, 0x72815b25, 0xffe6566c, 0x684e41b7, 0xe5294cfe, 0xf7054644, 0x7a624b0d, 0xedca5cd6, 0x60ad519f, 0x7486b264, 0xf9e1bf2d, + 0x6e49a8f6, 0xe32ea5bf, 0xfd0c7d86, 0x706b70cf, 0xe7c36714, 0x6aa46a5d, 0x7e8f89a6, 0xf3e884ef, 0x64409334, 0xe9279e7d, 0xfb0b94c7, + 0x766c998e, 0xe1c48e55, 0x6ca3831c, 0x788860e7, 0xf5ef6dae, 0x62477a75, 0xef20773c, 0xbc06940d, 0x31619944, 0xa6c98e9f, 0x2bae83d6, + 0x3f85602d, 0xb2e26d64, 0x254a7abf, 0xa82d77f6, 0xba017d4c, 0x37667005, 0xa0ce67de, 0x2da96a97, 0x3982896c, 0xb4e58425, 0x234d93fe, + 0xae2a9eb7, 0xb008468e, 0x3d6f4bc7, 0xaac75c1c, 0x27a05155, 0x338bb2ae, 0xbeecbfe7, 0x2944a83c, 0xa423a575, 0xb60fafcf, 0x3b68a286, + 0xacc0b55d, 0x21a7b814, 0x358c5bef, 0xb8eb56a6, 0x2f43417d, 0xa2244c34, 0x1307f10e, 0x9e60fc47, 0x09c8eb9c, 0x84afe6d5, 0x9084052e, + 0x1de30867, 0x8a4b1fbc, 0x072c12f5, 0x1500184f, 0x98671506, 0x0fcf02dd, 0x82a80f94, 0x9683ec6f, 0x1be4e126, 0x8c4cf6fd, 0x012bfbb4, + 0x1f09238d, 0x926e2ec4, 0x05c6391f, 0x88a13456, 0x9c8ad7ad, 0x11eddae4, 0x8645cd3f, 0x0b22c076, 0x190ecacc, 0x9469c785, 0x03c1d05e, + 0x8ea6dd17, 0x9a8d3eec, 0x17ea33a5, 0x8042247e, 0x0d252937, 0xe2055e0b, 0x6f625342, 0xf8ca4499, 0x75ad49d0, 0x6186aa2b, 0xece1a762, + 0x7b49b0b9, 0xf62ebdf0, 0xe402b74a, 0x6965ba03, 0xfecdadd8, 0x73aaa091, 0x6781436a, 0xeae64e23, 0x7d4e59f8, 0xf02954b1, 0xee0b8c88, + 0x636c81c1, 0xf4c4961a, 0x79a39b53, 0x6d8878a8, 0xe0ef75e1, 0x7747623a, 0xfa206f73, 0xe80c65c9, 0x656b6880, 0xf2c37f5b, 0x7fa47212, + 0x6b8f91e9, 0xe6e89ca0, 0x71408b7b, 0xfc278632, 0x4d043b08, 0xc0633641, 0x57cb219a, 0xdaac2cd3, 0xce87cf28, 0x43e0c261, 0xd448d5ba, + 0x592fd8f3, 0x4b03d249, 0xc664df00, 0x51ccc8db, 0xdcabc592, 0xc8802669, 0x45e72b20, 0xd24f3cfb, 0x5f2831b2, 0x410ae98b, 0xcc6de4c2, + 0x5bc5f319, 0xd6a2fe50, 0xc2891dab, 0x4fee10e2, 0xd8460739, 0x55210a70, 0x470d00ca, 0xca6a0d83, 0x5dc21a58, 0xd0a51711, 0xc48ef4ea, + 0x49e9f9a3, 0xde41ee78, 0x5326e331}, + {0x00000000, 0x780d281b, 0xf01a5036, 0x8817782d, 0xe035a06c, 0x98388877, 0x102ff05a, 0x6822d841, 0xc06b40d9, 0xb86668c2, 0x307110ef, + 0x487c38f4, 0x205ee0b5, 0x5853c8ae, 0xd044b083, 0xa8499898, 0x37ca41b6, 0x4fc769ad, 0xc7d01180, 0xbfdd399b, 0xd7ffe1da, 0xaff2c9c1, + 0x27e5b1ec, 0x5fe899f7, 0xf7a1016f, 0x8fac2974, 0x07bb5159, 0x7fb67942, 0x1794a103, 0x6f998918, 0xe78ef135, 0x9f83d92e, 0xd9894268, + 0xa1846a73, 0x2993125e, 0x519e3a45, 0x39bce204, 0x41b1ca1f, 0xc9a6b232, 0xb1ab9a29, 0x19e202b1, 0x61ef2aaa, 0xe9f85287, 0x91f57a9c, + 0xf9d7a2dd, 0x81da8ac6, 0x09cdf2eb, 0x71c0daf0, 0xee4303de, 0x964e2bc5, 0x1e5953e8, 0x66547bf3, 0x0e76a3b2, 0x767b8ba9, 0xfe6cf384, + 0x8661db9f, 0x2e284307, 0x56256b1c, 0xde321331, 0xa63f3b2a, 0xce1de36b, 0xb610cb70, 0x3e07b35d, 0x460a9b46, 0xb21385d0, 0xca1eadcb, + 0x4209d5e6, 0x3a04fdfd, 0x522625bc, 0x2a2b0da7, 0xa23c758a, 0xda315d91, 0x7278c509, 0x0a75ed12, 0x8262953f, 0xfa6fbd24, 0x924d6565, + 0xea404d7e, 0x62573553, 0x1a5a1d48, 0x85d9c466, 0xfdd4ec7d, 0x75c39450, 0x0dcebc4b, 0x65ec640a, 0x1de14c11, 0x95f6343c, 0xedfb1c27, + 0x45b284bf, 0x3dbfaca4, 0xb5a8d489, 0xcda5fc92, 0xa58724d3, 0xdd8a0cc8, 0x559d74e5, 0x2d905cfe, 0x6b9ac7b8, 0x1397efa3, 0x9b80978e, + 0xe38dbf95, 0x8baf67d4, 0xf3a24fcf, 0x7bb537e2, 0x03b81ff9, 0xabf18761, 0xd3fcaf7a, 0x5bebd757, 0x23e6ff4c, 0x4bc4270d, 0x33c90f16, + 0xbbde773b, 0xc3d35f20, 0x5c50860e, 0x245dae15, 0xac4ad638, 0xd447fe23, 0xbc652662, 0xc4680e79, 0x4c7f7654, 0x34725e4f, 0x9c3bc6d7, + 0xe436eecc, 0x6c2196e1, 0x142cbefa, 0x7c0e66bb, 0x04034ea0, 0x8c14368d, 0xf4191e96, 0xd33acba5, 0xab37e3be, 0x23209b93, 0x5b2db388, + 0x330f6bc9, 0x4b0243d2, 0xc3153bff, 0xbb1813e4, 0x13518b7c, 0x6b5ca367, 0xe34bdb4a, 0x9b46f351, 0xf3642b10, 0x8b69030b, 0x037e7b26, + 0x7b73533d, 0xe4f08a13, 0x9cfda208, 0x14eada25, 0x6ce7f23e, 0x04c52a7f, 0x7cc80264, 0xf4df7a49, 0x8cd25252, 0x249bcaca, 0x5c96e2d1, + 0xd4819afc, 0xac8cb2e7, 0xc4ae6aa6, 0xbca342bd, 0x34b43a90, 0x4cb9128b, 0x0ab389cd, 0x72bea1d6, 0xfaa9d9fb, 0x82a4f1e0, 0xea8629a1, + 0x928b01ba, 0x1a9c7997, 0x6291518c, 0xcad8c914, 0xb2d5e10f, 0x3ac29922, 0x42cfb139, 0x2aed6978, 0x52e04163, 0xdaf7394e, 0xa2fa1155, + 0x3d79c87b, 0x4574e060, 0xcd63984d, 0xb56eb056, 0xdd4c6817, 0xa541400c, 0x2d563821, 0x555b103a, 0xfd1288a2, 0x851fa0b9, 0x0d08d894, + 0x7505f08f, 0x1d2728ce, 0x652a00d5, 0xed3d78f8, 0x953050e3, 0x61294e75, 0x1924666e, 0x91331e43, 0xe93e3658, 0x811cee19, 0xf911c602, + 0x7106be2f, 0x090b9634, 0xa1420eac, 0xd94f26b7, 0x51585e9a, 0x29557681, 0x4177aec0, 0x397a86db, 0xb16dfef6, 0xc960d6ed, 0x56e30fc3, + 0x2eee27d8, 0xa6f95ff5, 0xdef477ee, 0xb6d6afaf, 0xcedb87b4, 0x46ccff99, 0x3ec1d782, 0x96884f1a, 0xee856701, 0x66921f2c, 0x1e9f3737, + 0x76bdef76, 0x0eb0c76d, 0x86a7bf40, 0xfeaa975b, 0xb8a00c1d, 0xc0ad2406, 0x48ba5c2b, 0x30b77430, 0x5895ac71, 0x2098846a, 0xa88ffc47, + 0xd082d45c, 0x78cb4cc4, 0x00c664df, 0x88d11cf2, 0xf0dc34e9, 0x98feeca8, 0xe0f3c4b3, 0x68e4bc9e, 0x10e99485, 0x8f6a4dab, 0xf76765b0, + 0x7f701d9d, 0x077d3586, 0x6f5fedc7, 0x1752c5dc, 0x9f45bdf1, 0xe74895ea, 0x4f010d72, 0x370c2569, 0xbf1b5d44, 0xc716755f, 0xaf34ad1e, + 0xd7398505, 0x5f2efd28, 0x2723d533}, + {0x00000000, 0x1168574f, 0x22d0ae9e, 0x33b8f9d1, 0xf3bd9c39, 0xe2d5cb76, 0xd16d32a7, 0xc00565e8, 0xe67b3973, 0xf7136e3c, 0xc4ab97ed, + 0xd5c3c0a2, 0x15c6a54a, 0x04aef205, 0x37160bd4, 0x267e5c9b, 0xccf772e6, 0xdd9f25a9, 0xee27dc78, 0xff4f8b37, 0x3f4aeedf, 0x2e22b990, + 0x1d9a4041, 0x0cf2170e, 0x2a8c4b95, 0x3be41cda, 0x085ce50b, 0x1934b244, 0xd931d7ac, 0xc85980e3, 0xfbe17932, 0xea892e7d, 0x2ff224c8, + 0x3e9a7387, 0x0d228a56, 0x1c4add19, 0xdc4fb8f1, 0xcd27efbe, 0xfe9f166f, 0xeff74120, 0xc9891dbb, 0xd8e14af4, 0xeb59b325, 0xfa31e46a, + 0x3a348182, 0x2b5cd6cd, 0x18e42f1c, 0x098c7853, 0xe305562e, 0xf26d0161, 0xc1d5f8b0, 0xd0bdafff, 0x10b8ca17, 0x01d09d58, 0x32686489, + 0x230033c6, 0x057e6f5d, 0x14163812, 0x27aec1c3, 0x36c6968c, 0xf6c3f364, 0xe7aba42b, 0xd4135dfa, 0xc57b0ab5, 0xe9f98894, 0xf891dfdb, + 0xcb29260a, 0xda417145, 0x1a4414ad, 0x0b2c43e2, 0x3894ba33, 0x29fced7c, 0x0f82b1e7, 0x1eeae6a8, 0x2d521f79, 0x3c3a4836, 0xfc3f2dde, + 0xed577a91, 0xdeef8340, 0xcf87d40f, 0x250efa72, 0x3466ad3d, 0x07de54ec, 0x16b603a3, 0xd6b3664b, 0xc7db3104, 0xf463c8d5, 0xe50b9f9a, + 0xc375c301, 0xd21d944e, 0xe1a56d9f, 0xf0cd3ad0, 0x30c85f38, 0x21a00877, 0x1218f1a6, 0x0370a6e9, 0xc60bac5c, 0xd763fb13, 0xe4db02c2, + 0xf5b3558d, 0x35b63065, 0x24de672a, 0x17669efb, 0x060ec9b4, 0x2070952f, 0x3118c260, 0x02a03bb1, 0x13c86cfe, 0xd3cd0916, 0xc2a55e59, + 0xf11da788, 0xe075f0c7, 0x0afcdeba, 0x1b9489f5, 0x282c7024, 0x3944276b, 0xf9414283, 0xe82915cc, 0xdb91ec1d, 0xcaf9bb52, 0xec87e7c9, + 0xfdefb086, 0xce574957, 0xdf3f1e18, 0x1f3a7bf0, 0x0e522cbf, 0x3dead56e, 0x2c828221, 0x65eed02d, 0x74868762, 0x473e7eb3, 0x565629fc, + 0x96534c14, 0x873b1b5b, 0xb483e28a, 0xa5ebb5c5, 0x8395e95e, 0x92fdbe11, 0xa14547c0, 0xb02d108f, 0x70287567, 0x61402228, 0x52f8dbf9, + 0x43908cb6, 0xa919a2cb, 0xb871f584, 0x8bc90c55, 0x9aa15b1a, 0x5aa43ef2, 0x4bcc69bd, 0x7874906c, 0x691cc723, 0x4f629bb8, 0x5e0accf7, + 0x6db23526, 0x7cda6269, 0xbcdf0781, 0xadb750ce, 0x9e0fa91f, 0x8f67fe50, 0x4a1cf4e5, 0x5b74a3aa, 0x68cc5a7b, 0x79a40d34, 0xb9a168dc, + 0xa8c93f93, 0x9b71c642, 0x8a19910d, 0xac67cd96, 0xbd0f9ad9, 0x8eb76308, 0x9fdf3447, 0x5fda51af, 0x4eb206e0, 0x7d0aff31, 0x6c62a87e, + 0x86eb8603, 0x9783d14c, 0xa43b289d, 0xb5537fd2, 0x75561a3a, 0x643e4d75, 0x5786b4a4, 0x46eee3eb, 0x6090bf70, 0x71f8e83f, 0x424011ee, + 0x532846a1, 0x932d2349, 0x82457406, 0xb1fd8dd7, 0xa095da98, 0x8c1758b9, 0x9d7f0ff6, 0xaec7f627, 0xbfafa168, 0x7faac480, 0x6ec293cf, + 0x5d7a6a1e, 0x4c123d51, 0x6a6c61ca, 0x7b043685, 0x48bccf54, 0x59d4981b, 0x99d1fdf3, 0x88b9aabc, 0xbb01536d, 0xaa690422, 0x40e02a5f, + 0x51887d10, 0x623084c1, 0x7358d38e, 0xb35db666, 0xa235e129, 0x918d18f8, 0x80e54fb7, 0xa69b132c, 0xb7f34463, 0x844bbdb2, 0x9523eafd, + 0x55268f15, 0x444ed85a, 0x77f6218b, 0x669e76c4, 0xa3e57c71, 0xb28d2b3e, 0x8135d2ef, 0x905d85a0, 0x5058e048, 0x4130b707, 0x72884ed6, + 0x63e01999, 0x459e4502, 0x54f6124d, 0x674eeb9c, 0x7626bcd3, 0xb623d93b, 0xa74b8e74, 0x94f377a5, 0x859b20ea, 0x6f120e97, 0x7e7a59d8, + 0x4dc2a009, 0x5caaf746, 0x9caf92ae, 0x8dc7c5e1, 0xbe7f3c30, 0xaf176b7f, 0x896937e4, 0x980160ab, 0xabb9997a, 0xbad1ce35, 0x7ad4abdd, + 0x6bbcfc92, 0x58040543, 0x496c520c}, + {0x00000000, 0xcadca15b, 0x94b943b7, 0x5e65e2ec, 0x9f6e466a, 0x55b2e731, 0x0bd705dd, 0xc10ba486, 0x3edd8cd4, 0xf4012d8f, 0xaa64cf63, + 0x60b86e38, 0xa1b3cabe, 0x6b6f6be5, 0x350a8909, 0xffd62852, 0xcba7d8ad, 0x017b79f6, 0x5f1e9b1a, 0x95c23a41, 0x54c99ec7, 0x9e153f9c, + 0xc070dd70, 0x0aac7c2b, 0xf57a5479, 0x3fa6f522, 0x61c317ce, 0xab1fb695, 0x6a141213, 0xa0c8b348, 0xfead51a4, 0x3471f0ff, 0x2152705f, + 0xeb8ed104, 0xb5eb33e8, 0x7f3792b3, 0xbe3c3635, 0x74e0976e, 0x2a857582, 0xe059d4d9, 0x1f8ffc8b, 0xd5535dd0, 0x8b36bf3c, 0x41ea1e67, + 0x80e1bae1, 0x4a3d1bba, 0x1458f956, 0xde84580d, 0xeaf5a8f2, 0x202909a9, 0x7e4ceb45, 0xb4904a1e, 0x759bee98, 0xbf474fc3, 0xe122ad2f, + 0x2bfe0c74, 0xd4282426, 0x1ef4857d, 0x40916791, 0x8a4dc6ca, 0x4b46624c, 0x819ac317, 0xdfff21fb, 0x152380a0, 0x42a4e0be, 0x887841e5, + 0xd61da309, 0x1cc10252, 0xddcaa6d4, 0x1716078f, 0x4973e563, 0x83af4438, 0x7c796c6a, 0xb6a5cd31, 0xe8c02fdd, 0x221c8e86, 0xe3172a00, + 0x29cb8b5b, 0x77ae69b7, 0xbd72c8ec, 0x89033813, 0x43df9948, 0x1dba7ba4, 0xd766daff, 0x166d7e79, 0xdcb1df22, 0x82d43dce, 0x48089c95, + 0xb7deb4c7, 0x7d02159c, 0x2367f770, 0xe9bb562b, 0x28b0f2ad, 0xe26c53f6, 0xbc09b11a, 0x76d51041, 0x63f690e1, 0xa92a31ba, 0xf74fd356, + 0x3d93720d, 0xfc98d68b, 0x364477d0, 0x6821953c, 0xa2fd3467, 0x5d2b1c35, 0x97f7bd6e, 0xc9925f82, 0x034efed9, 0xc2455a5f, 0x0899fb04, + 0x56fc19e8, 0x9c20b8b3, 0xa851484c, 0x628de917, 0x3ce80bfb, 0xf634aaa0, 0x373f0e26, 0xfde3af7d, 0xa3864d91, 0x695aecca, 0x968cc498, + 0x5c5065c3, 0x0235872f, 0xc8e92674, 0x09e282f2, 0xc33e23a9, 0x9d5bc145, 0x5787601e, 0x33550079, 0xf989a122, 0xa7ec43ce, 0x6d30e295, + 0xac3b4613, 0x66e7e748, 0x388205a4, 0xf25ea4ff, 0x0d888cad, 0xc7542df6, 0x9931cf1a, 0x53ed6e41, 0x92e6cac7, 0x583a6b9c, 0x065f8970, + 0xcc83282b, 0xf8f2d8d4, 0x322e798f, 0x6c4b9b63, 0xa6973a38, 0x679c9ebe, 0xad403fe5, 0xf325dd09, 0x39f97c52, 0xc62f5400, 0x0cf3f55b, + 0x529617b7, 0x984ab6ec, 0x5941126a, 0x939db331, 0xcdf851dd, 0x0724f086, 0x12077026, 0xd8dbd17d, 0x86be3391, 0x4c6292ca, 0x8d69364c, + 0x47b59717, 0x19d075fb, 0xd30cd4a0, 0x2cdafcf2, 0xe6065da9, 0xb863bf45, 0x72bf1e1e, 0xb3b4ba98, 0x79681bc3, 0x270df92f, 0xedd15874, + 0xd9a0a88b, 0x137c09d0, 0x4d19eb3c, 0x87c54a67, 0x46ceeee1, 0x8c124fba, 0xd277ad56, 0x18ab0c0d, 0xe77d245f, 0x2da18504, 0x73c467e8, + 0xb918c6b3, 0x78136235, 0xb2cfc36e, 0xecaa2182, 0x267680d9, 0x71f1e0c7, 0xbb2d419c, 0xe548a370, 0x2f94022b, 0xee9fa6ad, 0x244307f6, + 0x7a26e51a, 0xb0fa4441, 0x4f2c6c13, 0x85f0cd48, 0xdb952fa4, 0x11498eff, 0xd0422a79, 0x1a9e8b22, 0x44fb69ce, 0x8e27c895, 0xba56386a, + 0x708a9931, 0x2eef7bdd, 0xe433da86, 0x25387e00, 0xefe4df5b, 0xb1813db7, 0x7b5d9cec, 0x848bb4be, 0x4e5715e5, 0x1032f709, 0xdaee5652, + 0x1be5f2d4, 0xd139538f, 0x8f5cb163, 0x45801038, 0x50a39098, 0x9a7f31c3, 0xc41ad32f, 0x0ec67274, 0xcfcdd6f2, 0x051177a9, 0x5b749545, + 0x91a8341e, 0x6e7e1c4c, 0xa4a2bd17, 0xfac75ffb, 0x301bfea0, 0xf1105a26, 0x3bccfb7d, 0x65a91991, 0xaf75b8ca, 0x9b044835, 0x51d8e96e, + 0x0fbd0b82, 0xc561aad9, 0x046a0e5f, 0xceb6af04, 0x90d34de8, 0x5a0fecb3, 0xa5d9c4e1, 0x6f0565ba, 0x31608756, 0xfbbc260d, 0x3ab7828b, + 0xf06b23d0, 0xae0ec13c, 0x64d26067}}; + +static const uint32_t CRCTablesSB8[8][256] = { + {0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3, 0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, + 0x97d2d988, 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, + 0xf4d4b551, 0x83d385c7, 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5, 0x3b6e20c8, + 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, + 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, + 0xb8bda50f, 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, 0x76dc4190, 0x01db7106, + 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433, 0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, + 0x086d3d2d, 0x91646c97, 0xe6635c01, 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457, + 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, 0x4db26158, 0x3ab551ce, 0xa3bc0074, + 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb, 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, + 0xaa0a4c5f, 0xdd0d7cc9, 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, 0x5edef90e, + 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad, 0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, + 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683, 0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, + 0x7d079eb1, 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7, 0xfed41b76, 0x89d32be0, + 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, + 0xa6bc5767, 0x3fb506dd, 0x48b2364b, 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79, + 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, + 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f, + 0x72076785, 0x05005713, 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21, 0x86d3d2d4, + 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, + 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45, 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, + 0x3e6e77db, 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, 0xbdbdf21c, 0xcabac28a, + 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf, 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, + 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d}, + {0x00000000, 0x191b3141, 0x32366282, 0x2b2d53c3, 0x646cc504, 0x7d77f445, 0x565aa786, 0x4f4196c7, 0xc8d98a08, 0xd1c2bb49, 0xfaefe88a, + 0xe3f4d9cb, 0xacb54f0c, 0xb5ae7e4d, 0x9e832d8e, 0x87981ccf, 0x4ac21251, 0x53d92310, 0x78f470d3, 0x61ef4192, 0x2eaed755, 0x37b5e614, + 0x1c98b5d7, 0x05838496, 0x821b9859, 0x9b00a918, 0xb02dfadb, 0xa936cb9a, 0xe6775d5d, 0xff6c6c1c, 0xd4413fdf, 0xcd5a0e9e, 0x958424a2, + 0x8c9f15e3, 0xa7b24620, 0xbea97761, 0xf1e8e1a6, 0xe8f3d0e7, 0xc3de8324, 0xdac5b265, 0x5d5daeaa, 0x44469feb, 0x6f6bcc28, 0x7670fd69, + 0x39316bae, 0x202a5aef, 0x0b07092c, 0x121c386d, 0xdf4636f3, 0xc65d07b2, 0xed705471, 0xf46b6530, 0xbb2af3f7, 0xa231c2b6, 0x891c9175, + 0x9007a034, 0x179fbcfb, 0x0e848dba, 0x25a9de79, 0x3cb2ef38, 0x73f379ff, 0x6ae848be, 0x41c51b7d, 0x58de2a3c, 0xf0794f05, 0xe9627e44, + 0xc24f2d87, 0xdb541cc6, 0x94158a01, 0x8d0ebb40, 0xa623e883, 0xbf38d9c2, 0x38a0c50d, 0x21bbf44c, 0x0a96a78f, 0x138d96ce, 0x5ccc0009, + 0x45d73148, 0x6efa628b, 0x77e153ca, 0xbabb5d54, 0xa3a06c15, 0x888d3fd6, 0x91960e97, 0xded79850, 0xc7cca911, 0xece1fad2, 0xf5facb93, + 0x7262d75c, 0x6b79e61d, 0x4054b5de, 0x594f849f, 0x160e1258, 0x0f152319, 0x243870da, 0x3d23419b, 0x65fd6ba7, 0x7ce65ae6, 0x57cb0925, + 0x4ed03864, 0x0191aea3, 0x188a9fe2, 0x33a7cc21, 0x2abcfd60, 0xad24e1af, 0xb43fd0ee, 0x9f12832d, 0x8609b26c, 0xc94824ab, 0xd05315ea, + 0xfb7e4629, 0xe2657768, 0x2f3f79f6, 0x362448b7, 0x1d091b74, 0x04122a35, 0x4b53bcf2, 0x52488db3, 0x7965de70, 0x607eef31, 0xe7e6f3fe, + 0xfefdc2bf, 0xd5d0917c, 0xcccba03d, 0x838a36fa, 0x9a9107bb, 0xb1bc5478, 0xa8a76539, 0x3b83984b, 0x2298a90a, 0x09b5fac9, 0x10aecb88, + 0x5fef5d4f, 0x46f46c0e, 0x6dd93fcd, 0x74c20e8c, 0xf35a1243, 0xea412302, 0xc16c70c1, 0xd8774180, 0x9736d747, 0x8e2de606, 0xa500b5c5, + 0xbc1b8484, 0x71418a1a, 0x685abb5b, 0x4377e898, 0x5a6cd9d9, 0x152d4f1e, 0x0c367e5f, 0x271b2d9c, 0x3e001cdd, 0xb9980012, 0xa0833153, + 0x8bae6290, 0x92b553d1, 0xddf4c516, 0xc4eff457, 0xefc2a794, 0xf6d996d5, 0xae07bce9, 0xb71c8da8, 0x9c31de6b, 0x852aef2a, 0xca6b79ed, + 0xd37048ac, 0xf85d1b6f, 0xe1462a2e, 0x66de36e1, 0x7fc507a0, 0x54e85463, 0x4df36522, 0x02b2f3e5, 0x1ba9c2a4, 0x30849167, 0x299fa026, + 0xe4c5aeb8, 0xfdde9ff9, 0xd6f3cc3a, 0xcfe8fd7b, 0x80a96bbc, 0x99b25afd, 0xb29f093e, 0xab84387f, 0x2c1c24b0, 0x350715f1, 0x1e2a4632, + 0x07317773, 0x4870e1b4, 0x516bd0f5, 0x7a468336, 0x635db277, 0xcbfad74e, 0xd2e1e60f, 0xf9ccb5cc, 0xe0d7848d, 0xaf96124a, 0xb68d230b, + 0x9da070c8, 0x84bb4189, 0x03235d46, 0x1a386c07, 0x31153fc4, 0x280e0e85, 0x674f9842, 0x7e54a903, 0x5579fac0, 0x4c62cb81, 0x8138c51f, + 0x9823f45e, 0xb30ea79d, 0xaa1596dc, 0xe554001b, 0xfc4f315a, 0xd7626299, 0xce7953d8, 0x49e14f17, 0x50fa7e56, 0x7bd72d95, 0x62cc1cd4, + 0x2d8d8a13, 0x3496bb52, 0x1fbbe891, 0x06a0d9d0, 0x5e7ef3ec, 0x4765c2ad, 0x6c48916e, 0x7553a02f, 0x3a1236e8, 0x230907a9, 0x0824546a, + 0x113f652b, 0x96a779e4, 0x8fbc48a5, 0xa4911b66, 0xbd8a2a27, 0xf2cbbce0, 0xebd08da1, 0xc0fdde62, 0xd9e6ef23, 0x14bce1bd, 0x0da7d0fc, + 0x268a833f, 0x3f91b27e, 0x70d024b9, 0x69cb15f8, 0x42e6463b, 0x5bfd777a, 0xdc656bb5, 0xc57e5af4, 0xee530937, 0xf7483876, 0xb809aeb1, + 0xa1129ff0, 0x8a3fcc33, 0x9324fd72}, + {0x00000000, 0x01c26a37, 0x0384d46e, 0x0246be59, 0x0709a8dc, 0x06cbc2eb, 0x048d7cb2, 0x054f1685, 0x0e1351b8, 0x0fd13b8f, 0x0d9785d6, + 0x0c55efe1, 0x091af964, 0x08d89353, 0x0a9e2d0a, 0x0b5c473d, 0x1c26a370, 0x1de4c947, 0x1fa2771e, 0x1e601d29, 0x1b2f0bac, 0x1aed619b, + 0x18abdfc2, 0x1969b5f5, 0x1235f2c8, 0x13f798ff, 0x11b126a6, 0x10734c91, 0x153c5a14, 0x14fe3023, 0x16b88e7a, 0x177ae44d, 0x384d46e0, + 0x398f2cd7, 0x3bc9928e, 0x3a0bf8b9, 0x3f44ee3c, 0x3e86840b, 0x3cc03a52, 0x3d025065, 0x365e1758, 0x379c7d6f, 0x35dac336, 0x3418a901, + 0x3157bf84, 0x3095d5b3, 0x32d36bea, 0x331101dd, 0x246be590, 0x25a98fa7, 0x27ef31fe, 0x262d5bc9, 0x23624d4c, 0x22a0277b, 0x20e69922, + 0x2124f315, 0x2a78b428, 0x2bbade1f, 0x29fc6046, 0x283e0a71, 0x2d711cf4, 0x2cb376c3, 0x2ef5c89a, 0x2f37a2ad, 0x709a8dc0, 0x7158e7f7, + 0x731e59ae, 0x72dc3399, 0x7793251c, 0x76514f2b, 0x7417f172, 0x75d59b45, 0x7e89dc78, 0x7f4bb64f, 0x7d0d0816, 0x7ccf6221, 0x798074a4, + 0x78421e93, 0x7a04a0ca, 0x7bc6cafd, 0x6cbc2eb0, 0x6d7e4487, 0x6f38fade, 0x6efa90e9, 0x6bb5866c, 0x6a77ec5b, 0x68315202, 0x69f33835, + 0x62af7f08, 0x636d153f, 0x612bab66, 0x60e9c151, 0x65a6d7d4, 0x6464bde3, 0x662203ba, 0x67e0698d, 0x48d7cb20, 0x4915a117, 0x4b531f4e, + 0x4a917579, 0x4fde63fc, 0x4e1c09cb, 0x4c5ab792, 0x4d98dda5, 0x46c49a98, 0x4706f0af, 0x45404ef6, 0x448224c1, 0x41cd3244, 0x400f5873, + 0x4249e62a, 0x438b8c1d, 0x54f16850, 0x55330267, 0x5775bc3e, 0x56b7d609, 0x53f8c08c, 0x523aaabb, 0x507c14e2, 0x51be7ed5, 0x5ae239e8, + 0x5b2053df, 0x5966ed86, 0x58a487b1, 0x5deb9134, 0x5c29fb03, 0x5e6f455a, 0x5fad2f6d, 0xe1351b80, 0xe0f771b7, 0xe2b1cfee, 0xe373a5d9, + 0xe63cb35c, 0xe7fed96b, 0xe5b86732, 0xe47a0d05, 0xef264a38, 0xeee4200f, 0xeca29e56, 0xed60f461, 0xe82fe2e4, 0xe9ed88d3, 0xebab368a, + 0xea695cbd, 0xfd13b8f0, 0xfcd1d2c7, 0xfe976c9e, 0xff5506a9, 0xfa1a102c, 0xfbd87a1b, 0xf99ec442, 0xf85cae75, 0xf300e948, 0xf2c2837f, + 0xf0843d26, 0xf1465711, 0xf4094194, 0xf5cb2ba3, 0xf78d95fa, 0xf64fffcd, 0xd9785d60, 0xd8ba3757, 0xdafc890e, 0xdb3ee339, 0xde71f5bc, + 0xdfb39f8b, 0xddf521d2, 0xdc374be5, 0xd76b0cd8, 0xd6a966ef, 0xd4efd8b6, 0xd52db281, 0xd062a404, 0xd1a0ce33, 0xd3e6706a, 0xd2241a5d, + 0xc55efe10, 0xc49c9427, 0xc6da2a7e, 0xc7184049, 0xc25756cc, 0xc3953cfb, 0xc1d382a2, 0xc011e895, 0xcb4dafa8, 0xca8fc59f, 0xc8c97bc6, + 0xc90b11f1, 0xcc440774, 0xcd866d43, 0xcfc0d31a, 0xce02b92d, 0x91af9640, 0x906dfc77, 0x922b422e, 0x93e92819, 0x96a63e9c, 0x976454ab, + 0x9522eaf2, 0x94e080c5, 0x9fbcc7f8, 0x9e7eadcf, 0x9c381396, 0x9dfa79a1, 0x98b56f24, 0x99770513, 0x9b31bb4a, 0x9af3d17d, 0x8d893530, + 0x8c4b5f07, 0x8e0de15e, 0x8fcf8b69, 0x8a809dec, 0x8b42f7db, 0x89044982, 0x88c623b5, 0x839a6488, 0x82580ebf, 0x801eb0e6, 0x81dcdad1, + 0x8493cc54, 0x8551a663, 0x8717183a, 0x86d5720d, 0xa9e2d0a0, 0xa820ba97, 0xaa6604ce, 0xaba46ef9, 0xaeeb787c, 0xaf29124b, 0xad6fac12, + 0xacadc625, 0xa7f18118, 0xa633eb2f, 0xa4755576, 0xa5b73f41, 0xa0f829c4, 0xa13a43f3, 0xa37cfdaa, 0xa2be979d, 0xb5c473d0, 0xb40619e7, + 0xb640a7be, 0xb782cd89, 0xb2cddb0c, 0xb30fb13b, 0xb1490f62, 0xb08b6555, 0xbbd72268, 0xba15485f, 0xb853f606, 0xb9919c31, 0xbcde8ab4, + 0xbd1ce083, 0xbf5a5eda, 0xbe9834ed}, + {0x00000000, 0xb8bc6765, 0xaa09c88b, 0x12b5afee, 0x8f629757, 0x37def032, 0x256b5fdc, 0x9dd738b9, 0xc5b428ef, 0x7d084f8a, 0x6fbde064, + 0xd7018701, 0x4ad6bfb8, 0xf26ad8dd, 0xe0df7733, 0x58631056, 0x5019579f, 0xe8a530fa, 0xfa109f14, 0x42acf871, 0xdf7bc0c8, 0x67c7a7ad, + 0x75720843, 0xcdce6f26, 0x95ad7f70, 0x2d111815, 0x3fa4b7fb, 0x8718d09e, 0x1acfe827, 0xa2738f42, 0xb0c620ac, 0x087a47c9, 0xa032af3e, + 0x188ec85b, 0x0a3b67b5, 0xb28700d0, 0x2f503869, 0x97ec5f0c, 0x8559f0e2, 0x3de59787, 0x658687d1, 0xdd3ae0b4, 0xcf8f4f5a, 0x7733283f, + 0xeae41086, 0x525877e3, 0x40edd80d, 0xf851bf68, 0xf02bf8a1, 0x48979fc4, 0x5a22302a, 0xe29e574f, 0x7f496ff6, 0xc7f50893, 0xd540a77d, + 0x6dfcc018, 0x359fd04e, 0x8d23b72b, 0x9f9618c5, 0x272a7fa0, 0xbafd4719, 0x0241207c, 0x10f48f92, 0xa848e8f7, 0x9b14583d, 0x23a83f58, + 0x311d90b6, 0x89a1f7d3, 0x1476cf6a, 0xaccaa80f, 0xbe7f07e1, 0x06c36084, 0x5ea070d2, 0xe61c17b7, 0xf4a9b859, 0x4c15df3c, 0xd1c2e785, + 0x697e80e0, 0x7bcb2f0e, 0xc377486b, 0xcb0d0fa2, 0x73b168c7, 0x6104c729, 0xd9b8a04c, 0x446f98f5, 0xfcd3ff90, 0xee66507e, 0x56da371b, + 0x0eb9274d, 0xb6054028, 0xa4b0efc6, 0x1c0c88a3, 0x81dbb01a, 0x3967d77f, 0x2bd27891, 0x936e1ff4, 0x3b26f703, 0x839a9066, 0x912f3f88, + 0x299358ed, 0xb4446054, 0x0cf80731, 0x1e4da8df, 0xa6f1cfba, 0xfe92dfec, 0x462eb889, 0x549b1767, 0xec277002, 0x71f048bb, 0xc94c2fde, + 0xdbf98030, 0x6345e755, 0x6b3fa09c, 0xd383c7f9, 0xc1366817, 0x798a0f72, 0xe45d37cb, 0x5ce150ae, 0x4e54ff40, 0xf6e89825, 0xae8b8873, + 0x1637ef16, 0x048240f8, 0xbc3e279d, 0x21e91f24, 0x99557841, 0x8be0d7af, 0x335cb0ca, 0xed59b63b, 0x55e5d15e, 0x47507eb0, 0xffec19d5, + 0x623b216c, 0xda874609, 0xc832e9e7, 0x708e8e82, 0x28ed9ed4, 0x9051f9b1, 0x82e4565f, 0x3a58313a, 0xa78f0983, 0x1f336ee6, 0x0d86c108, + 0xb53aa66d, 0xbd40e1a4, 0x05fc86c1, 0x1749292f, 0xaff54e4a, 0x322276f3, 0x8a9e1196, 0x982bbe78, 0x2097d91d, 0x78f4c94b, 0xc048ae2e, + 0xd2fd01c0, 0x6a4166a5, 0xf7965e1c, 0x4f2a3979, 0x5d9f9697, 0xe523f1f2, 0x4d6b1905, 0xf5d77e60, 0xe762d18e, 0x5fdeb6eb, 0xc2098e52, + 0x7ab5e937, 0x680046d9, 0xd0bc21bc, 0x88df31ea, 0x3063568f, 0x22d6f961, 0x9a6a9e04, 0x07bda6bd, 0xbf01c1d8, 0xadb46e36, 0x15080953, + 0x1d724e9a, 0xa5ce29ff, 0xb77b8611, 0x0fc7e174, 0x9210d9cd, 0x2aacbea8, 0x38191146, 0x80a57623, 0xd8c66675, 0x607a0110, 0x72cfaefe, + 0xca73c99b, 0x57a4f122, 0xef189647, 0xfdad39a9, 0x45115ecc, 0x764dee06, 0xcef18963, 0xdc44268d, 0x64f841e8, 0xf92f7951, 0x41931e34, + 0x5326b1da, 0xeb9ad6bf, 0xb3f9c6e9, 0x0b45a18c, 0x19f00e62, 0xa14c6907, 0x3c9b51be, 0x842736db, 0x96929935, 0x2e2efe50, 0x2654b999, + 0x9ee8defc, 0x8c5d7112, 0x34e11677, 0xa9362ece, 0x118a49ab, 0x033fe645, 0xbb838120, 0xe3e09176, 0x5b5cf613, 0x49e959fd, 0xf1553e98, + 0x6c820621, 0xd43e6144, 0xc68bceaa, 0x7e37a9cf, 0xd67f4138, 0x6ec3265d, 0x7c7689b3, 0xc4caeed6, 0x591dd66f, 0xe1a1b10a, 0xf3141ee4, + 0x4ba87981, 0x13cb69d7, 0xab770eb2, 0xb9c2a15c, 0x017ec639, 0x9ca9fe80, 0x241599e5, 0x36a0360b, 0x8e1c516e, 0x866616a7, 0x3eda71c2, + 0x2c6fde2c, 0x94d3b949, 0x090481f0, 0xb1b8e695, 0xa30d497b, 0x1bb12e1e, 0x43d23e48, 0xfb6e592d, 0xe9dbf6c3, 0x516791a6, 0xccb0a91f, + 0x740cce7a, 0x66b96194, 0xde0506f1}, + {0x00000000, 0x3d6029b0, 0x7ac05360, 0x47a07ad0, 0xf580a6c0, 0xc8e08f70, 0x8f40f5a0, 0xb220dc10, 0x30704bc1, 0x0d106271, 0x4ab018a1, + 0x77d03111, 0xc5f0ed01, 0xf890c4b1, 0xbf30be61, 0x825097d1, 0x60e09782, 0x5d80be32, 0x1a20c4e2, 0x2740ed52, 0x95603142, 0xa80018f2, + 0xefa06222, 0xd2c04b92, 0x5090dc43, 0x6df0f5f3, 0x2a508f23, 0x1730a693, 0xa5107a83, 0x98705333, 0xdfd029e3, 0xe2b00053, 0xc1c12f04, + 0xfca106b4, 0xbb017c64, 0x866155d4, 0x344189c4, 0x0921a074, 0x4e81daa4, 0x73e1f314, 0xf1b164c5, 0xccd14d75, 0x8b7137a5, 0xb6111e15, + 0x0431c205, 0x3951ebb5, 0x7ef19165, 0x4391b8d5, 0xa121b886, 0x9c419136, 0xdbe1ebe6, 0xe681c256, 0x54a11e46, 0x69c137f6, 0x2e614d26, + 0x13016496, 0x9151f347, 0xac31daf7, 0xeb91a027, 0xd6f18997, 0x64d15587, 0x59b17c37, 0x1e1106e7, 0x23712f57, 0x58f35849, 0x659371f9, + 0x22330b29, 0x1f532299, 0xad73fe89, 0x9013d739, 0xd7b3ade9, 0xead38459, 0x68831388, 0x55e33a38, 0x124340e8, 0x2f236958, 0x9d03b548, + 0xa0639cf8, 0xe7c3e628, 0xdaa3cf98, 0x3813cfcb, 0x0573e67b, 0x42d39cab, 0x7fb3b51b, 0xcd93690b, 0xf0f340bb, 0xb7533a6b, 0x8a3313db, + 0x0863840a, 0x3503adba, 0x72a3d76a, 0x4fc3feda, 0xfde322ca, 0xc0830b7a, 0x872371aa, 0xba43581a, 0x9932774d, 0xa4525efd, 0xe3f2242d, + 0xde920d9d, 0x6cb2d18d, 0x51d2f83d, 0x167282ed, 0x2b12ab5d, 0xa9423c8c, 0x9422153c, 0xd3826fec, 0xeee2465c, 0x5cc29a4c, 0x61a2b3fc, + 0x2602c92c, 0x1b62e09c, 0xf9d2e0cf, 0xc4b2c97f, 0x8312b3af, 0xbe729a1f, 0x0c52460f, 0x31326fbf, 0x7692156f, 0x4bf23cdf, 0xc9a2ab0e, + 0xf4c282be, 0xb362f86e, 0x8e02d1de, 0x3c220dce, 0x0142247e, 0x46e25eae, 0x7b82771e, 0xb1e6b092, 0x8c869922, 0xcb26e3f2, 0xf646ca42, + 0x44661652, 0x79063fe2, 0x3ea64532, 0x03c66c82, 0x8196fb53, 0xbcf6d2e3, 0xfb56a833, 0xc6368183, 0x74165d93, 0x49767423, 0x0ed60ef3, + 0x33b62743, 0xd1062710, 0xec660ea0, 0xabc67470, 0x96a65dc0, 0x248681d0, 0x19e6a860, 0x5e46d2b0, 0x6326fb00, 0xe1766cd1, 0xdc164561, + 0x9bb63fb1, 0xa6d61601, 0x14f6ca11, 0x2996e3a1, 0x6e369971, 0x5356b0c1, 0x70279f96, 0x4d47b626, 0x0ae7ccf6, 0x3787e546, 0x85a73956, + 0xb8c710e6, 0xff676a36, 0xc2074386, 0x4057d457, 0x7d37fde7, 0x3a978737, 0x07f7ae87, 0xb5d77297, 0x88b75b27, 0xcf1721f7, 0xf2770847, + 0x10c70814, 0x2da721a4, 0x6a075b74, 0x576772c4, 0xe547aed4, 0xd8278764, 0x9f87fdb4, 0xa2e7d404, 0x20b743d5, 0x1dd76a65, 0x5a7710b5, + 0x67173905, 0xd537e515, 0xe857cca5, 0xaff7b675, 0x92979fc5, 0xe915e8db, 0xd475c16b, 0x93d5bbbb, 0xaeb5920b, 0x1c954e1b, 0x21f567ab, + 0x66551d7b, 0x5b3534cb, 0xd965a31a, 0xe4058aaa, 0xa3a5f07a, 0x9ec5d9ca, 0x2ce505da, 0x11852c6a, 0x562556ba, 0x6b457f0a, 0x89f57f59, + 0xb49556e9, 0xf3352c39, 0xce550589, 0x7c75d999, 0x4115f029, 0x06b58af9, 0x3bd5a349, 0xb9853498, 0x84e51d28, 0xc34567f8, 0xfe254e48, + 0x4c059258, 0x7165bbe8, 0x36c5c138, 0x0ba5e888, 0x28d4c7df, 0x15b4ee6f, 0x521494bf, 0x6f74bd0f, 0xdd54611f, 0xe03448af, 0xa794327f, + 0x9af41bcf, 0x18a48c1e, 0x25c4a5ae, 0x6264df7e, 0x5f04f6ce, 0xed242ade, 0xd044036e, 0x97e479be, 0xaa84500e, 0x4834505d, 0x755479ed, + 0x32f4033d, 0x0f942a8d, 0xbdb4f69d, 0x80d4df2d, 0xc774a5fd, 0xfa148c4d, 0x78441b9c, 0x4524322c, 0x028448fc, 0x3fe4614c, 0x8dc4bd5c, + 0xb0a494ec, 0xf704ee3c, 0xca64c78c}, + {0x00000000, 0xcb5cd3a5, 0x4dc8a10b, 0x869472ae, 0x9b914216, 0x50cd91b3, 0xd659e31d, 0x1d0530b8, 0xec53826d, 0x270f51c8, 0xa19b2366, + 0x6ac7f0c3, 0x77c2c07b, 0xbc9e13de, 0x3a0a6170, 0xf156b2d5, 0x03d6029b, 0xc88ad13e, 0x4e1ea390, 0x85427035, 0x9847408d, 0x531b9328, + 0xd58fe186, 0x1ed33223, 0xef8580f6, 0x24d95353, 0xa24d21fd, 0x6911f258, 0x7414c2e0, 0xbf481145, 0x39dc63eb, 0xf280b04e, 0x07ac0536, + 0xccf0d693, 0x4a64a43d, 0x81387798, 0x9c3d4720, 0x57619485, 0xd1f5e62b, 0x1aa9358e, 0xebff875b, 0x20a354fe, 0xa6372650, 0x6d6bf5f5, + 0x706ec54d, 0xbb3216e8, 0x3da66446, 0xf6fab7e3, 0x047a07ad, 0xcf26d408, 0x49b2a6a6, 0x82ee7503, 0x9feb45bb, 0x54b7961e, 0xd223e4b0, + 0x197f3715, 0xe82985c0, 0x23755665, 0xa5e124cb, 0x6ebdf76e, 0x73b8c7d6, 0xb8e41473, 0x3e7066dd, 0xf52cb578, 0x0f580a6c, 0xc404d9c9, + 0x4290ab67, 0x89cc78c2, 0x94c9487a, 0x5f959bdf, 0xd901e971, 0x125d3ad4, 0xe30b8801, 0x28575ba4, 0xaec3290a, 0x659ffaaf, 0x789aca17, + 0xb3c619b2, 0x35526b1c, 0xfe0eb8b9, 0x0c8e08f7, 0xc7d2db52, 0x4146a9fc, 0x8a1a7a59, 0x971f4ae1, 0x5c439944, 0xdad7ebea, 0x118b384f, + 0xe0dd8a9a, 0x2b81593f, 0xad152b91, 0x6649f834, 0x7b4cc88c, 0xb0101b29, 0x36846987, 0xfdd8ba22, 0x08f40f5a, 0xc3a8dcff, 0x453cae51, + 0x8e607df4, 0x93654d4c, 0x58399ee9, 0xdeadec47, 0x15f13fe2, 0xe4a78d37, 0x2ffb5e92, 0xa96f2c3c, 0x6233ff99, 0x7f36cf21, 0xb46a1c84, + 0x32fe6e2a, 0xf9a2bd8f, 0x0b220dc1, 0xc07ede64, 0x46eaacca, 0x8db67f6f, 0x90b34fd7, 0x5bef9c72, 0xdd7beedc, 0x16273d79, 0xe7718fac, + 0x2c2d5c09, 0xaab92ea7, 0x61e5fd02, 0x7ce0cdba, 0xb7bc1e1f, 0x31286cb1, 0xfa74bf14, 0x1eb014d8, 0xd5ecc77d, 0x5378b5d3, 0x98246676, + 0x852156ce, 0x4e7d856b, 0xc8e9f7c5, 0x03b52460, 0xf2e396b5, 0x39bf4510, 0xbf2b37be, 0x7477e41b, 0x6972d4a3, 0xa22e0706, 0x24ba75a8, + 0xefe6a60d, 0x1d661643, 0xd63ac5e6, 0x50aeb748, 0x9bf264ed, 0x86f75455, 0x4dab87f0, 0xcb3ff55e, 0x006326fb, 0xf135942e, 0x3a69478b, + 0xbcfd3525, 0x77a1e680, 0x6aa4d638, 0xa1f8059d, 0x276c7733, 0xec30a496, 0x191c11ee, 0xd240c24b, 0x54d4b0e5, 0x9f886340, 0x828d53f8, + 0x49d1805d, 0xcf45f2f3, 0x04192156, 0xf54f9383, 0x3e134026, 0xb8873288, 0x73dbe12d, 0x6eded195, 0xa5820230, 0x2316709e, 0xe84aa33b, + 0x1aca1375, 0xd196c0d0, 0x5702b27e, 0x9c5e61db, 0x815b5163, 0x4a0782c6, 0xcc93f068, 0x07cf23cd, 0xf6999118, 0x3dc542bd, 0xbb513013, + 0x700de3b6, 0x6d08d30e, 0xa65400ab, 0x20c07205, 0xeb9ca1a0, 0x11e81eb4, 0xdab4cd11, 0x5c20bfbf, 0x977c6c1a, 0x8a795ca2, 0x41258f07, + 0xc7b1fda9, 0x0ced2e0c, 0xfdbb9cd9, 0x36e74f7c, 0xb0733dd2, 0x7b2fee77, 0x662adecf, 0xad760d6a, 0x2be27fc4, 0xe0beac61, 0x123e1c2f, + 0xd962cf8a, 0x5ff6bd24, 0x94aa6e81, 0x89af5e39, 0x42f38d9c, 0xc467ff32, 0x0f3b2c97, 0xfe6d9e42, 0x35314de7, 0xb3a53f49, 0x78f9ecec, + 0x65fcdc54, 0xaea00ff1, 0x28347d5f, 0xe368aefa, 0x16441b82, 0xdd18c827, 0x5b8cba89, 0x90d0692c, 0x8dd55994, 0x46898a31, 0xc01df89f, + 0x0b412b3a, 0xfa1799ef, 0x314b4a4a, 0xb7df38e4, 0x7c83eb41, 0x6186dbf9, 0xaada085c, 0x2c4e7af2, 0xe712a957, 0x15921919, 0xdececabc, + 0x585ab812, 0x93066bb7, 0x8e035b0f, 0x455f88aa, 0xc3cbfa04, 0x089729a1, 0xf9c19b74, 0x329d48d1, 0xb4093a7f, 0x7f55e9da, 0x6250d962, + 0xa90c0ac7, 0x2f987869, 0xe4c4abcc}, + {0x00000000, 0xa6770bb4, 0x979f1129, 0x31e81a9d, 0xf44f2413, 0x52382fa7, 0x63d0353a, 0xc5a73e8e, 0x33ef4e67, 0x959845d3, 0xa4705f4e, + 0x020754fa, 0xc7a06a74, 0x61d761c0, 0x503f7b5d, 0xf64870e9, 0x67de9cce, 0xc1a9977a, 0xf0418de7, 0x56368653, 0x9391b8dd, 0x35e6b369, + 0x040ea9f4, 0xa279a240, 0x5431d2a9, 0xf246d91d, 0xc3aec380, 0x65d9c834, 0xa07ef6ba, 0x0609fd0e, 0x37e1e793, 0x9196ec27, 0xcfbd399c, + 0x69ca3228, 0x582228b5, 0xfe552301, 0x3bf21d8f, 0x9d85163b, 0xac6d0ca6, 0x0a1a0712, 0xfc5277fb, 0x5a257c4f, 0x6bcd66d2, 0xcdba6d66, + 0x081d53e8, 0xae6a585c, 0x9f8242c1, 0x39f54975, 0xa863a552, 0x0e14aee6, 0x3ffcb47b, 0x998bbfcf, 0x5c2c8141, 0xfa5b8af5, 0xcbb39068, + 0x6dc49bdc, 0x9b8ceb35, 0x3dfbe081, 0x0c13fa1c, 0xaa64f1a8, 0x6fc3cf26, 0xc9b4c492, 0xf85cde0f, 0x5e2bd5bb, 0x440b7579, 0xe27c7ecd, + 0xd3946450, 0x75e36fe4, 0xb044516a, 0x16335ade, 0x27db4043, 0x81ac4bf7, 0x77e43b1e, 0xd19330aa, 0xe07b2a37, 0x460c2183, 0x83ab1f0d, + 0x25dc14b9, 0x14340e24, 0xb2430590, 0x23d5e9b7, 0x85a2e203, 0xb44af89e, 0x123df32a, 0xd79acda4, 0x71edc610, 0x4005dc8d, 0xe672d739, + 0x103aa7d0, 0xb64dac64, 0x87a5b6f9, 0x21d2bd4d, 0xe47583c3, 0x42028877, 0x73ea92ea, 0xd59d995e, 0x8bb64ce5, 0x2dc14751, 0x1c295dcc, + 0xba5e5678, 0x7ff968f6, 0xd98e6342, 0xe86679df, 0x4e11726b, 0xb8590282, 0x1e2e0936, 0x2fc613ab, 0x89b1181f, 0x4c162691, 0xea612d25, + 0xdb8937b8, 0x7dfe3c0c, 0xec68d02b, 0x4a1fdb9f, 0x7bf7c102, 0xdd80cab6, 0x1827f438, 0xbe50ff8c, 0x8fb8e511, 0x29cfeea5, 0xdf879e4c, + 0x79f095f8, 0x48188f65, 0xee6f84d1, 0x2bc8ba5f, 0x8dbfb1eb, 0xbc57ab76, 0x1a20a0c2, 0x8816eaf2, 0x2e61e146, 0x1f89fbdb, 0xb9fef06f, + 0x7c59cee1, 0xda2ec555, 0xebc6dfc8, 0x4db1d47c, 0xbbf9a495, 0x1d8eaf21, 0x2c66b5bc, 0x8a11be08, 0x4fb68086, 0xe9c18b32, 0xd82991af, + 0x7e5e9a1b, 0xefc8763c, 0x49bf7d88, 0x78576715, 0xde206ca1, 0x1b87522f, 0xbdf0599b, 0x8c184306, 0x2a6f48b2, 0xdc27385b, 0x7a5033ef, + 0x4bb82972, 0xedcf22c6, 0x28681c48, 0x8e1f17fc, 0xbff70d61, 0x198006d5, 0x47abd36e, 0xe1dcd8da, 0xd034c247, 0x7643c9f3, 0xb3e4f77d, + 0x1593fcc9, 0x247be654, 0x820cede0, 0x74449d09, 0xd23396bd, 0xe3db8c20, 0x45ac8794, 0x800bb91a, 0x267cb2ae, 0x1794a833, 0xb1e3a387, + 0x20754fa0, 0x86024414, 0xb7ea5e89, 0x119d553d, 0xd43a6bb3, 0x724d6007, 0x43a57a9a, 0xe5d2712e, 0x139a01c7, 0xb5ed0a73, 0x840510ee, + 0x22721b5a, 0xe7d525d4, 0x41a22e60, 0x704a34fd, 0xd63d3f49, 0xcc1d9f8b, 0x6a6a943f, 0x5b828ea2, 0xfdf58516, 0x3852bb98, 0x9e25b02c, + 0xafcdaab1, 0x09baa105, 0xfff2d1ec, 0x5985da58, 0x686dc0c5, 0xce1acb71, 0x0bbdf5ff, 0xadcafe4b, 0x9c22e4d6, 0x3a55ef62, 0xabc30345, + 0x0db408f1, 0x3c5c126c, 0x9a2b19d8, 0x5f8c2756, 0xf9fb2ce2, 0xc813367f, 0x6e643dcb, 0x982c4d22, 0x3e5b4696, 0x0fb35c0b, 0xa9c457bf, + 0x6c636931, 0xca146285, 0xfbfc7818, 0x5d8b73ac, 0x03a0a617, 0xa5d7ada3, 0x943fb73e, 0x3248bc8a, 0xf7ef8204, 0x519889b0, 0x6070932d, + 0xc6079899, 0x304fe870, 0x9638e3c4, 0xa7d0f959, 0x01a7f2ed, 0xc400cc63, 0x6277c7d7, 0x539fdd4a, 0xf5e8d6fe, 0x647e3ad9, 0xc209316d, + 0xf3e12bf0, 0x55962044, 0x90311eca, 0x3646157e, 0x07ae0fe3, 0xa1d90457, 0x579174be, 0xf1e67f0a, 0xc00e6597, 0x66796e23, 0xa3de50ad, + 0x05a95b19, 0x34414184, 0x92364a30}, + {0x00000000, 0xccaa009e, 0x4225077d, 0x8e8f07e3, 0x844a0efa, 0x48e00e64, 0xc66f0987, 0x0ac50919, 0xd3e51bb5, 0x1f4f1b2b, 0x91c01cc8, + 0x5d6a1c56, 0x57af154f, 0x9b0515d1, 0x158a1232, 0xd92012ac, 0x7cbb312b, 0xb01131b5, 0x3e9e3656, 0xf23436c8, 0xf8f13fd1, 0x345b3f4f, + 0xbad438ac, 0x767e3832, 0xaf5e2a9e, 0x63f42a00, 0xed7b2de3, 0x21d12d7d, 0x2b142464, 0xe7be24fa, 0x69312319, 0xa59b2387, 0xf9766256, + 0x35dc62c8, 0xbb53652b, 0x77f965b5, 0x7d3c6cac, 0xb1966c32, 0x3f196bd1, 0xf3b36b4f, 0x2a9379e3, 0xe639797d, 0x68b67e9e, 0xa41c7e00, + 0xaed97719, 0x62737787, 0xecfc7064, 0x205670fa, 0x85cd537d, 0x496753e3, 0xc7e85400, 0x0b42549e, 0x01875d87, 0xcd2d5d19, 0x43a25afa, + 0x8f085a64, 0x562848c8, 0x9a824856, 0x140d4fb5, 0xd8a74f2b, 0xd2624632, 0x1ec846ac, 0x9047414f, 0x5ced41d1, 0x299dc2ed, 0xe537c273, + 0x6bb8c590, 0xa712c50e, 0xadd7cc17, 0x617dcc89, 0xeff2cb6a, 0x2358cbf4, 0xfa78d958, 0x36d2d9c6, 0xb85dde25, 0x74f7debb, 0x7e32d7a2, + 0xb298d73c, 0x3c17d0df, 0xf0bdd041, 0x5526f3c6, 0x998cf358, 0x1703f4bb, 0xdba9f425, 0xd16cfd3c, 0x1dc6fda2, 0x9349fa41, 0x5fe3fadf, + 0x86c3e873, 0x4a69e8ed, 0xc4e6ef0e, 0x084cef90, 0x0289e689, 0xce23e617, 0x40ace1f4, 0x8c06e16a, 0xd0eba0bb, 0x1c41a025, 0x92cea7c6, + 0x5e64a758, 0x54a1ae41, 0x980baedf, 0x1684a93c, 0xda2ea9a2, 0x030ebb0e, 0xcfa4bb90, 0x412bbc73, 0x8d81bced, 0x8744b5f4, 0x4beeb56a, + 0xc561b289, 0x09cbb217, 0xac509190, 0x60fa910e, 0xee7596ed, 0x22df9673, 0x281a9f6a, 0xe4b09ff4, 0x6a3f9817, 0xa6959889, 0x7fb58a25, + 0xb31f8abb, 0x3d908d58, 0xf13a8dc6, 0xfbff84df, 0x37558441, 0xb9da83a2, 0x7570833c, 0x533b85da, 0x9f918544, 0x111e82a7, 0xddb48239, + 0xd7718b20, 0x1bdb8bbe, 0x95548c5d, 0x59fe8cc3, 0x80de9e6f, 0x4c749ef1, 0xc2fb9912, 0x0e51998c, 0x04949095, 0xc83e900b, 0x46b197e8, + 0x8a1b9776, 0x2f80b4f1, 0xe32ab46f, 0x6da5b38c, 0xa10fb312, 0xabcaba0b, 0x6760ba95, 0xe9efbd76, 0x2545bde8, 0xfc65af44, 0x30cfafda, + 0xbe40a839, 0x72eaa8a7, 0x782fa1be, 0xb485a120, 0x3a0aa6c3, 0xf6a0a65d, 0xaa4de78c, 0x66e7e712, 0xe868e0f1, 0x24c2e06f, 0x2e07e976, + 0xe2ade9e8, 0x6c22ee0b, 0xa088ee95, 0x79a8fc39, 0xb502fca7, 0x3b8dfb44, 0xf727fbda, 0xfde2f2c3, 0x3148f25d, 0xbfc7f5be, 0x736df520, + 0xd6f6d6a7, 0x1a5cd639, 0x94d3d1da, 0x5879d144, 0x52bcd85d, 0x9e16d8c3, 0x1099df20, 0xdc33dfbe, 0x0513cd12, 0xc9b9cd8c, 0x4736ca6f, + 0x8b9ccaf1, 0x8159c3e8, 0x4df3c376, 0xc37cc495, 0x0fd6c40b, 0x7aa64737, 0xb60c47a9, 0x3883404a, 0xf42940d4, 0xfeec49cd, 0x32464953, + 0xbcc94eb0, 0x70634e2e, 0xa9435c82, 0x65e95c1c, 0xeb665bff, 0x27cc5b61, 0x2d095278, 0xe1a352e6, 0x6f2c5505, 0xa386559b, 0x061d761c, + 0xcab77682, 0x44387161, 0x889271ff, 0x825778e6, 0x4efd7878, 0xc0727f9b, 0x0cd87f05, 0xd5f86da9, 0x19526d37, 0x97dd6ad4, 0x5b776a4a, + 0x51b26353, 0x9d1863cd, 0x1397642e, 0xdf3d64b0, 0x83d02561, 0x4f7a25ff, 0xc1f5221c, 0x0d5f2282, 0x079a2b9b, 0xcb302b05, 0x45bf2ce6, + 0x89152c78, 0x50353ed4, 0x9c9f3e4a, 0x121039a9, 0xdeba3937, 0xd47f302e, 0x18d530b0, 0x965a3753, 0x5af037cd, 0xff6b144a, 0x33c114d4, + 0xbd4e1337, 0x71e413a9, 0x7b211ab0, 0xb78b1a2e, 0x39041dcd, 0xf5ae1d53, 0x2c8e0fff, 0xe0240f61, 0x6eab0882, 0xa201081c, 0xa8c40105, + 0x646e019b, 0xeae10678, 0x264b06e6}}; + +#define BYTESWAP_ORDER32(x) (((x) >> 24) + (((x) >> 8) & 0xff00) + (((x) << 8) & 0xff0000) + ((x) << 24)) +#define UE_PTRDIFF_TO_UINT32(argument) static_cast<uint32_t>(argument) + +template<typename T> +constexpr T +Align(T Val, uint64_t Alignment) +{ + return (T)(((uint64_t)Val + Alignment - 1) & ~(Alignment - 1)); +} + +} // namespace CRC32 + +namespace zen { + +uint32_t +StrCrc_Deprecated(const char* Data) +{ + using namespace CRC32; + + uint32_t CRC = 0xFFFFFFFF; + while (*Data) + { + char16_t C = *Data++; + int32_t CL = (C & 255); + CRC = (CRC << 8) ^ CRCTable_DEPRECATED[(CRC >> 24) ^ CL]; + int32_t CH = (C >> 8) & 255; + CRC = (CRC << 8) ^ CRCTable_DEPRECATED[(CRC >> 24) ^ CH]; + } + return ~CRC; +} + +uint32_t +MemCrc32(const void* InData, size_t Length, uint32_t CRC /*=0 */) +{ + using namespace CRC32; + + // Based on the Slicing-by-8 implementation found here: + // http://slicing-by-8.sourceforge.net/ + + CRC = ~CRC; + + const uint8_t* __restrict Data = (uint8_t*)InData; + + // First we need to align to 32-bits + uint32_t InitBytes = UE_PTRDIFF_TO_UINT32(Align(Data, 4) - Data); + + if (Length > InitBytes) + { + Length -= InitBytes; + + for (; InitBytes; --InitBytes) + { + CRC = (CRC >> 8) ^ CRCTablesSB8[0][(CRC & 0xFF) ^ *Data++]; + } + + auto Data4 = (const uint32_t*)Data; + for (size_t Repeat = Length / 8; Repeat; --Repeat) + { + uint32_t V1 = *Data4++ ^ CRC; + uint32_t V2 = *Data4++; + CRC = CRCTablesSB8[7][V1 & 0xFF] ^ CRCTablesSB8[6][(V1 >> 8) & 0xFF] ^ CRCTablesSB8[5][(V1 >> 16) & 0xFF] ^ + CRCTablesSB8[4][V1 >> 24] ^ CRCTablesSB8[3][V2 & 0xFF] ^ CRCTablesSB8[2][(V2 >> 8) & 0xFF] ^ + CRCTablesSB8[1][(V2 >> 16) & 0xFF] ^ CRCTablesSB8[0][V2 >> 24]; + } + Data = (const uint8_t*)Data4; + + Length %= 8; + } + + for (; Length; --Length) + { + CRC = (CRC >> 8) ^ CRCTablesSB8[0][(CRC & 0xFF) ^ *Data++]; + } + + return ~CRC; +} + +uint32_t +MemCrc32_Deprecated(const void* InData, size_t Length, uint32_t CRC) +{ + using namespace CRC32; + + // Based on the Slicing-by-8 implementation found here: + // http://slicing-by-8.sourceforge.net/ + + CRC = ~BYTESWAP_ORDER32(CRC); + + const uint8_t* __restrict Data = (uint8_t*)InData; + + // First we need to align to 32-bits + uint32_t InitBytes = UE_PTRDIFF_TO_UINT32(Align(Data, 4) - Data); + + if (Length > InitBytes) + { + Length -= InitBytes; + + for (; InitBytes; --InitBytes) + { + CRC = (CRC >> 8) ^ CRCTablesSB8_DEPRECATED[0][(CRC & 0xFF) ^ *Data++]; + } + + auto Data4 = (const uint32_t*)Data; + for (size_t Repeat = Length / 8; Repeat; --Repeat) + { + uint32_t V1 = *Data4++ ^ CRC; + uint32_t V2 = *Data4++; + CRC = CRCTablesSB8_DEPRECATED[7][V1 & 0xFF] ^ CRCTablesSB8_DEPRECATED[6][(V1 >> 8) & 0xFF] ^ + CRCTablesSB8_DEPRECATED[5][(V1 >> 16) & 0xFF] ^ CRCTablesSB8_DEPRECATED[4][V1 >> 24] ^ + CRCTablesSB8_DEPRECATED[3][V2 & 0xFF] ^ CRCTablesSB8_DEPRECATED[2][(V2 >> 8) & 0xFF] ^ + CRCTablesSB8_DEPRECATED[1][(V2 >> 16) & 0xFF] ^ CRCTablesSB8_DEPRECATED[0][V2 >> 24]; + } + Data = (const uint8_t*)Data4; + + Length %= 8; + } + + for (; Length; --Length) + { + CRC = (CRC >> 8) ^ CRCTablesSB8_DEPRECATED[0][(CRC & 0xFF) ^ *Data++]; + } + + return BYTESWAP_ORDER32(~CRC); +} + +} // namespace zen diff --git a/src/zencore/crypto.cpp b/src/zencore/crypto.cpp new file mode 100644 index 000000000..448fd36fa --- /dev/null +++ b/src/zencore/crypto.cpp @@ -0,0 +1,208 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/crypto.h> +#include <zencore/intmath.h> +#include <zencore/testing.h> + +#include <string> +#include <string_view> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <openssl/conf.h> +#include <openssl/err.h> +#include <openssl/evp.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# pragma comment(lib, "crypt32.lib") +# pragma comment(lib, "ws2_32.lib") +#endif + +namespace zen { + +using namespace std::literals; + +namespace crypto { + + class EvpContext + { + public: + EvpContext() : m_Ctx(EVP_CIPHER_CTX_new()) {} + ~EvpContext() { EVP_CIPHER_CTX_free(m_Ctx); } + + operator EVP_CIPHER_CTX*() { return m_Ctx; } + + private: + EVP_CIPHER_CTX* m_Ctx; + }; + + enum class TransformMode : uint32_t + { + Decrypt, + Encrypt + }; + + MemoryView Transform(const EVP_CIPHER* Cipher, + TransformMode Mode, + MemoryView Key, + MemoryView IV, + MemoryView In, + MutableMemoryView Out, + std::optional<std::string>& Reason) + { + ZEN_ASSERT(Cipher != nullptr); + + EvpContext Ctx; + + int Err = EVP_CipherInit_ex(Ctx, + Cipher, + nullptr, + reinterpret_cast<const unsigned char*>(Key.GetData()), + reinterpret_cast<const unsigned char*>(IV.GetData()), + static_cast<int>(Mode)); + + if (Err != 1) + { + Reason = fmt::format("failed to initialize cipher, error code '{}'", Err); + + return MemoryView(); + } + + int EncryptedBytes = 0; + int TotalEncryptedBytes = 0; + + Err = EVP_CipherUpdate(Ctx, + reinterpret_cast<unsigned char*>(Out.GetData()), + &EncryptedBytes, + reinterpret_cast<const unsigned char*>(In.GetData()), + static_cast<int>(In.GetSize())); + + if (Err != 1) + { + Reason = fmt::format("update crypto transform failed, error code '{}'", Err); + + return MemoryView(); + } + + TotalEncryptedBytes = EncryptedBytes; + MutableMemoryView Remaining = Out.RightChop(EncryptedBytes); + + EncryptedBytes = static_cast<int>(Remaining.GetSize()); + + Err = EVP_CipherFinal(Ctx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &EncryptedBytes); + + if (Err != 1) + { + Reason = fmt::format("finalize crypto transform failed, error code '{}'", Err); + + return MemoryView(); + } + + TotalEncryptedBytes += EncryptedBytes; + + return Out.Left(TotalEncryptedBytes); + } + + bool ValidateKeyAndIV(const AesKey256Bit& Key, const AesIV128Bit& IV, std::optional<std::string>& Reason) + { + if (Key.IsValid() == false) + { + Reason = "invalid key"sv; + + return false; + } + + if (IV.IsValid() == false) + { + Reason = "invalid initialization vector"sv; + + return false; + } + + return true; + } + +} // namespace crypto + +MemoryView +Aes::Encrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional<std::string>& Reason) +{ + if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false) + { + return MemoryView(); + } + + return crypto::Transform(EVP_aes_256_cbc(), crypto::TransformMode::Encrypt, Key.GetView(), IV.GetView(), In, Out, Reason); +} + +MemoryView +Aes::Decrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional<std::string>& Reason) +{ + if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false) + { + return MemoryView(); + } + + return crypto::Transform(EVP_aes_256_cbc(), crypto::TransformMode::Decrypt, Key.GetView(), IV.GetView(), In, Out, Reason); +} + +#if ZEN_WITH_TESTS + +void +crypto_forcelink() +{ +} + +TEST_CASE("crypto.bits") +{ + using CryptoBits256Bit = CryptoBits<256>; + + CryptoBits256Bit Bits; + + CHECK(Bits.IsNull()); + CHECK(Bits.IsValid() == false); + + CHECK(Bits.GetBitCount() == 256); + CHECK(Bits.GetSize() == 32); + + Bits = CryptoBits256Bit::FromString("Addff"sv); + CHECK(Bits.IsValid() == false); + + Bits = CryptoBits256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + CHECK(Bits.IsValid()); + + auto SmallerBits = CryptoBits<128>::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + CHECK(SmallerBits.IsValid() == false); +} + +TEST_CASE("crypto.aes") +{ + SUBCASE("basic") + { + const uint8_t InitVector[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv); + const AesIV128Bit IV = AesIV128Bit::FromMemoryView(MakeMemoryView(InitVector)); + + std::string_view PlainText = "The quick brown fox jumps over the lazy dog"sv; + + std::vector<uint8_t> EncryptionBuffer; + std::vector<uint8_t> DecryptionBuffer; + std::optional<std::string> Reason; + + EncryptionBuffer.resize(PlainText.size() + Aes::BlockSize); + DecryptionBuffer.resize(PlainText.size() + Aes::BlockSize); + + MemoryView EncryptedView = Aes::Encrypt(Key, IV, MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer), Reason); + MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedView, MakeMutableMemoryView(DecryptionBuffer), Reason); + + std::string_view EncryptedDecryptedText = + std::string_view(reinterpret_cast<const char*>(DecryptedView.GetData()), DecryptedView.GetSize()); + + CHECK(EncryptedDecryptedText == PlainText); + } +} + +#endif + +} // namespace zen diff --git a/src/zencore/except.cpp b/src/zencore/except.cpp new file mode 100644 index 000000000..2749d1984 --- /dev/null +++ b/src/zencore/except.cpp @@ -0,0 +1,93 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <fmt/format.h> +#include <zencore/except.h> + +namespace zen { + +#if ZEN_PLATFORM_WINDOWS + +class WindowsException : public std::exception +{ +public: + WindowsException(std::string_view Message) + { + m_hResult = HRESULT_FROM_WIN32(GetLastError()); + m_Message = Message; + } + + WindowsException(HRESULT hRes, std::string_view Message) + { + m_hResult = hRes; + m_Message = Message; + } + + WindowsException(HRESULT hRes, const char* Message, const char* Detail) + { + m_hResult = hRes; + + ExtendableStringBuilder<128> msg; + msg.Append(Message); + msg.Append(" (detail: '"); + msg.Append(Detail); + msg.Append("')"); + + m_Message = msg.c_str(); + } + + virtual const char* what() const override { return m_Message.c_str(); } + +private: + std::string m_Message; + HRESULT m_hResult; +}; + +void +ThrowSystemException([[maybe_unused]] HRESULT hRes, [[maybe_unused]] std::string_view Message) +{ + if (HRESULT_FACILITY(hRes) == FACILITY_WIN32) + { + throw std::system_error(std::error_code(hRes & 0xffff, std::system_category()), std::string(Message)); + } + else + { + throw WindowsException(hRes, Message); + } +} + +#endif // ZEN_PLATFORM_WINDOWS + +void +ThrowSystemError(uint32_t ErrorCode, std::string_view Message) +{ + throw std::system_error(std::error_code(ErrorCode, std::system_category()), std::string(Message)); +} + +std::string +GetLastErrorAsString() +{ + return GetSystemErrorAsString(zen::GetLastError()); +} + +std::string +GetSystemErrorAsString(uint32_t ErrorCode) +{ + return std::error_code(ErrorCode, std::system_category()).message(); +} + +#if defined(__cpp_lib_source_location) +void +ThrowLastErrorImpl(std::string_view Message, const std::source_location& Location) +{ + throw std::system_error(std::error_code(zen::GetLastError(), std::system_category()), + fmt::format("{}({}): {}", Location.file_name(), Location.line(), Message)); +} +#else +void +ThrowLastError(std::string_view Message) +{ + throw std::system_error(std::error_code(zen::GetLastError(), std::system_category()), std::string(Message)); +} +#endif + +} // namespace zen diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp new file mode 100644 index 000000000..a17773024 --- /dev/null +++ b/src/zencore/filesystem.cpp @@ -0,0 +1,1304 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/filesystem.h> + +#include <zencore/except.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#endif + +#if ZEN_PLATFORM_WINDOWS +# include <atlbase.h> +# include <atlfile.h> +# include <winioctl.h> +# include <winnt.h> +#endif + +#if ZEN_PLATFORM_LINUX +# include <dirent.h> +# include <fcntl.h> +# include <sys/resource.h> +# include <sys/stat.h> +# include <unistd.h> +#endif + +#if ZEN_PLATFORM_MAC +# include <dirent.h> +# include <fcntl.h> +# include <libproc.h> +# include <sys/resource.h> +# include <sys/stat.h> +# include <sys/syslimits.h> +# include <unistd.h> +#endif + +#include <filesystem> +#include <gsl/gsl-lite.hpp> + +namespace zen { + +using namespace std::literals; + +#if ZEN_PLATFORM_WINDOWS + +static bool +DeleteReparsePoint(const wchar_t* Path, DWORD dwReparseTag) +{ + CHandle hDir(CreateFileW(Path, + GENERIC_WRITE, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + nullptr, + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS | FILE_FLAG_OPEN_REPARSE_POINT, + nullptr)); + + if (hDir != INVALID_HANDLE_VALUE) + { + REPARSE_GUID_DATA_BUFFER Rgdb = {}; + Rgdb.ReparseTag = dwReparseTag; + + DWORD dwBytes; + const BOOL bOK = + DeviceIoControl(hDir, FSCTL_DELETE_REPARSE_POINT, &Rgdb, REPARSE_GUID_DATA_BUFFER_HEADER_SIZE, nullptr, 0, &dwBytes, nullptr); + + return bOK == TRUE; + } + + return false; +} + +bool +CreateDirectories(const wchar_t* Dir) +{ + // This may be suboptimal, in that it appears to try and create directories + // from the root on up instead of from some directory which is known to + // be present + // + // We should implement a smarter version at some point since this can be + // pretty expensive in aggregate + + return std::filesystem::create_directories(Dir); +} + +// Erase all files and directories in a given directory, leaving an empty directory +// behind + +static bool +WipeDirectory(const wchar_t* DirPath) +{ + ExtendableWideStringBuilder<128> Pattern; + Pattern.Append(DirPath); + Pattern.Append(L"\\*"); + + WIN32_FIND_DATAW FindData; + HANDLE hFind = FindFirstFileW(Pattern.c_str(), &FindData); + + if (hFind != nullptr) + { + do + { + bool IsRegular = true; + + if (FindData.cFileName[0] == L'.') + { + if (FindData.cFileName[1] == L'.') + { + if (FindData.cFileName[2] == L'\0') + { + IsRegular = false; + } + } + else if (FindData.cFileName[1] == L'\0') + { + IsRegular = false; + } + } + + if (IsRegular) + { + ExtendableWideStringBuilder<128> Path; + Path.Append(DirPath); + Path.Append(L'\\'); + Path.Append(FindData.cFileName); + + // if (fd.dwFileAttributes & FILE_ATTRIBUTE_RECALL_ON_OPEN) + // deleteReparsePoint(path.c_str(), fd.dwReserved0); + + if (FindData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) + { + if (FindData.dwFileAttributes & FILE_ATTRIBUTE_RECALL_ON_OPEN) + { + DeleteReparsePoint(Path.c_str(), FindData.dwReserved0); + } + + if (FindData.dwFileAttributes & FILE_ATTRIBUTE_RECALL_ON_DATA_ACCESS) + { + DeleteReparsePoint(Path.c_str(), FindData.dwReserved0); + } + + bool Success = DeleteDirectories(Path.c_str()); + + if (!Success) + { + if (FindData.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT) + { + DeleteReparsePoint(Path.c_str(), FindData.dwReserved0); + } + } + } + else + { + DeleteFileW(Path.c_str()); + } + } + } while (FindNextFileW(hFind, &FindData) == TRUE); + + FindClose(hFind); + } + + return true; +} + +bool +DeleteDirectories(const wchar_t* DirPath) +{ + return WipeDirectory(DirPath) && RemoveDirectoryW(DirPath) == TRUE; +} + +bool +CleanDirectory(const wchar_t* DirPath) +{ + if (std::filesystem::exists(DirPath)) + { + return WipeDirectory(DirPath); + } + + return CreateDirectories(DirPath); +} + +#endif // ZEN_PLATFORM_WINDOWS + +bool +CreateDirectories(const std::filesystem::path& Dir) +{ + while (!std::filesystem::is_directory(Dir)) + { + if (Dir.has_parent_path()) + { + CreateDirectories(Dir.parent_path()); + } + std::error_code ErrorCode; + std::filesystem::create_directory(Dir, ErrorCode); + if (ErrorCode) + { + throw std::system_error(ErrorCode, fmt::format("Failed to create directories for '{}'", Dir.string())); + } + return true; + } + return false; +} + +bool +DeleteDirectories(const std::filesystem::path& Dir) +{ +#if ZEN_PLATFORM_WINDOWS + return DeleteDirectories(Dir.c_str()); +#else + std::error_code ErrorCode; + return std::filesystem::remove_all(Dir, ErrorCode); +#endif +} + +bool +CleanDirectory(const std::filesystem::path& Dir) +{ +#if ZEN_PLATFORM_WINDOWS + return CleanDirectory(Dir.c_str()); +#else + if (std::filesystem::exists(Dir)) + { + bool Success = true; + + std::error_code ErrorCode; + for (const auto& Item : std::filesystem::directory_iterator(Dir)) + { + Success &= std::filesystem::remove_all(Item, ErrorCode); + } + + return Success; + } + + return CreateDirectories(Dir); +#endif +} + +////////////////////////////////////////////////////////////////////////// + +bool +SupportsBlockRefCounting(std::filesystem::path Path) +{ +#if ZEN_PLATFORM_WINDOWS + ATL::CHandle Handle(CreateFileW(Path.c_str(), + GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + nullptr, + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, + nullptr)); + + if (Handle == INVALID_HANDLE_VALUE) + { + Handle.Detach(); + return false; + } + + ULONG FileSystemFlags = 0; + if (!GetVolumeInformationByHandleW(Handle, nullptr, 0, nullptr, nullptr, /* lpFileSystemFlags */ &FileSystemFlags, nullptr, 0)) + { + return false; + } + + if (!(FileSystemFlags & FILE_SUPPORTS_BLOCK_REFCOUNTING)) + { + return false; + } + + return true; +#else + ZEN_UNUSED(Path); + return false; +#endif // ZEN_PLATFORM_WINDOWS +} + +bool +CloneFile(std::filesystem::path FromPath, std::filesystem::path ToPath) +{ +#if ZEN_PLATFORM_WINDOWS + ATL::CHandle FromFile(CreateFileW(FromPath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr)); + if (FromFile == INVALID_HANDLE_VALUE) + { + FromFile.Detach(); + return false; + } + + ULONG FileSystemFlags; + if (!GetVolumeInformationByHandleW(FromFile, nullptr, 0, nullptr, nullptr, /* lpFileSystemFlags */ &FileSystemFlags, nullptr, 0)) + { + return false; + } + if (!(FileSystemFlags & FILE_SUPPORTS_BLOCK_REFCOUNTING)) + { + SetLastError(ERROR_NOT_CAPABLE); + return false; + } + + FILE_END_OF_FILE_INFO FileSize; + if (!GetFileSizeEx(FromFile, &FileSize.EndOfFile)) + { + return false; + } + + FILE_BASIC_INFO BasicInfo; + if (!GetFileInformationByHandleEx(FromFile, FileBasicInfo, &BasicInfo, sizeof BasicInfo)) + { + return false; + } + + DWORD dwBytesReturned = 0; + FSCTL_GET_INTEGRITY_INFORMATION_BUFFER GetIntegrityInfoBuffer; + if (!DeviceIoControl(FromFile, + FSCTL_GET_INTEGRITY_INFORMATION, + nullptr, + 0, + &GetIntegrityInfoBuffer, + sizeof GetIntegrityInfoBuffer, + &dwBytesReturned, + nullptr)) + { + return false; + } + + SetFileAttributesW(ToPath.c_str(), FILE_ATTRIBUTE_NORMAL); + + ATL::CHandle TargetFile(CreateFileW(ToPath.c_str(), + GENERIC_READ | GENERIC_WRITE | DELETE, + /* no sharing */ FILE_SHARE_READ, + nullptr, + OPEN_ALWAYS, + 0, + /* hTemplateFile */ FromFile)); + + if (TargetFile == INVALID_HANDLE_VALUE) + { + TargetFile.Detach(); + return false; + } + + // Delete target file when handle is closed (we only reset this if the copy succeeds) + FILE_DISPOSITION_INFO FileDisposition = {TRUE}; + if (!SetFileInformationByHandle(TargetFile, FileDispositionInfo, &FileDisposition, sizeof FileDisposition)) + { + TargetFile.Close(); + DeleteFileW(ToPath.c_str()); + return false; + } + + // Make file sparse so we don't end up allocating space when we change the file size + if (!DeviceIoControl(TargetFile, FSCTL_SET_SPARSE, nullptr, 0, nullptr, 0, &dwBytesReturned, nullptr)) + { + return false; + } + + // Copy integrity checking information + FSCTL_SET_INTEGRITY_INFORMATION_BUFFER IntegritySet = {GetIntegrityInfoBuffer.ChecksumAlgorithm, + GetIntegrityInfoBuffer.Reserved, + GetIntegrityInfoBuffer.Flags}; + if (!DeviceIoControl(TargetFile, FSCTL_SET_INTEGRITY_INFORMATION, &IntegritySet, sizeof IntegritySet, nullptr, 0, nullptr, nullptr)) + { + return false; + } + + // Resize file - note that the file is sparse at this point so no additional data will be written + if (!SetFileInformationByHandle(TargetFile, FileEndOfFileInfo, &FileSize, sizeof FileSize)) + { + return false; + } + + constexpr auto RoundToClusterSize = [](LONG64 FileSize, ULONG ClusterSize) -> LONG64 { + return (FileSize + ClusterSize - 1) / ClusterSize * ClusterSize; + }; + static_assert(RoundToClusterSize(5678, 4 * 1024) == 8 * 1024); + + // Loop for cloning file contents. This is necessary as the API has a 32-bit size + // limit for some reason + + const LONG64 SplitThreshold = (1LL << 32) - GetIntegrityInfoBuffer.ClusterSizeInBytes; + + DUPLICATE_EXTENTS_DATA DuplicateExtentsData{.FileHandle = FromFile}; + + for (LONG64 CurrentByteOffset = 0, + RemainingBytes = RoundToClusterSize(FileSize.EndOfFile.QuadPart, GetIntegrityInfoBuffer.ClusterSizeInBytes); + RemainingBytes > 0; + CurrentByteOffset += SplitThreshold, RemainingBytes -= SplitThreshold) + { + DuplicateExtentsData.SourceFileOffset.QuadPart = CurrentByteOffset; + DuplicateExtentsData.TargetFileOffset.QuadPart = CurrentByteOffset; + DuplicateExtentsData.ByteCount.QuadPart = std::min(SplitThreshold, RemainingBytes); + + if (!DeviceIoControl(TargetFile, + FSCTL_DUPLICATE_EXTENTS_TO_FILE, + &DuplicateExtentsData, + sizeof DuplicateExtentsData, + nullptr, + 0, + &dwBytesReturned, + nullptr)) + { + return false; + } + } + + // Make the file not sparse again now that we have populated the contents + if (!(BasicInfo.FileAttributes & FILE_ATTRIBUTE_SPARSE_FILE)) + { + FILE_SET_SPARSE_BUFFER SetSparse = {FALSE}; + + if (!DeviceIoControl(TargetFile, FSCTL_SET_SPARSE, &SetSparse, sizeof SetSparse, nullptr, 0, &dwBytesReturned, nullptr)) + { + return false; + } + } + + // Update timestamps (but don't lie about the creation time) + BasicInfo.CreationTime.QuadPart = 0; + if (!SetFileInformationByHandle(TargetFile, FileBasicInfo, &BasicInfo, sizeof BasicInfo)) + { + return false; + } + + if (!FlushFileBuffers(TargetFile)) + { + return false; + } + + // Finally now everything is done - make sure the file is not deleted on close! + + FileDisposition = {FALSE}; + + const bool AllOk = (TRUE == SetFileInformationByHandle(TargetFile, FileDispositionInfo, &FileDisposition, sizeof FileDisposition)); + + return AllOk; +#elif ZEN_PLATFORM_LINUX +# if 0 + struct ScopedFd + { + ~ScopedFd() { close(Fd); } + int Fd; + }; + + // The 'from' file + int FromFd = open(FromPath.c_str(), O_RDONLY|O_CLOEXEC); + if (FromFd < 0) + { + return false; + } + ScopedFd $From = { FromFd }; + + // The 'to' file + int ToFd = open(ToPath.c_str(), O_WRONLY|O_CREAT|O_EXCL|O_CLOEXEC, 0666); + if (ToFd < 0) + { + return false; + } + fchmod(ToFd, 0666); + ScopedFd $To = { FromFd }; + + ioctl(ToFd, FICLONE, FromFd); + + return false; +# endif // 0 + ZEN_UNUSED(FromPath, ToPath); + ZEN_ERROR("CloneFile() is not implemented on this platform"); + return false; +#elif ZEN_PLATFORM_MAC + /* clonefile() syscall if APFS */ + ZEN_UNUSED(FromPath, ToPath); + ZEN_ERROR("CloneFile() is not implemented on this platform"); + return false; +#endif // ZEN_PLATFORM_WINDOWS +} + +bool +CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const CopyFileOptions& Options) +{ + bool Success = false; + + if (Options.EnableClone) + { + Success = CloneFile(FromPath.native(), ToPath.native()); + + if (Success) + { + return true; + } + } + + if (Options.MustClone) + { + return false; + } + +#if ZEN_PLATFORM_WINDOWS + BOOL CancelFlag = FALSE; + Success = !!::CopyFileExW(FromPath.c_str(), + ToPath.c_str(), + /* lpProgressRoutine */ nullptr, + /* lpData */ nullptr, + &CancelFlag, + /* dwCopyFlags */ 0); +#else + struct ScopedFd + { + ~ScopedFd() { close(Fd); } + int Fd; + }; + + // From file + int FromFd = open(FromPath.c_str(), O_RDONLY | O_CLOEXEC); + if (FromFd < 0) + { + ThrowLastError(fmt::format("failed to open file {}", FromPath)); + } + ScopedFd $From = {FromFd}; + + // To file + int ToFd = open(ToPath.c_str(), O_WRONLY | O_CREAT | O_EXCL | O_CLOEXEC, 0666); + if (ToFd < 0) + { + ThrowLastError(fmt::format("failed to create file {}", ToPath)); + } + fchmod(ToFd, 0666); + ScopedFd $To = {ToFd}; + + // Copy impl + static const size_t BufferSize = 64 << 10; + void* Buffer = malloc(BufferSize); + while (true) + { + int BytesRead = read(FromFd, Buffer, BufferSize); + if (BytesRead <= 0) + { + Success = (BytesRead == 0); + break; + } + + if (write(ToFd, Buffer, BytesRead) != BufferSize) + { + Success = false; + break; + } + } + free(Buffer); +#endif // ZEN_PLATFORM_WINDOWS + + if (!Success) + { + ThrowLastError("file copy failed"sv); + } + + return true; +} + +void +WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t BufferCount) +{ +#if ZEN_PLATFORM_WINDOWS + CAtlFile Outfile; + HRESULT hRes = Outfile.Create(Path.c_str(), GENERIC_WRITE, FILE_SHARE_READ, CREATE_ALWAYS); + if (hRes == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) + { + CreateDirectories(Path.parent_path()); + + hRes = Outfile.Create(Path.c_str(), GENERIC_WRITE, FILE_SHARE_READ, CREATE_ALWAYS); + } + + if (FAILED(hRes)) + { + ThrowSystemException(hRes, fmt::format("File open failed for '{}'", Path).c_str()); + } + +#else + int OpenFlags = O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC; + int Fd = open(Path.c_str(), OpenFlags, 0666); + if (Fd < 0) + { + zen::CreateDirectories(Path.parent_path()); + Fd = open(Path.c_str(), OpenFlags, 0666); + } + + if (Fd < 0) + { + ThrowLastError(fmt::format("File open failed for '{}'", Path)); + } + + fchmod(Fd, 0666); +#endif + + // TODO: this should be block-enlightened + + for (size_t i = 0; i < BufferCount; ++i) + { + uint64_t WriteSize = Data[i]->Size(); + const void* DataPtr = Data[i]->Data(); + + while (WriteSize) + { + const uint64_t ChunkSize = Min<uint64_t>(WriteSize, uint64_t(2) * 1024 * 1024 * 1024); + +#if ZEN_PLATFORM_WINDOWS + hRes = Outfile.Write(DataPtr, gsl::narrow_cast<uint32_t>(WriteSize)); + if (FAILED(hRes)) + { + ThrowSystemException(hRes, fmt::format("File write failed for '{}'", Path).c_str()); + } +#else + if (write(Fd, DataPtr, WriteSize) != int64_t(WriteSize)) + { + ThrowLastError(fmt::format("File write failed for '{}'", Path)); + } +#endif // ZEN_PLATFORM_WINDOWS + + WriteSize -= ChunkSize; + DataPtr = reinterpret_cast<const uint8_t*>(DataPtr) + ChunkSize; + } + } + +#if !ZEN_PLATFORM_WINDOWS + close(Fd); +#endif +} + +void +WriteFile(std::filesystem::path Path, IoBuffer Data) +{ + const IoBuffer* const DataPtr = &Data; + + WriteFile(Path, &DataPtr, 1); +} + +IoBuffer +FileContents::Flatten() +{ + if (Data.size() == 1) + { + return Data[0]; + } + else if (Data.empty()) + { + return {}; + } + else + { + ZEN_NOT_IMPLEMENTED(); + } +} + +FileContents +ReadStdIn() +{ + BinaryWriter Writer; + + do + { + uint8_t ReadBuffer[1024]; + + size_t BytesRead = fread(ReadBuffer, 1, sizeof ReadBuffer, stdin); + Writer.Write(ReadBuffer, BytesRead); + } while (!feof(stdin)); + + FileContents Contents; + Contents.Data.emplace_back(IoBuffer(IoBuffer::Clone, Writer.GetData(), Writer.GetSize())); + + return Contents; +} + +FileContents +ReadFile(std::filesystem::path Path) +{ + uint64_t FileSizeBytes; + void* Handle; + +#if ZEN_PLATFORM_WINDOWS + ATL::CHandle FromFile(CreateFileW(Path.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr)); + if (FromFile == INVALID_HANDLE_VALUE) + { + FromFile.Detach(); + return FileContents{.ErrorCode = std::error_code(::GetLastError(), std::system_category())}; + } + + FILE_END_OF_FILE_INFO FileSize; + if (!GetFileSizeEx(FromFile, &FileSize.EndOfFile)) + { + return FileContents{.ErrorCode = std::error_code(::GetLastError(), std::system_category())}; + } + + FileSizeBytes = FileSize.EndOfFile.QuadPart; + Handle = FromFile.Detach(); +#else + int Fd = open(Path.c_str(), O_RDONLY | O_CLOEXEC); + if (Fd < 0) + { + FileContents Ret; + Ret.ErrorCode = std::error_code(zen::GetLastError(), std::system_category()); + return Ret; + } + + static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files"); + struct stat Stat; + fstat(Fd, &Stat); + + FileSizeBytes = Stat.st_size; + Handle = (void*)uintptr_t(Fd); +#endif + + FileContents Contents; + Contents.Data.emplace_back(IoBuffer(IoBuffer::File, Handle, 0, FileSizeBytes)); + return Contents; +} + +bool +ScanFile(std::filesystem::path Path, const uint64_t ChunkSize, std::function<void(const void* Data, size_t Size)>&& ProcessFunc) +{ +#if ZEN_PLATFORM_WINDOWS + ATL::CHandle FromFile(CreateFileW(Path.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr)); + if (FromFile == INVALID_HANDLE_VALUE) + { + FromFile.Detach(); + return false; + } + + std::vector<uint8_t> ReadBuffer(ChunkSize); + + for (;;) + { + DWORD dwBytesRead = 0; + BOOL Success = ::ReadFile(FromFile, ReadBuffer.data(), (DWORD)ReadBuffer.size(), &dwBytesRead, nullptr); + + if (!Success) + { + throw std::system_error(std::error_code(::GetLastError(), std::system_category()), "file scan failed"); + } + + if (dwBytesRead == 0) + break; + + ProcessFunc(ReadBuffer.data(), dwBytesRead); + } +#else + int Fd = open(Path.c_str(), O_RDONLY | O_CLOEXEC); + if (Fd < 0) + { + return false; + } + + bool Success = true; + + void* Buffer = malloc(ChunkSize); + while (true) + { + int BytesRead = read(Fd, Buffer, ChunkSize); + if (BytesRead < 0) + { + Success = false; + break; + } + + if (BytesRead == 0) + { + break; + } + + ProcessFunc(Buffer, BytesRead); + } + + free(Buffer); + close(Fd); + + if (!Success) + { + ThrowLastError("file scan failed"); + } +#endif // ZEN_PLATFORM_WINDOWS + + return true; +} + +void +PathToUtf8(const std::filesystem::path& Path, StringBuilderBase& Out) +{ +#if ZEN_PLATFORM_WINDOWS + WideToUtf8(Path.native().c_str(), Out); +#else + Out << Path.c_str(); +#endif +} + +std::string +PathToUtf8(const std::filesystem::path& Path) +{ +#if ZEN_PLATFORM_WINDOWS + return WideToUtf8(Path.native().c_str()); +#else + return Path.string(); +#endif +} + +DiskSpace +DiskSpaceInfo(std::filesystem::path Directory, std::error_code& Error) +{ + using namespace std::filesystem; + + space_info SpaceInfo = space(Directory, Error); + if (Error) + { + return {}; + } + + return { + .Free = uint64_t(SpaceInfo.available), + .Total = uint64_t(SpaceInfo.capacity), + }; +} + +void +FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, TreeVisitor& Visitor) +{ +#if ZEN_PLATFORM_WINDOWS + uint64_t FileInfoBuffer[8 * 1024]; + + FILE_INFO_BY_HANDLE_CLASS FibClass = FileIdBothDirectoryRestartInfo; + bool Continue = true; + + CAtlFile RootDirHandle; + HRESULT hRes = + RootDirHandle.Create(RootDir.c_str(), GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_WRITE, OPEN_EXISTING, FILE_FLAG_BACKUP_SEMANTICS); + + if (FAILED(hRes)) + { + ThrowSystemException(hRes, "Failed to open handle to volume root"); + } + + while (Continue) + { + BOOL Success = GetFileInformationByHandleEx(RootDirHandle, FibClass, FileInfoBuffer, sizeof FileInfoBuffer); + FibClass = FileIdBothDirectoryInfo; // Set up for next iteration + + uint64_t EntryOffset = 0; + + if (!Success) + { + DWORD LastError = GetLastError(); + + if (LastError == ERROR_NO_MORE_FILES) + { + break; + } + + throw std::system_error(std::error_code(LastError, std::system_category()), "file system traversal error"); + } + + for (;;) + { + const FILE_ID_BOTH_DIR_INFO* DirInfo = + reinterpret_cast<const FILE_ID_BOTH_DIR_INFO*>(reinterpret_cast<const uint8_t*>(FileInfoBuffer) + EntryOffset); + + std::wstring_view FileName(DirInfo->FileName, DirInfo->FileNameLength / sizeof(wchar_t)); + + if (DirInfo->FileAttributes & FILE_ATTRIBUTE_DIRECTORY) + { + if (FileName == L"."sv || FileName == L".."sv) + { + // Not very interesting + } + else + { + const bool ShouldDescend = Visitor.VisitDirectory(RootDir, FileName); + + if (ShouldDescend) + { + // Note that this recursion combined with the buffer could + // blow the stack, we should consider a different strategy + + std::filesystem::path FullPath = RootDir / FileName; + + TraverseFileSystem(FullPath, Visitor); + } + } + } + else if (DirInfo->FileAttributes & FILE_ATTRIBUTE_DEVICE) + { + ZEN_WARN("encountered device node during file system traversal: '{}' found in '{}'", WideToUtf8(FileName), RootDir); + } + else + { + Visitor.VisitFile(RootDir, FileName, DirInfo->EndOfFile.QuadPart); + } + + const uint64_t NextOffset = DirInfo->NextEntryOffset; + + if (NextOffset == 0) + { + break; + } + + EntryOffset += NextOffset; + } + } +#else + /* Could also implement this using Linux's getdents() syscall */ + + DIR* Dir = opendir(RootDir.c_str()); + if (Dir == nullptr) + { + ThrowLastError(fmt::format("Failed to open directory for traversal: {}", RootDir.c_str())); + } + + for (struct dirent* Entry; (Entry = readdir(Dir));) + { + const char* FileName = Entry->d_name; + + struct stat Stat; + std::filesystem::path FullPath = RootDir / FileName; + stat(FullPath.c_str(), &Stat); + + if (S_ISDIR(Stat.st_mode)) + { + if (strcmp(FileName, ".") == 0 || strcmp(FileName, "..") == 0) + { + /* nop */ + } + else if (Visitor.VisitDirectory(RootDir, FileName)) + { + TraverseFileSystem(FullPath, Visitor); + } + } + else if (S_ISREG(Stat.st_mode)) + { + Visitor.VisitFile(RootDir, FileName, Stat.st_size); + } + else + { + ZEN_WARN("encountered non-regular file during file system traversal ({}): {} found in {}", + Stat.st_mode, + FileName, + RootDir.c_str()); + } + } + + closedir(Dir); +#endif // ZEN_PLATFORM_WINDOWS +} + +std::filesystem::path +PathFromHandle(void* NativeHandle) +{ +#if ZEN_PLATFORM_WINDOWS + if (NativeHandle == nullptr || NativeHandle == INVALID_HANDLE_VALUE) + { + return std::filesystem::path(); + } + + auto GetFinalPathNameByHandleWRetry = [](HANDLE hFile, LPWSTR lpszFilePath, DWORD cchFilePath, DWORD dwFlags) -> DWORD { + while (true) + { + DWORD Res = GetFinalPathNameByHandleW(hFile, lpszFilePath, cchFilePath, dwFlags); + if (Res == 0) + { + DWORD LastError = zen::GetLastError(); + // Under heavy concurrent loads we might get access denied on a file handle while trying to get path name. + // Retry if that is the case. + if (LastError != ERROR_ACCESS_DENIED) + { + ThrowSystemError(LastError, fmt::format("failed to get path from file handle {}", hFile)); + } + // Retry + continue; + } + ZEN_ASSERT(Res != 1); // We don't accept empty path names + return Res; + } + }; + + static const DWORD PathDataSize = 512; + wchar_t PathData[PathDataSize]; + DWORD RequiredLengthIncludingNul = GetFinalPathNameByHandleWRetry(NativeHandle, PathData, PathDataSize, FILE_NAME_OPENED); + if (RequiredLengthIncludingNul == 0) + { + ThrowLastError(fmt::format("failed to get path from file handle {}", NativeHandle)); + } + + if (RequiredLengthIncludingNul < PathDataSize) + { + std::wstring FullPath(PathData, gsl::narrow<size_t>(RequiredLengthIncludingNul)); + return FullPath; + } + + std::wstring FullPath; + FullPath.resize(RequiredLengthIncludingNul - 1); + + const DWORD FinalLength = GetFinalPathNameByHandleWRetry(NativeHandle, FullPath.data(), RequiredLengthIncludingNul, FILE_NAME_OPENED); + ZEN_UNUSED(FinalLength); + return FullPath; + +#elif ZEN_PLATFORM_LINUX + char Link[PATH_MAX]; + char Path[64]; + + sprintf(Path, "/proc/self/fd/%d", int(uintptr_t(NativeHandle))); + ssize_t BytesRead = readlink(Path, Link, sizeof(Link) - 1); + if (BytesRead <= 0) + { + return std::filesystem::path(); + } + + Link[BytesRead] = '\0'; + return Link; +#elif ZEN_PLATFORM_MAC + int Fd = int(uintptr_t(NativeHandle)); + char Path[MAXPATHLEN]; + if (fcntl(Fd, F_GETPATH, Path) < 0) + { + return std::filesystem::path(); + } + + return Path; +#endif // ZEN_PLATFORM_WINDOWS +} + +std::filesystem::path +GetRunningExecutablePath() +{ +#if ZEN_PLATFORM_WINDOWS + TCHAR ExePath[MAX_PATH]; + DWORD PathLength = GetModuleFileName(NULL, ExePath, ZEN_ARRAY_COUNT(ExePath)); + + return {std::wstring_view(ExePath, PathLength)}; +#elif ZEN_PLATFORM_LINUX + char Link[256]; + ssize_t BytesRead = readlink("/proc/self/exe", Link, sizeof(Link) - 1); + if (BytesRead < 0) + return {}; + + Link[BytesRead] = '\0'; + return Link; +#elif ZEN_PLATFORM_MAC + char Buffer[PROC_PIDPATHINFO_MAXSIZE]; + + int SelfPid = GetCurrentProcessId(); + if (proc_pidpath(SelfPid, Buffer, sizeof(Buffer)) <= 0) + return {}; + + return Buffer; +#endif // ZEN_PLATFORM_WINDOWS +} + +void +MaximizeOpenFileCount() +{ +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + struct rlimit Limit; + int Error = getrlimit(RLIMIT_NOFILE, &Limit); + if (Error) + { + ZEN_WARN("failed getting rlimit RLIMIT_NOFILE, reason '{}'", zen::MakeErrorCode(Error).message()); + } + else + { + struct rlimit NewLimit = Limit; + NewLimit.rlim_cur = NewLimit.rlim_max; + ZEN_INFO("changing RLIMIT_NOFILE from rlim_cur = {}, rlim_max {} to rlim_cur = {}, rlim_max {}", + Limit.rlim_cur, + Limit.rlim_max, + NewLimit.rlim_cur, + NewLimit.rlim_max); + + Error = setrlimit(RLIMIT_NOFILE, &NewLimit); + if (Error != 0) + { + ZEN_WARN("failed to set RLIMIT_NOFILE limits from rlim_cur = {}, rlim_max {} to rlim_cur = {}, rlim_max {}, reason '{}'", + Limit.rlim_cur, + Limit.rlim_max, + NewLimit.rlim_cur, + NewLimit.rlim_max, + zen::MakeErrorCode(Error).message()); + } + } +#endif +} + +void +GetDirectoryContent(const std::filesystem::path& RootDir, uint8_t Flags, DirectoryContent& OutContent) +{ + FileSystemTraversal Traversal; + struct Visitor : public FileSystemTraversal::TreeVisitor + { + Visitor(uint8_t Flags, DirectoryContent& OutContent) : Flags(Flags), Content(OutContent) {} + + virtual void VisitFile([[maybe_unused]] const std::filesystem::path& Parent, + [[maybe_unused]] const path_view& File, + [[maybe_unused]] uint64_t FileSize) override + { + if (Flags & DirectoryContent::IncludeFilesFlag) + { + Content.Files.push_back(Parent / File); + } + } + + virtual bool VisitDirectory([[maybe_unused]] const std::filesystem::path& Parent, const path_view& DirectoryName) override + { + if (Flags & DirectoryContent::IncludeDirsFlag) + { + Content.Directories.push_back(Parent / DirectoryName); + } + return (Flags & DirectoryContent::RecursiveFlag) != 0; + } + + const uint8_t Flags; + DirectoryContent& Content; + } Visit(Flags, OutContent); + + Traversal.TraverseFileSystem(RootDir, Visit); +} + +std::string +GetEnvVariable(std::string_view VariableName) +{ + ZEN_ASSERT(!VariableName.empty()); +#if ZEN_PLATFORM_WINDOWS + + CHAR EnvVariableBuffer[1023 + 1]; + DWORD RESULT = GetEnvironmentVariableA(std::string(VariableName).c_str(), EnvVariableBuffer, sizeof(EnvVariableBuffer)); + if (RESULT > 0 && RESULT < sizeof(EnvVariableBuffer)) + { + return std::string(EnvVariableBuffer); + } +#endif +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + char* EnvVariable = getenv(std::string(VariableName).c_str()); + if (EnvVariable) + { + return std::string(EnvVariable); + } +#endif + return ""; +} + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +void +filesystem_forcelink() +{ +} + +TEST_CASE("filesystem") +{ + using namespace std::filesystem; + + // GetExePath -- this is not a great test as it's so dependent on where the this code gets linked in + path BinPath = GetRunningExecutablePath(); + const bool ExpectedExe = PathToUtf8(BinPath.stem().native()).ends_with("-test"sv) || BinPath.stem() == "zenserver"; + CHECK(ExpectedExe); + CHECK(is_regular_file(BinPath)); + + // PathFromHandle + void* Handle; +# if ZEN_PLATFORM_WINDOWS + Handle = CreateFileW(BinPath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr); + CHECK(Handle != INVALID_HANDLE_VALUE); +# else + int Fd = open(BinPath.c_str(), O_RDONLY | O_CLOEXEC); + CHECK(Fd >= 0); + Handle = (void*)uintptr_t(Fd); +# endif + + auto FromHandle = PathFromHandle((void*)uintptr_t(Handle)); + CHECK(equivalent(FromHandle, BinPath)); + +# if ZEN_PLATFORM_WINDOWS + CloseHandle(Handle); +# else + close(int(uintptr_t(Handle))); +# endif + + // Traversal + struct : public FileSystemTraversal::TreeVisitor + { + virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t) override + { + bFoundExpected |= std::filesystem::equivalent(Parent / File, Expected); + } + + virtual bool VisitDirectory(const std::filesystem::path&, const path_view&) override { return true; } + + bool bFoundExpected = false; + std::filesystem::path Expected; + } Visitor; + Visitor.Expected = BinPath; + + FileSystemTraversal().TraverseFileSystem(BinPath.parent_path().parent_path(), Visitor); + CHECK(Visitor.bFoundExpected); + + // Scan/read file + FileContents BinRead = ReadFile(BinPath); + std::vector<uint8_t> BinScan; + ScanFile(BinPath, 16 << 10, [&](const void* Data, size_t Size) { + const auto* Ptr = (uint8_t*)Data; + BinScan.insert(BinScan.end(), Ptr, Ptr + Size); + }); + CHECK_EQ(BinRead.Data.size(), 1); + CHECK_EQ(BinScan.size(), BinRead.Data[0].GetSize()); +} + +TEST_CASE("WriteFile") +{ + std::filesystem::path TempFile = GetRunningExecutablePath().parent_path(); + TempFile /= "write_file_test"; + + uint64_t Magics[] = { + 0x0'a9e'a9e'a9e'a9e'a9e, + 0x0'493'493'493'493'493, + }; + + struct + { + const void* Data; + size_t Size; + } MagicTests[] = { + { + Magics, + sizeof(Magics), + }, + { + Magics + 1, + sizeof(Magics[0]), + }, + }; + for (auto& MagicTest : MagicTests) + { + WriteFile(TempFile, IoBuffer(IoBuffer::Wrap, MagicTest.Data, MagicTest.Size)); + + FileContents MagicsReadback = ReadFile(TempFile); + CHECK_EQ(MagicsReadback.Data.size(), 1); + CHECK_EQ(MagicsReadback.Data[0].GetSize(), MagicTest.Size); + CHECK_EQ(memcmp(MagicTest.Data, MagicsReadback.Data[0].Data(), MagicTest.Size), 0); + } + + std::filesystem::remove(TempFile); +} + +TEST_CASE("DiskSpaceInfo") +{ + std::filesystem::path BinPath = GetRunningExecutablePath(); + + DiskSpace Space = {}; + + std::error_code Error; + Space = DiskSpaceInfo(BinPath, Error); + CHECK(!Error); + + bool Okay = DiskSpaceInfo(BinPath, Space); + CHECK(Okay); + + CHECK(int64_t(Space.Total) > 0); + CHECK(int64_t(Space.Free) > 0); // Hopefully there's at least one byte free +} + +TEST_CASE("PathBuilder") +{ +# if ZEN_PLATFORM_WINDOWS + const char* foo_bar = "/foo\\bar"; +# else + const char* foo_bar = "/foo/bar"; +# endif + + ExtendablePathBuilder<32> Path; + for (const char* Prefix : {"/foo", "/foo/"}) + { + Path.Reset(); + Path.Append(Prefix); + Path /= "bar"; + CHECK(Path.ToPath() == foo_bar); + } + + using fspath = std::filesystem::path; + + Path.Reset(); + Path.Append(fspath("/foo/")); + Path /= (fspath("bar")); + CHECK(Path.ToPath() == foo_bar); + +# if ZEN_PLATFORM_WINDOWS + Path.Reset(); + Path.Append(fspath(L"/\u0119oo/")); + Path /= L"bar"; + printf("%ls\n", Path.ToPath().c_str()); + CHECK(Path.ToView() == L"/\u0119oo/bar"); + CHECK(Path.ToPath() == L"\\\u0119oo\\bar"); +# endif +} + +#endif + +} // namespace zen diff --git a/src/zencore/include/zencore/atomic.h b/src/zencore/include/zencore/atomic.h new file mode 100644 index 000000000..bf549e21d --- /dev/null +++ b/src/zencore/include/zencore/atomic.h @@ -0,0 +1,74 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if ZEN_COMPILER_MSC +# include <intrin.h> +#else +# include <atomic> +#endif + +#include <cinttypes> + +namespace zen { + +inline uint32_t +AtomicIncrement(volatile uint32_t& value) +{ +#if ZEN_COMPILER_MSC + return _InterlockedIncrement((long volatile*)&value); +#else + return ((std::atomic<uint32_t>*)(&value))->fetch_add(1, std::memory_order_seq_cst) + 1; +#endif +} +inline uint32_t +AtomicDecrement(volatile uint32_t& value) +{ +#if ZEN_COMPILER_MSC + return _InterlockedDecrement((long volatile*)&value); +#else + return ((std::atomic<uint32_t>*)(&value))->fetch_sub(1, std::memory_order_seq_cst) - 1; +#endif +} + +inline uint64_t +AtomicIncrement(volatile uint64_t& value) +{ +#if ZEN_COMPILER_MSC + return _InterlockedIncrement64((__int64 volatile*)&value); +#else + return ((std::atomic<uint64_t>*)(&value))->fetch_add(1, std::memory_order_seq_cst) + 1; +#endif +} +inline uint64_t +AtomicDecrement(volatile uint64_t& value) +{ +#if ZEN_COMPILER_MSC + return _InterlockedDecrement64((__int64 volatile*)&value); +#else + return ((std::atomic<uint64_t>*)(&value))->fetch_sub(1, std::memory_order_seq_cst) - 1; +#endif +} + +inline uint32_t +AtomicAdd(volatile uint32_t& value, uint32_t amount) +{ +#if ZEN_COMPILER_MSC + return _InterlockedExchangeAdd((long volatile*)&value, amount); +#else + return ((std::atomic<uint32_t>*)(&value))->fetch_add(amount, std::memory_order_seq_cst); +#endif +} +inline uint64_t +AtomicAdd(volatile uint64_t& value, uint64_t amount) +{ +#if ZEN_COMPILER_MSC + return _InterlockedExchangeAdd64((__int64 volatile*)&value, amount); +#else + return ((std::atomic<uint64_t>*)(&value))->fetch_add(amount, std::memory_order_seq_cst); +#endif +} + +} // namespace zen diff --git a/src/zencore/include/zencore/base64.h b/src/zencore/include/zencore/base64.h new file mode 100644 index 000000000..4d78b085f --- /dev/null +++ b/src/zencore/include/zencore/base64.h @@ -0,0 +1,17 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +namespace zen { + +struct Base64 +{ + template<typename CharType> + static uint32_t Encode(const uint8_t* Source, uint32_t Length, CharType* Dest); + + static inline constexpr int32_t GetEncodedDataSize(uint32_t Size) { return ((Size + 2) / 3) * 4; } +}; + +} // namespace zen diff --git a/src/zencore/include/zencore/blake3.h b/src/zencore/include/zencore/blake3.h new file mode 100644 index 000000000..b31b710a7 --- /dev/null +++ b/src/zencore/include/zencore/blake3.h @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <cinttypes> +#include <compare> +#include <cstring> + +#include <zencore/memory.h> + +namespace zen { + +class CompositeBuffer; +class StringBuilderBase; + +/** + * BLAKE3 hash - 256 bits + */ +struct BLAKE3 +{ + uint8_t Hash[32]; + + inline auto operator<=>(const BLAKE3& Rhs) const = default; + + static BLAKE3 HashBuffer(const CompositeBuffer& Buffer); + static BLAKE3 HashMemory(const void* Data, size_t ByteCount); + static BLAKE3 FromHexString(const char* String); + const char* ToHexString(char* OutString /* 40 characters + NUL terminator */) const; + StringBuilderBase& ToHexString(StringBuilderBase& OutBuilder) const; + + static const int StringLength = 64; + typedef char String_t[StringLength + 1]; + + static BLAKE3 Zero; // Initialized to all zeroes + + struct Hasher + { + size_t operator()(const BLAKE3& v) const + { + size_t h; + memcpy(&h, v.Hash, sizeof h); + return h; + } + }; +}; + +struct BLAKE3Stream +{ + BLAKE3Stream(); + + void Reset(); // Begin streaming hash compute (not needed on freshly constructed instance) + BLAKE3Stream& Append(const void* data, size_t byteCount); // Append another chunk + BLAKE3Stream& Append(MemoryView DataView) { return Append(DataView.GetData(), DataView.GetSize()); } // Append another chunk + BLAKE3 GetHash(); // Obtain final hash. If you wish to reuse the instance call reset() + +private: + alignas(16) uint8_t m_HashState[2048]; +}; + +void blake3_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/blockingqueue.h b/src/zencore/include/zencore/blockingqueue.h new file mode 100644 index 000000000..f92df5a54 --- /dev/null +++ b/src/zencore/include/zencore/blockingqueue.h @@ -0,0 +1,76 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <atomic> +#include <condition_variable> +#include <deque> +#include <mutex> + +namespace zen { + +template<typename T> +class BlockingQueue +{ +public: + BlockingQueue() = default; + + ~BlockingQueue() { CompleteAdding(); } + + void Enqueue(T&& Item) + { + { + std::lock_guard Lock(m_Lock); + m_Queue.emplace_back(std::move(Item)); + m_Size++; + } + + m_NewItemSignal.notify_one(); + } + + bool WaitAndDequeue(T& Item) + { + if (m_CompleteAdding.load()) + { + return false; + } + + std::unique_lock Lock(m_Lock); + m_NewItemSignal.wait(Lock, [this]() { return !m_Queue.empty() || m_CompleteAdding.load(); }); + + if (!m_Queue.empty()) + { + Item = std::move(m_Queue.front()); + m_Queue.pop_front(); + m_Size--; + + return true; + } + + return false; + } + + void CompleteAdding() + { + if (!m_CompleteAdding.load()) + { + m_CompleteAdding.store(true); + m_NewItemSignal.notify_all(); + } + } + + std::size_t Size() const + { + std::unique_lock Lock(m_Lock); + return m_Queue.size(); + } + +private: + mutable std::mutex m_Lock; + std::condition_variable m_NewItemSignal; + std::deque<T> m_Queue; + std::atomic_bool m_CompleteAdding{false}; + std::atomic_uint32_t m_Size; +}; + +} // namespace zen diff --git a/src/zencore/include/zencore/compactbinary.h b/src/zencore/include/zencore/compactbinary.h new file mode 100644 index 000000000..b546f97aa --- /dev/null +++ b/src/zencore/include/zencore/compactbinary.h @@ -0,0 +1,1475 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <zencore/enumflags.h> +#include <zencore/intmath.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/memory.h> +#include <zencore/meta.h> +#include <zencore/sharedbuffer.h> +#include <zencore/uid.h> +#include <zencore/varint.h> + +#include <functional> +#include <memory> +#include <string> +#include <string_view> +#include <type_traits> +#include <vector> + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +class CbObjectView; +class CbArrayView; +class BinaryReader; +class BinaryWriter; +class CompressedBuffer; +class CbValue; + +class DateTime +{ +public: + explicit DateTime(uint64_t InTicks) : Ticks(InTicks) {} + inline DateTime(int Year, int Month, int Day, int Hours = 0, int Minutes = 0, int Seconds = 0, int MilliSeconds = 0) + { + Set(Year, Month, Day, Hours, Minutes, Seconds, MilliSeconds); + } + + inline uint64_t GetTicks() const { return Ticks; } + + static uint64_t NowTicks(); + static DateTime Now(); + + int GetYear() const; + int GetMonth() const; + int GetDay() const; + int GetHour() const; + int GetHour12() const; + int GetMinute() const; + int GetSecond() const; + int GetMillisecond() const; + void GetDate(int& Year, int& Month, int& Day) const; + + inline bool operator==(const DateTime& Rhs) const { return Ticks == Rhs.Ticks; } + inline auto operator<=>(const DateTime& Rhs) const { return Ticks - Rhs.Ticks; } + + std::string ToString(const char* Format) const; + std::string ToIso8601() const; + +private: + void Set(int Year, int Month, int Day, int Hours, int Minutes, int Seconds, int MilliSecond); + uint64_t Ticks; // 1 tick == 0.1us == 100ns, epoch == Jan 1st 0001 +}; + +class TimeSpan +{ +public: + explicit TimeSpan(uint64_t InTicks) : Ticks(InTicks) {} + inline TimeSpan(int Hours, int Minutes, int Seconds) { Set(0, Hours, Minutes, Seconds, 0); } + inline TimeSpan(int Days, int Hours, int Minutes, int Seconds) { Set(Days, Hours, Minutes, Seconds, 0); } + inline TimeSpan(int Days, int Hours, int Minutes, int Seconds, int Nanos) { Set(Days, Hours, Minutes, Seconds, Nanos); } + + inline uint64_t GetTicks() const { return Ticks; } + inline bool operator==(const TimeSpan& Rhs) const { return Ticks == Rhs.Ticks; } + inline auto operator<=>(const TimeSpan& Rhs) const { return Ticks - Rhs.Ticks; } + + /** + * Time span related constants. + */ + + /** The maximum number of ticks that can be represented in FTimespan. */ + static constexpr int64_t MaxTicks = 9223372036854775807; + + /** The minimum number of ticks that can be represented in FTimespan. */ + static constexpr int64_t MinTicks = -9223372036854775807 - 1; + + /** The number of nanoseconds per tick. */ + static constexpr int64_t NanosecondsPerTick = 100; + + /** The number of timespan ticks per day. */ + static constexpr int64_t TicksPerDay = 864000000000; + + /** The number of timespan ticks per hour. */ + static constexpr int64_t TicksPerHour = 36000000000; + + /** The number of timespan ticks per microsecond. */ + static constexpr int64_t TicksPerMicrosecond = 10; + + /** The number of timespan ticks per millisecond. */ + static constexpr int64_t TicksPerMillisecond = 10000; + + /** The number of timespan ticks per minute. */ + static constexpr int64_t TicksPerMinute = 600000000; + + /** The number of timespan ticks per second. */ + static constexpr int64_t TicksPerSecond = 10000000; + + /** The number of timespan ticks per week. */ + static constexpr int64_t TicksPerWeek = 6048000000000; + + /** The number of timespan ticks per year (365 days, not accounting for leap years). */ + static constexpr int64_t TicksPerYear = 365 * TicksPerDay; + + int GetFractionTicks() const { return (int)(Ticks % TicksPerSecond); } + + int GetFractionMicro() const { return (int)((Ticks % TicksPerSecond) / TicksPerMicrosecond); } + + int GetFractionMilli() const { return (int)((Ticks % TicksPerSecond) / TicksPerMillisecond); } + + int GetFractionNano() const { return (int)((Ticks % TicksPerSecond) * NanosecondsPerTick); } + + int GetDays() const { return (int)(Ticks / TicksPerDay); } + + int GetHours() const { return (int)((Ticks / TicksPerHour) % 24); } + + int GetMinutes() const { return (int)((Ticks / TicksPerMinute) % 60); } + + int GetSeconds() const { return (int)((Ticks / TicksPerSecond) % 60); } + + ZENCORE_API std::string ToString(const char* Format) const; + ZENCORE_API std::string ToString() const; + +private: + void Set(int Days, int Hours, int Minutes, int Seconds, int FractionNano); + + uint64_t Ticks; +}; + +struct Guid +{ + uint32_t A, B, C, D; + + StringBuilderBase& ToString(StringBuilderBase& OutString) const; +}; + +////////////////////////////////////////////////////////////////////////// + +/** + * Field types and flags for CbField. + * + * This is a private type and is only declared here to enable inline use below. + * + * DO NOT CHANGE THE VALUE OF ANY MEMBERS OF THIS ENUM! + * BACKWARD COMPATIBILITY REQUIRES THAT THESE VALUES BE FIXED! + * SERIALIZATION USES HARD-CODED CONSTANTS BASED ON THESE VALUES! + */ +enum class CbFieldType : uint8_t +{ + /** A field type that does not occur in a valid object. */ + None = 0x00, + + /** Null. Payload is empty. */ + Null = 0x01, + + /** + * Object is an array of fields with unique non-empty names. + * + * Payload is a VarUInt byte count for the encoded fields followed by the fields. + */ + Object = 0x02, + /** + * UniformObject is an array of fields with the same field types and unique non-empty names. + * + * Payload is a VarUInt byte count for the encoded fields followed by the fields. + */ + UniformObject = 0x03, + + /** + * Array is an array of fields with no name that may be of different types. + * + * Payload is a VarUInt byte count, followed by a VarUInt item count, followed by the fields. + */ + Array = 0x04, + /** + * UniformArray is an array of fields with no name and with the same field type. + * + * Payload is a VarUInt byte count, followed by a VarUInt item count, followed by field type, + * followed by the fields without their field type. + */ + UniformArray = 0x05, + + /** Binary. Payload is a VarUInt byte count followed by the data. */ + Binary = 0x06, + + /** String in UTF-8. Payload is a VarUInt byte count then an unterminated UTF-8 string. */ + String = 0x07, + + /** + * Non-negative integer with the range of a 64-bit unsigned integer. + * + * Payload is the value encoded as a VarUInt. + */ + IntegerPositive = 0x08, + /** + * Negative integer with the range of a 64-bit signed integer. + * + * Payload is the ones' complement of the value encoded as a VarUInt. + */ + IntegerNegative = 0x09, + + /** Single precision float. Payload is one big endian IEEE 754 binary32 float. */ + Float32 = 0x0a, + /** Double precision float. Payload is one big endian IEEE 754 binary64 float. */ + Float64 = 0x0b, + + /** Boolean false value. Payload is empty. */ + BoolFalse = 0x0c, + /** Boolean true value. Payload is empty. */ + BoolTrue = 0x0d, + + /** + * ObjectAttachment is a reference to a compact binary attachment stored externally. + * + * Payload is a 160-bit hash digest of the referenced compact binary data. + */ + ObjectAttachment = 0x0e, + /** + * BinaryAttachment is a reference to a binary attachment stored externally. + * + * Payload is a 160-bit hash digest of the referenced binary data. + */ + BinaryAttachment = 0x0f, + + /** Hash. Payload is a 160-bit hash digest. */ + Hash = 0x10, + /** UUID/GUID. Payload is a 128-bit UUID as defined by RFC 4122. */ + Uuid = 0x11, + + /** + * Date and time between 0001-01-01 00:00:00.0000000 and 9999-12-31 23:59:59.9999999. + * + * Payload is a big endian int64 count of 100ns ticks since 0001-01-01 00:00:00.0000000. + */ + DateTime = 0x12, + /** + * Difference between two date/time values. + * + * Payload is a big endian int64 count of 100ns ticks in the span, and may be negative. + */ + TimeSpan = 0x13, + + /** + * Object ID + * + * Payload is a 12-byte opaque identifier + */ + ObjectId = 0x14, + + /** + * CustomById identifies the sub-type of its payload by an integer identifier. + * + * Payload is a VarUInt byte count of the sub-type identifier and the sub-type payload, followed + * by a VarUInt of the sub-type identifier then the payload of the sub-type. + */ + CustomById = 0x1e, + /** + * CustomByType identifies the sub-type of its payload by a string identifier. + * + * Payload is a VarUInt byte count of the sub-type identifier and the sub-type payload, followed + * by a VarUInt byte count of the unterminated sub-type identifier, then the sub-type identifier + * without termination, then the payload of the sub-type. + */ + CustomByName = 0x1f, + + /** Reserved for future use as a flag. Do not add types in this range. */ + Reserved = 0x20, + + /** + * A transient flag which indicates that the object or array containing this field has stored + * the field type before the payload and name. Non-uniform objects and fields will set this. + * + * Note: Since the flag must never be serialized, this bit may be repurposed in the future. + */ + HasFieldType = 0x40, + + /** A persisted flag which indicates that the field has a name stored before the payload. */ + HasFieldName = 0x80, +}; + +ENUM_CLASS_FLAGS(CbFieldType); + +/** Functions that operate on CbFieldType. */ +class CbFieldTypeOps +{ + static constexpr CbFieldType SerializedTypeMask = CbFieldType(0b1011'1111); + static constexpr CbFieldType TypeMask = CbFieldType(0b0011'1111); + static constexpr CbFieldType ObjectMask = CbFieldType(0b0011'1110); + static constexpr CbFieldType ObjectBase = CbFieldType(0b0000'0010); + static constexpr CbFieldType ArrayMask = CbFieldType(0b0011'1110); + static constexpr CbFieldType ArrayBase = CbFieldType(0b0000'0100); + static constexpr CbFieldType IntegerMask = CbFieldType(0b0011'1110); + static constexpr CbFieldType IntegerBase = CbFieldType(0b0000'1000); + static constexpr CbFieldType FloatMask = CbFieldType(0b0011'1100); + static constexpr CbFieldType FloatBase = CbFieldType(0b0000'1000); + static constexpr CbFieldType BoolMask = CbFieldType(0b0011'1110); + static constexpr CbFieldType BoolBase = CbFieldType(0b0000'1100); + static constexpr CbFieldType AttachmentMask = CbFieldType(0b0011'1110); + static constexpr CbFieldType AttachmentBase = CbFieldType(0b0000'1110); + + static void StaticAssertTypeConstants(); + +public: + /** The type with flags removed. */ + static constexpr inline CbFieldType GetType(CbFieldType Type) { return Type & TypeMask; } + /** The type with transient flags removed. */ + static constexpr inline CbFieldType GetSerializedType(CbFieldType Type) { return Type & SerializedTypeMask; } + + static constexpr inline bool HasFieldType(CbFieldType Type) { return EnumHasAnyFlags(Type, CbFieldType::HasFieldType); } + static constexpr inline bool HasFieldName(CbFieldType Type) { return EnumHasAnyFlags(Type, CbFieldType::HasFieldName); } + + static constexpr inline bool IsNone(CbFieldType Type) { return GetType(Type) == CbFieldType::None; } + static constexpr inline bool IsNull(CbFieldType Type) { return GetType(Type) == CbFieldType::Null; } + + static constexpr inline bool IsObject(CbFieldType Type) { return (Type & ObjectMask) == ObjectBase; } + static constexpr inline bool IsArray(CbFieldType Type) { return (Type & ArrayMask) == ArrayBase; } + + static constexpr inline bool IsBinary(CbFieldType Type) { return GetType(Type) == CbFieldType::Binary; } + static constexpr inline bool IsString(CbFieldType Type) { return GetType(Type) == CbFieldType::String; } + + static constexpr inline bool IsInteger(CbFieldType Type) { return (Type & IntegerMask) == IntegerBase; } + /** Whether the field is a float, or integer due to implicit conversion. */ + static constexpr inline bool IsFloat(CbFieldType Type) { return (Type & FloatMask) == FloatBase; } + static constexpr inline bool IsBool(CbFieldType Type) { return (Type & BoolMask) == BoolBase; } + + static constexpr inline bool IsObjectAttachment(CbFieldType Type) { return GetType(Type) == CbFieldType::ObjectAttachment; } + static constexpr inline bool IsBinaryAttachment(CbFieldType Type) { return GetType(Type) == CbFieldType::BinaryAttachment; } + static constexpr inline bool IsAttachment(CbFieldType Type) { return (Type & AttachmentMask) == AttachmentBase; } + + static constexpr inline bool IsHash(CbFieldType Type) + { + switch (GetType(Type)) + { + case CbFieldType::Hash: + case CbFieldType::BinaryAttachment: + case CbFieldType::ObjectAttachment: + return true; + default: + return false; + } + } + + static constexpr inline bool IsUuid(CbFieldType Type) { return GetType(Type) == CbFieldType::Uuid; } + static constexpr inline bool IsObjectId(CbFieldType Type) { return GetType(Type) == CbFieldType::ObjectId; } + + static constexpr inline bool IsCustomById(CbFieldType Type) { return GetType(Type) == CbFieldType::CustomById; } + static constexpr inline bool IsCustomByName(CbFieldType Type) { return GetType(Type) == CbFieldType::CustomByName; } + + static constexpr inline bool IsDateTime(CbFieldType Type) { return GetType(Type) == CbFieldType::DateTime; } + static constexpr inline bool IsTimeSpan(CbFieldType Type) { return GetType(Type) == CbFieldType::TimeSpan; } + + /** Whether the type is or may contain fields of any attachment type. */ + static constexpr inline bool MayContainAttachments(CbFieldType Type) + { + return int(IsObject(Type) == true) | int(IsArray(Type) == true) | int(IsAttachment(Type) == true); + } +}; + +/** Errors that can occur when accessing a field. */ +enum class CbFieldError : uint8_t +{ + /** The field is not in an error state. */ + None, + /** The value type does not match the requested type. */ + TypeError, + /** The value is out of range for the requested type. */ + RangeError, +}; + +class ICbVisitor +{ +public: + virtual void SetName(std::string_view Name) = 0; + virtual void BeginObject() = 0; + virtual void EndObject() = 0; + virtual void BeginArray() = 0; + virtual void EndArray() = 0; + virtual void VisitNull() = 0; + virtual void VisitBinary(SharedBuffer Value) = 0; + virtual void VisitString(std::string_view Value) = 0; + virtual void VisitInteger(int64_t Value) = 0; + virtual void VisitInteger(uint64_t Value) = 0; + virtual void VisitFloat(float Value) = 0; + virtual void VisitDouble(double Value) = 0; + virtual void VisitBool(bool value) = 0; + virtual void VisitCbAttachment(const IoHash& Value) = 0; + virtual void VisitBinaryAttachment(const IoHash& Value) = 0; + virtual void VisitHash(const IoHash& Value) = 0; + virtual void VisitUuid(const Guid& Value) = 0; + virtual void VisitObjectId(const Oid& Value) = 0; + virtual void VisitDateTime(DateTime Value) = 0; + virtual void VisitTimeSpan(TimeSpan Value) = 0; +}; + +/** A custom compact binary field type with an integer identifier. */ +struct CbCustomById +{ + /** An identifier for the sub-type of the field. */ + uint64_t Id = 0; + /** A view of the value. Lifetime is tied to the field that the value is associated with. */ + MemoryView Data; +}; + +/** A custom compact binary field type with a string identifier. */ +struct CbCustomByName +{ + /** An identifier for the sub-type of the field. Lifetime is tied to the field that the name is associated with. */ + std::u8string_view Name; + /** A view of the value. Lifetime is tied to the field that the value is associated with. */ + MemoryView Data; +}; + +namespace CompactBinaryPrivate { + /** Parameters for converting to an integer. */ + struct IntegerParams + { + /** Whether the output type has a sign bit. */ + uint32_t IsSigned : 1; + /** Bits of magnitude. (7 for int8) */ + uint32_t MagnitudeBits : 31; + }; + + /** Make integer params for the given integer type. */ + template<typename IntType> + static constexpr inline IntegerParams MakeIntegerParams() + { + IntegerParams Params; + Params.IsSigned = IntType(-1) < IntType(0); + Params.MagnitudeBits = 8 * sizeof(IntType) - Params.IsSigned; + return Params; + } + +} // namespace CompactBinaryPrivate + +/** + * An atom of data in the compact binary format. + * + * Accessing the value of a field is always a safe operation, even if accessed as the wrong type. + * An invalid access will return a default value for the requested type, and set an error code on + * the field that can be checked with GetLastError and HasLastError. A valid access will clear an + * error from a previous invalid access. + * + * A field is encoded in one or more bytes, depending on its type and the type of object or array + * that contains it. A field of an object or array which is non-uniform encodes its field type in + * the first byte, and includes the HasFieldName flag for a field in an object. The field name is + * encoded in a variable-length unsigned integer of its size in bytes, for named fields, followed + * by that many bytes of the UTF-8 encoding of the name with no null terminator. The remainder of + * the field is the payload and is described in the field type enum. Every field must be uniquely + * addressable when encoded, which means a zero-byte field is not permitted, and only arises in a + * uniform array of fields with no payload, where the answer is to encode as a non-uniform array. + * + * This type only provides a view into memory and does not perform any memory management itself. + * Use CbFieldRef to hold a reference to the underlying memory when necessary. + */ + +class CbFieldView +{ +public: + CbFieldView() = default; + + ZENCORE_API CbFieldView(const void* DataPointer, CbFieldType FieldType = CbFieldType::HasFieldType); + + /** Construct a field from a value, without access to the name. */ + inline explicit CbFieldView(const CbValue& Value); + + /** Returns the name of the field if it has a name, otherwise an empty view. */ + constexpr inline std::string_view GetName() const { return std::string_view(static_cast<const char*>(Payload) - NameLen, NameLen); } + /** Returns the name of the field if it has a name, otherwise an empty view. */ + constexpr inline std::u8string_view GetU8Name() const + { + return std::u8string_view(static_cast<const char8_t*>(Payload) - NameLen, NameLen); + } + + /** Returns the value for unchecked access. Prefer the typed accessors below. */ + inline CbValue GetValue() const; + + ZENCORE_API MemoryView AsBinaryView(MemoryView Default = MemoryView()); + ZENCORE_API CbObjectView AsObjectView(); + ZENCORE_API CbArrayView AsArrayView(); + ZENCORE_API std::string_view AsString(std::string_view Default = std::string_view()); + ZENCORE_API std::u8string_view AsU8String(std::u8string_view Default = std::u8string_view()); + + ZENCORE_API void IterateAttachments(std::function<void(CbFieldView)> Visitor) const; + + /** Access the field as an int8. Returns the provided default on error. */ + inline int8_t AsInt8(int8_t Default = 0) { return AsInteger<int8_t>(Default); } + /** Access the field as an int16. Returns the provided default on error. */ + inline int16_t AsInt16(int16_t Default = 0) { return AsInteger<int16_t>(Default); } + /** Access the field as an int32. Returns the provided default on error. */ + inline int32_t AsInt32(int32_t Default = 0) { return AsInteger<int32_t>(Default); } + /** Access the field as an int64. Returns the provided default on error. */ + inline int64_t AsInt64(int64_t Default = 0) { return AsInteger<int64_t>(Default); } + /** Access the field as a uint8. Returns the provided default on error. */ + inline uint8_t AsUInt8(uint8_t Default = 0) { return AsInteger<uint8_t>(Default); } + /** Access the field as a uint16. Returns the provided default on error. */ + inline uint16_t AsUInt16(uint16_t Default = 0) { return AsInteger<uint16_t>(Default); } + /** Access the field as a uint32. Returns the provided default on error. */ + inline uint32_t AsUInt32(uint32_t Default = 0) { return AsInteger<uint32_t>(Default); } + /** Access the field as a uint64. Returns the provided default on error. */ + inline uint64_t AsUInt64(uint64_t Default = 0) { return AsInteger<uint64_t>(Default); } + + /** Access the field as a float. Returns the provided default on error. */ + ZENCORE_API float AsFloat(float Default = 0.0f); + /** Access the field as a double. Returns the provided default on error. */ + ZENCORE_API double AsDouble(double Default = 0.0); + + /** Access the field as a bool. Returns the provided default on error. */ + ZENCORE_API bool AsBool(bool bDefault = false); + + /** Access the field as a hash referencing a compact binary attachment. Returns the provided default on error. */ + ZENCORE_API IoHash AsObjectAttachment(const IoHash& Default = IoHash()); + /** Access the field as a hash referencing a binary attachment. Returns the provided default on error. */ + ZENCORE_API IoHash AsBinaryAttachment(const IoHash& Default = IoHash()); + /** Access the field as a hash referencing an attachment. Returns the provided default on error. */ + ZENCORE_API IoHash AsAttachment(const IoHash& Default = IoHash()); + + /** Access the field as a hash. Returns the provided default on error. */ + ZENCORE_API IoHash AsHash(const IoHash& Default = IoHash()); + + /** Access the field as a UUID. Returns a nil UUID on error. */ + ZENCORE_API Guid AsUuid(); + /** Access the field as a UUID. Returns the provided default on error. */ + ZENCORE_API Guid AsUuid(const Guid& Default); + + /** Access the field as an OID. Returns a nil OID on error. */ + ZENCORE_API Oid AsObjectId(); + /** Access the field as a OID. Returns the provided default on error. */ + ZENCORE_API Oid AsObjectId(const Oid& Default); + + /** Access the field as a custom sub-type with an integer identifier. Returns the provided default on error. */ + ZENCORE_API CbCustomById AsCustomById(CbCustomById Default); + /** Access the field as a custom sub-type with a string identifier. Returns the provided default on error. */ + ZENCORE_API CbCustomByName AsCustomByName(CbCustomByName Default); + + /** Access the field as a date/time tick count. Returns the provided default on error. */ + ZENCORE_API int64_t AsDateTimeTicks(int64_t Default = 0); + + /** Access the field as a date/time. Returns a date/time at the epoch on error. */ + ZENCORE_API DateTime AsDateTime(); + /** Access the field as a date/time. Returns the provided default on error. */ + ZENCORE_API DateTime AsDateTime(DateTime Default); + + /** Access the field as a timespan tick count. Returns the provided default on error. */ + ZENCORE_API int64_t AsTimeSpanTicks(int64_t Default = 0); + + /** Access the field as a timespan. Returns an empty timespan on error. */ + ZENCORE_API TimeSpan AsTimeSpan(); + /** Access the field as a timespan. Returns the provided default on error. */ + ZENCORE_API TimeSpan AsTimeSpan(TimeSpan Default); + + /** True if the field has a name. */ + constexpr inline bool HasName() const { return CbFieldTypeOps::HasFieldName(Type); } + + constexpr inline bool IsNull() const { return CbFieldTypeOps::IsNull(Type); } + + constexpr inline bool IsObject() const { return CbFieldTypeOps::IsObject(Type); } + constexpr inline bool IsArray() const { return CbFieldTypeOps::IsArray(Type); } + + constexpr inline bool IsBinary() const { return CbFieldTypeOps::IsBinary(Type); } + constexpr inline bool IsString() const { return CbFieldTypeOps::IsString(Type); } + + /** Whether the field is an integer of unspecified range and sign. */ + constexpr inline bool IsInteger() const { return CbFieldTypeOps::IsInteger(Type); } + /** Whether the field is a float, or integer that supports implicit conversion. */ + constexpr inline bool IsFloat() const { return CbFieldTypeOps::IsFloat(Type); } + constexpr inline bool IsBool() const { return CbFieldTypeOps::IsBool(Type); } + + constexpr inline bool IsObjectAttachment() const { return CbFieldTypeOps::IsObjectAttachment(Type); } + constexpr inline bool IsBinaryAttachment() const { return CbFieldTypeOps::IsBinaryAttachment(Type); } + constexpr inline bool IsAttachment() const { return CbFieldTypeOps::IsAttachment(Type); } + + constexpr inline bool IsHash() const { return CbFieldTypeOps::IsHash(Type); } + constexpr inline bool IsUuid() const { return CbFieldTypeOps::IsUuid(Type); } + constexpr inline bool IsObjectId() const { return CbFieldTypeOps::IsObjectId(Type); } + + constexpr inline bool IsDateTime() const { return CbFieldTypeOps::IsDateTime(Type); } + constexpr inline bool IsTimeSpan() const { return CbFieldTypeOps::IsTimeSpan(Type); } + + /** Whether the field has a value. */ + constexpr inline explicit operator bool() const { return HasValue(); } + + /** + * Whether the field has a value. + * + * All fields in a valid object or array have a value. A field with no value is returned when + * finding a field by name fails or when accessing an iterator past the end. + */ + constexpr inline bool HasValue() const { return !CbFieldTypeOps::IsNone(Type); }; + + /** Whether the last field access encountered an error. */ + constexpr inline bool HasError() const { return Error != CbFieldError::None; } + + /** The type of error that occurred on the last field access, or None. */ + constexpr inline CbFieldError GetError() const { return Error; } + + /** Returns the size of the field in bytes, including the type and name. */ + ZENCORE_API uint64_t GetSize() const; + + /** Calculate the hash of the field, including the type and name. */ + ZENCORE_API IoHash GetHash() const; + + ZENCORE_API void GetHash(IoHashStream& HashStream) const; + + /** Feed the field (including type and name) to the stream function */ + inline void WriteToStream(auto Hash) const + { + const CbFieldType SerializedType = CbFieldTypeOps::GetSerializedType(Type); + Hash(&SerializedType, sizeof(SerializedType)); + auto View = GetViewNoType(); + Hash(View.GetData(), View.GetSize()); + } + + /** Copy the field into a buffer of exactly GetSize() bytes, including the type and name. */ + ZENCORE_API void CopyTo(MutableMemoryView Buffer) const; + + /** Copy the field into an archive, including its type and name. */ + ZENCORE_API void CopyTo(BinaryWriter& Ar) const; + + /** + * Whether this field is identical to the other field. + * + * Performs a deep comparison of any contained arrays or objects and their fields. Comparison + * assumes that both fields are valid and are written in the canonical format. Fields must be + * written in the same order in arrays and objects, and name comparison is case sensitive. If + * these assumptions do not hold, this may return false for equivalent inputs. Validation can + * be performed with ValidateCompactBinary, except for field order and field name case. + */ + ZENCORE_API bool Equals(const CbFieldView& Other) const; + + /** Returns a view of the field, including the type and name when present. */ + ZENCORE_API MemoryView GetView() const; + + /** + * Try to get a view of the field as it would be serialized, such as by CopyTo. + * + * A serialized view is not available if the field has an externally-provided type. + * Access the serialized form of such fields using CopyTo or FCbFieldRef::Clone. + */ + inline bool TryGetSerializedView(MemoryView& OutView) const + { + if (CbFieldTypeOps::HasFieldType(Type)) + { + OutView = GetView(); + return true; + } + return false; + } + +protected: + /** Returns a view of the name and value payload, which excludes the type. */ + ZENCORE_API MemoryView GetViewNoType() const; + + /** Returns a view of the value payload, which excludes the type and name. */ + inline MemoryView GetPayloadView() const { return MemoryView(Payload, GetPayloadSize()); } + + /** Returns the type of the field including flags. */ + constexpr inline CbFieldType GetType() const { return Type; } + + /** Returns the start of the value payload. */ + constexpr inline const void* GetPayload() const { return Payload; } + + /** Returns the end of the value payload. */ + inline const void* GetPayloadEnd() const { return static_cast<const uint8_t*>(Payload) + GetPayloadSize(); } + + /** Returns the size of the value payload in bytes, which is the field excluding the type and name. */ + ZENCORE_API uint64_t GetPayloadSize() const; + + /** Assign a field from a pointer to its data and an optional externally-provided type. */ + inline void Assign(const void* InData, const CbFieldType InType) + { + static_assert(std::is_trivially_destructible<CbFieldView>::value, + "This optimization requires CbField to be trivially destructible!"); + new (this) CbFieldView(InData, InType); + } + +private: + /** + * Access the field as the given integer type. + * + * Returns the provided default if the value cannot be represented in the output type. + */ + template<typename IntType> + inline IntType AsInteger(IntType Default) + { + return IntType(AsInteger(uint64_t(Default), CompactBinaryPrivate::MakeIntegerParams<IntType>())); + } + + ZENCORE_API uint64_t AsInteger(uint64_t Default, CompactBinaryPrivate::IntegerParams Params); + +private: + /** The field type, with the transient HasFieldType flag if the field contains its type. */ + CbFieldType Type = CbFieldType::None; + /** The error (if any) that occurred on the last field access. */ + CbFieldError Error = CbFieldError::None; + /** The number of bytes for the name stored before the payload. */ + uint32_t NameLen = 0; + /** The value payload, which also points to the end of the name. */ + const void* Payload = nullptr; +}; + +template<typename FieldType> +class TCbFieldIterator : public FieldType +{ +public: + /** Construct an empty field range. */ + constexpr TCbFieldIterator() = default; + + inline TCbFieldIterator& operator++() + { + const void* const PayloadEnd = FieldType::GetPayloadEnd(); + const int64_t AtEndMask = int64_t(PayloadEnd == FieldsEnd) - 1; + const CbFieldType NextType = CbFieldType(int64_t(FieldType::GetType()) & AtEndMask); + const void* const NextField = reinterpret_cast<const void*>(int64_t(PayloadEnd) & AtEndMask); + const void* const NextFieldsEnd = reinterpret_cast<const void*>(int64_t(FieldsEnd) & AtEndMask); + + FieldType::Assign(NextField, NextType); + FieldsEnd = NextFieldsEnd; + return *this; + } + + inline TCbFieldIterator operator++(int) + { + TCbFieldIterator It(*this); + ++*this; + return It; + } + + constexpr inline FieldType& operator*() { return *this; } + constexpr inline FieldType* operator->() { return this; } + + /** Reset this to an empty field range. */ + inline void Reset() { *this = TCbFieldIterator(); } + + /** Returns the size of the fields in the range in bytes. */ + ZENCORE_API uint64_t GetRangeSize() const; + + /** Calculate the hash of every field in the range. */ + ZENCORE_API IoHash GetRangeHash() const; + ZENCORE_API void GetRangeHash(IoHashStream& Hash) const; + + using FieldType::Equals; + + template<typename OtherFieldType> + constexpr inline bool Equals(const TCbFieldIterator<OtherFieldType>& Other) const + { + return FieldType::GetPayload() == Other.OtherFieldType::GetPayload() && FieldsEnd == Other.FieldsEnd; + } + + template<typename OtherFieldType> + constexpr inline bool operator==(const TCbFieldIterator<OtherFieldType>& Other) const + { + return Equals(Other); + } + + template<typename OtherFieldType> + constexpr inline bool operator!=(const TCbFieldIterator<OtherFieldType>& Other) const + { + return !Equals(Other); + } + + /** Copy the field range into a buffer of exactly GetRangeSize() bytes. */ + ZENCORE_API void CopyRangeTo(MutableMemoryView Buffer) const; + + /** Invoke the visitor for every attachment in the field range. */ + ZENCORE_API void IterateRangeAttachments(std::function<void(CbFieldView)> Visitor) const; + + /** Create a view of every field in the range. */ + inline MemoryView GetRangeView() const { return MemoryView(FieldType::GetView().GetData(), FieldsEnd); } + + /** + * Try to get a view of every field in the range as they would be serialized. + * + * A serialized view is not available if the underlying fields have an externally-provided type. + * Access the serialized form of such ranges using CbFieldRefIterator::CloneRange. + */ + inline bool TryGetSerializedRangeView(MemoryView& OutView) const + { + if (CbFieldTypeOps::HasFieldType(FieldType::GetType())) + { + OutView = GetRangeView(); + return true; + } + return false; + } + +protected: + /** Construct a field range that contains exactly one field. */ + constexpr inline explicit TCbFieldIterator(FieldType InField) : FieldType(std::move(InField)), FieldsEnd(FieldType::GetPayloadEnd()) {} + + /** + * Construct a field range from the first field and a pointer to the end of the last field. + * + * @param InField The first field, or the default field if there are no fields. + * @param InFieldsEnd A pointer to the end of the payload of the last field, or null. + */ + constexpr inline TCbFieldIterator(FieldType&& InField, const void* InFieldsEnd) : FieldType(std::move(InField)), FieldsEnd(InFieldsEnd) + { + } + + /** Returns the end of the last field, or null for an iterator at the end. */ + template<typename OtherFieldType> + static inline const void* GetFieldsEnd(const TCbFieldIterator<OtherFieldType>& It) + { + return It.FieldsEnd; + } + +private: + friend inline TCbFieldIterator begin(const TCbFieldIterator& Iterator) { return Iterator; } + friend inline TCbFieldIterator end(const TCbFieldIterator&) { return TCbFieldIterator(); } + +private: + template<typename OtherType> + friend class TCbFieldIterator; + + friend class CbFieldViewIterator; + + friend class CbFieldIterator; + + /** Pointer to the first byte past the end of the last field. Set to null at the end. */ + const void* FieldsEnd = nullptr; +}; + +/** + * Iterator for CbField. + * + * @see CbFieldIterator + */ +class CbFieldViewIterator : public TCbFieldIterator<CbFieldView> +{ +public: + constexpr CbFieldViewIterator() = default; + + /** Construct a field range that contains exactly one field. */ + static inline CbFieldViewIterator MakeSingle(const CbFieldView& Field) { return CbFieldViewIterator(Field); } + + /** + * Construct a field range from a buffer containing zero or more valid fields. + * + * @param View A buffer containing zero or more valid fields. + * @param Type HasFieldType means that View contains the type. Otherwise, use the given type. + */ + static inline CbFieldViewIterator MakeRange(MemoryView View, CbFieldType Type = CbFieldType::HasFieldType) + { + return !View.IsEmpty() ? TCbFieldIterator(CbFieldView(View.GetData(), Type), View.GetDataEnd()) : CbFieldViewIterator(); + } + + /** Construct an iterator from another iterator. */ + template<typename OtherFieldType> + inline CbFieldViewIterator(const TCbFieldIterator<OtherFieldType>& It) + : TCbFieldIterator(ImplicitConv<CbFieldView>(It), GetFieldsEnd(It)) + { + } + +private: + using TCbFieldIterator::TCbFieldIterator; +}; + +/** + * Serialize a compact binary array to JSON. + */ +ZENCORE_API void CompactBinaryToJson(const CbArrayView& Object, StringBuilderBase& Builder); + +/** + * Array of CbField that have no names. + * + * Accessing a field of the array requires iteration. Access by index is not provided because the + * cost of accessing an item by index scales linearly with the index. + * + * This type only provides a view into memory and does not perform any memory management itself. + * Use CbArrayRef to hold a reference to the underlying memory when necessary. + */ +class CbArrayView : protected CbFieldView +{ + friend class CbFieldView; + +public: + /** @see CbField::CbField */ + using CbFieldView::CbFieldView; + + /** Construct an array with no fields. */ + ZENCORE_API CbArrayView(); + + /** Returns the number of items in the array. */ + ZENCORE_API uint64_t Num() const; + + /** Create an iterator for the fields of this array. */ + ZENCORE_API CbFieldViewIterator CreateViewIterator() const; + + /** Visit the fields of this array. */ + ZENCORE_API void VisitFields(ICbVisitor& Visitor); + + /** Access the array as an array field. */ + inline CbFieldView AsFieldView() const { return static_cast<const CbFieldView&>(*this); } + + /** Construct an array from an array field. No type check is performed! */ + static inline CbArrayView FromFieldView(const CbFieldView& Field) { return CbArrayView(Field); } + + /** Whether the array has any fields. */ + inline explicit operator bool() const { return Num() > 0; } + + /** Returns the size of the array in bytes if serialized by itself with no name. */ + ZENCORE_API uint64_t GetSize() const; + + /** Calculate the hash of the array if serialized by itself with no name. */ + ZENCORE_API IoHash GetHash() const; + + ZENCORE_API void GetHash(IoHashStream& Stream) const; + + /** + * Whether this array is identical to the other array. + * + * Performs a deep comparison of any contained arrays or objects and their fields. Comparison + * assumes that both fields are valid and are written in the canonical format. Fields must be + * written in the same order in arrays and objects, and name comparison is case sensitive. If + * these assumptions do not hold, this may return false for equivalent inputs. Validation can + * be done with the All mode to check these assumptions about the format of the inputs. + */ + ZENCORE_API bool Equals(const CbArrayView& Other) const; + + /** Copy the array into a buffer of exactly GetSize() bytes, with no name. */ + ZENCORE_API void CopyTo(MutableMemoryView Buffer) const; + + /** Copy the array into an archive, including its type and name. */ + ZENCORE_API void CopyTo(BinaryWriter& Ar) const; + + ///** Invoke the visitor for every attachment in the array. */ + inline void IterateAttachments(std::function<void(CbFieldView)> Visitor) const + { + CreateViewIterator().IterateRangeAttachments(Visitor); + } + + /** Returns a view of the array, including the type and name when present. */ + using CbFieldView::GetView; + + StringBuilderBase& ToJson(StringBuilderBase& Builder) const + { + CompactBinaryToJson(*this, Builder); + return Builder; + } + +private: + friend inline CbFieldViewIterator begin(const CbArrayView& Array) { return Array.CreateViewIterator(); } + friend inline CbFieldViewIterator end(const CbArrayView&) { return CbFieldViewIterator(); } + + /** Construct an array from an array field. No type check is performed! Use via FromField. */ + inline explicit CbArrayView(const CbFieldView& Field) : CbFieldView(Field) {} +}; + +/** + * Serialize a compact binary object to JSON. + */ +ZENCORE_API void CompactBinaryToJson(const CbObjectView& Object, StringBuilderBase& Builder); + +class CbObjectView : protected CbFieldView +{ + friend class CbFieldView; + +public: + /** @see CbField::CbField */ + using CbFieldView::CbFieldView; + + using CbFieldView::TryGetSerializedView; + + /** Construct an object with no fields. */ + ZENCORE_API CbObjectView(); + + /** Create an iterator for the fields of this object. */ + ZENCORE_API CbFieldViewIterator CreateViewIterator() const; + + /** Visit the fields of this object. */ + ZENCORE_API void VisitFields(ICbVisitor& Visitor); + + /** + * Find a field by case-sensitive name comparison. + * + * The cost of this operation scales linearly with the number of fields in the object. Prefer + * to iterate over the fields only once when consuming an object. + * + * @param Name The name of the field. + * @return The matching field if found, otherwise a field with no value. + */ + ZENCORE_API CbFieldView FindView(std::string_view Name) const; + + /** Find a field by case-insensitive name comparison. */ + ZENCORE_API CbFieldView FindViewIgnoreCase(std::string_view Name) const; + + /** Find a field by case-sensitive name comparison. */ + inline CbFieldView operator[](std::string_view Name) const { return FindView(Name); } + + /** Access the object as an object field. */ + inline CbFieldView AsFieldView() const { return static_cast<const CbFieldView&>(*this); } + + /** Construct an object from an object field. No type check is performed! */ + static inline CbObjectView FromFieldView(const CbFieldView& Field) { return CbObjectView(Field); } + + /** Whether the object has any fields. */ + ZENCORE_API explicit operator bool() const; + + /** Returns the size of the object in bytes if serialized by itself with no name. */ + ZENCORE_API uint64_t GetSize() const; + + /** Calculate the hash of the object if serialized by itself with no name. */ + ZENCORE_API IoHash GetHash() const; + + ZENCORE_API void GetHash(IoHashStream& HashStream) const; + + /** + * Whether this object is identical to the other object. + * + * Performs a deep comparison of any contained arrays or objects and their fields. Comparison + * assumes that both fields are valid and are written in the canonical format. Fields must be + * written in the same order in arrays and objects, and name comparison is case sensitive. If + * these assumptions do not hold, this may return false for equivalent inputs. Validation can + * be done with the All mode to check these assumptions about the format of the inputs. + */ + ZENCORE_API bool Equals(const CbObjectView& Other) const; + + /** Copy the object into a buffer of exactly GetSize() bytes, with no name. */ + ZENCORE_API void CopyTo(MutableMemoryView Buffer) const; + + /** Copy the field into an archive, including its type and name. */ + ZENCORE_API void CopyTo(BinaryWriter& Ar) const; + + ///** Invoke the visitor for every attachment in the object. */ + inline void IterateAttachments(std::function<void(CbFieldView)> Visitor) const + { + CreateViewIterator().IterateRangeAttachments(Visitor); + } + + /** Returns a view of the object, including the type and name when present. */ + using CbFieldView::GetView; + + /** Whether the field has a value. */ + using CbFieldView::operator bool; + + StringBuilderBase& ToJson(StringBuilderBase& Builder) const + { + CompactBinaryToJson(*this, Builder); + return Builder; + } + +private: + friend inline CbFieldViewIterator begin(const CbObjectView& Object) { return Object.CreateViewIterator(); } + friend inline CbFieldViewIterator end(const CbObjectView&) { return CbFieldViewIterator(); } + + /** Construct an object from an object field. No type check is performed! Use via FromField. */ + inline explicit CbObjectView(const CbFieldView& Field) : CbFieldView(Field) {} +}; + +////////////////////////////////////////////////////////////////////////// + +/** A reference to a function that is used to allocate buffers for compact binary data. */ +using BufferAllocator = std::function<UniqueBuffer(uint64_t Size)>; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** A wrapper that holds a reference to the buffer that contains its compact binary value. */ +template<typename BaseType> +class CbBuffer : public BaseType +{ +public: + /** Construct a default value. */ + CbBuffer() = default; + + /** + * Construct a value from a pointer to its data and an optional externally-provided type. + * + * @param ValueBuffer A buffer that exactly contains the value. + * @param Type HasFieldType means that ValueBuffer contains the type. Otherwise, use the given type. + */ + inline explicit CbBuffer(SharedBuffer ValueBuffer, CbFieldType Type = CbFieldType::HasFieldType) + { + if (ValueBuffer) + { + BaseType::operator=(BaseType(ValueBuffer.GetData(), Type)); + ZEN_ASSERT(ValueBuffer.GetView().Contains(BaseType::GetView())); + Buffer = std::move(ValueBuffer); + } + } + + /** Construct a value that holds a reference to the buffer that contains it. */ + inline CbBuffer(const BaseType& Value, SharedBuffer OuterBuffer) : BaseType(Value) + { + if (OuterBuffer) + { + ZEN_ASSERT(OuterBuffer.GetView().Contains(BaseType::GetView())); + Buffer = std::move(OuterBuffer); + } + } + + /** Construct a value that holds a reference to the buffer of the outer that contains it. */ + template<typename OtherBaseType> + inline CbBuffer(const BaseType& Value, CbBuffer<OtherBaseType> OuterRef) : CbBuffer(Value, std::move(OuterRef.Buffer)) + { + } + + /** Reset this to a default value and null buffer. */ + inline void Reset() { *this = CbBuffer(); } + + /** Whether this reference has ownership of the memory in its buffer. */ + inline bool IsOwned() const { return Buffer && Buffer.IsOwned(); } + + /** Clone the value, if necessary, to a buffer that this reference has ownership of. */ + inline void MakeOwned() + { + if (!IsOwned()) + { + UniqueBuffer MutableBuffer = UniqueBuffer::Alloc(BaseType::GetSize()); + BaseType::CopyTo(MutableBuffer); + BaseType::operator=(BaseType(MutableBuffer.GetData())); + Buffer = std::move(MutableBuffer); + } + } + + /** Returns a buffer that exactly contains this value. */ + inline SharedBuffer GetBuffer() const + { + const MemoryView View = BaseType::GetView(); + const SharedBuffer& OuterBuffer = GetOuterBuffer(); + return View == OuterBuffer.GetView() ? OuterBuffer : SharedBuffer::MakeView(View, OuterBuffer); + } + + /** Returns the outer buffer (if any) that contains this value. */ + inline const SharedBuffer& GetOuterBuffer() const& { return Buffer; } + inline SharedBuffer GetOuterBuffer() && { return std::move(Buffer); } + +private: + template<typename OtherType> + friend class CbBuffer; + + SharedBuffer Buffer; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * Factory functions for types derived from CbBuffer. + * + * This uses the curiously recurring template pattern to construct the correct type of reference. + * The derived type inherits from CbBufferRef and this type to expose the factory functions. + */ +template<typename RefType, typename BaseType> +class CbBufferFactory +{ +public: + /** Construct a value from an owned clone of its memory. */ + static inline RefType Clone(const void* const Data) { return Clone(BaseType(Data)); } + + /** Construct a value from an owned clone of its memory. */ + static inline RefType Clone(const BaseType& Value) + { + RefType Ref = MakeView(Value); + Ref.MakeOwned(); + return Ref; + } + + /** Construct a value from a read-only view of its memory and its optional outer buffer. */ + static inline RefType MakeView(const void* const Data, SharedBuffer OuterBuffer = SharedBuffer()) + { + return MakeView(BaseType(Data), std::move(OuterBuffer)); + } + + /** Construct a value from a read-only view of its memory and its optional outer buffer. */ + static inline RefType MakeView(const BaseType& Value, SharedBuffer OuterBuffer = SharedBuffer()) + { + return RefType(Value, std::move(OuterBuffer)); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +class CbArray; +class CbObject; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * A field that can hold a reference to the memory that contains it. + * + * @see CbBufferRef + */ +class CbField : public CbBuffer<CbFieldView>, public CbBufferFactory<CbField, CbFieldView> +{ +public: + using CbBuffer::CbBuffer; + + /** Access the field as an object. Defaults to an empty object on error. */ + inline CbObject AsObject() &; + inline CbObject AsObject() &&; + + /** Access the field as an array. Defaults to an empty array on error. */ + inline CbArray AsArray() &; + inline CbArray AsArray() &&; + + /** Access the field as binary. Returns the provided default on error. */ + inline SharedBuffer AsBinary(const SharedBuffer& Default = SharedBuffer()) &; + inline SharedBuffer AsBinary(const SharedBuffer& Default = SharedBuffer()) &&; +}; + +/** + * Iterator for CbFieldRef. + * + * @see CbFieldIterator + */ +class CbFieldIterator : public TCbFieldIterator<CbField> +{ +public: + /** Construct a field range from an owned clone of a range. */ + ZENCORE_API static CbFieldIterator CloneRange(const CbFieldViewIterator& It); + + /** Construct a field range from an owned clone of a range. */ + static inline CbFieldIterator CloneRange(const CbFieldIterator& It) { return CloneRange(CbFieldViewIterator(It)); } + + /** Construct a field range that contains exactly one field. */ + static inline CbFieldIterator MakeSingle(CbField Field) { return CbFieldIterator(std::move(Field)); } + + /** + * Construct a field range from a buffer containing zero or more valid fields. + * + * @param Buffer A buffer containing zero or more valid fields. + * @param Type HasFieldType means that Buffer contains the type. Otherwise, use the given type. + */ + static inline CbFieldIterator MakeRange(SharedBuffer Buffer, CbFieldType Type = CbFieldType::HasFieldType) + { + if (Buffer.GetSize()) + { + const void* const DataEnd = Buffer.GetView().GetDataEnd(); + return CbFieldIterator(CbField(std::move(Buffer), Type), DataEnd); + } + return CbFieldIterator(); + } + + /** Construct a field range from an iterator and its optional outer buffer. */ + static inline CbFieldIterator MakeRangeView(const CbFieldViewIterator& It, SharedBuffer OuterBuffer = SharedBuffer()) + { + return CbFieldIterator(CbField(It, std::move(OuterBuffer)), GetFieldsEnd(It)); + } + + /** Construct an empty field range. */ + constexpr CbFieldIterator() = default; + + /** Clone the range, if necessary, to a buffer that this reference has ownership of. */ + inline void MakeRangeOwned() + { + if (!IsOwned()) + { + *this = CloneRange(*this); + } + } + + /** Returns a buffer that exactly contains the field range. */ + ZENCORE_API SharedBuffer GetRangeBuffer() const; + +private: + using TCbFieldIterator::TCbFieldIterator; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * An array that can hold a reference to the memory that contains it. + * + * @see CbBuffer + */ +class CbArray : public CbBuffer<CbArrayView>, public CbBufferFactory<CbArray, CbArrayView> +{ +public: + using CbBuffer::CbBuffer; + + /** Create an iterator for the fields of this array. */ + inline CbFieldIterator CreateIterator() const { return CbFieldIterator::MakeRangeView(CreateViewIterator(), GetOuterBuffer()); } + + /** Access the array as an array field. */ + inline CbField AsField() const& { return CbField(CbArrayView::AsFieldView(), *this); } + + /** Access the array as an array field. */ + inline CbField AsField() && { return CbField(CbArrayView::AsFieldView(), std::move(*this)); } + +private: + friend inline CbFieldIterator begin(const CbArray& Array) { return Array.CreateIterator(); } + friend inline CbFieldIterator end(const CbArray&) { return CbFieldIterator(); } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * An object that can hold a reference to the memory that contains it. + * + * @see CbBuffer + */ +class CbObject : public CbBuffer<CbObjectView>, public CbBufferFactory<CbObject, CbObjectView> +{ +public: + using CbBuffer::CbBuffer; + + /** Create an iterator for the fields of this object. */ + inline CbFieldIterator CreateIterator() const { return CbFieldIterator::MakeRangeView(CreateViewIterator(), GetOuterBuffer()); } + + /** Find a field by case-sensitive name comparison. */ + inline CbField Find(std::string_view Name) const + { + if (CbFieldView Field = FindView(Name)) + { + return CbField(Field, *this); + } + return CbField(); + } + + /** Find a field by case-insensitive name comparison. */ + inline CbField FindIgnoreCase(std::string_view Name) const + { + if (CbFieldView Field = FindViewIgnoreCase(Name)) + { + return CbField(Field, *this); + } + return CbField(); + } + + /** Find a field by case-sensitive name comparison. */ + inline CbFieldView operator[](std::string_view Name) const { return Find(Name); } + + /** Access the object as an object field. */ + inline CbField AsField() const& { return CbField(CbObjectView::AsFieldView(), *this); } + + /** Access the object as an object field. */ + inline CbField AsField() && { return CbField(CbObjectView::AsFieldView(), std::move(*this)); } + +private: + friend inline CbFieldIterator begin(const CbObject& Object) { return Object.CreateIterator(); } + friend inline CbFieldIterator end(const CbObject&) { return CbFieldIterator(); } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +inline CbObject +CbField::AsObject() & +{ + return IsObject() ? CbObject(AsObjectView(), *this) : CbObject(); +} + +inline CbObject +CbField::AsObject() && +{ + return IsObject() ? CbObject(AsObjectView(), std::move(*this)) : CbObject(); +} + +inline CbArray +CbField::AsArray() & +{ + return IsArray() ? CbArray(AsArrayView(), *this) : CbArray(); +} + +inline CbArray +CbField::AsArray() && +{ + return IsArray() ? CbArray(AsArrayView(), std::move(*this)) : CbArray(); +} + +inline SharedBuffer +CbField::AsBinary(const SharedBuffer& Default) & +{ + const MemoryView View = AsBinaryView(); + return !HasError() ? SharedBuffer::MakeView(View, GetOuterBuffer()) : Default; +} + +inline SharedBuffer +CbField::AsBinary(const SharedBuffer& Default) && +{ + const MemoryView View = AsBinaryView(); + return !HasError() ? SharedBuffer::MakeView(View, std::move(*this).GetOuterBuffer()) : Default; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * Load a compact binary field from an archive. + * + * The field may be an array or an object, which the caller can convert to by using AsArray or + * AsObject as appropriate. The buffer allocator is called to provide the buffer for the field + * to load into once its size has been determined. + * + * @param Ar Archive to read the field from. An error state is set on failure. + * @param Allocator Allocator for the buffer that the field is loaded into. + * @return A field with a reference to the allocated buffer, or a default field on failure. + */ +ZENCORE_API CbField LoadCompactBinary(BinaryReader& Ar, BufferAllocator Allocator); + +ZENCORE_API CbObject LoadCompactBinaryObject(IoBuffer&& Payload); +ZENCORE_API CbObject LoadCompactBinaryObject(const IoBuffer& Payload); +ZENCORE_API CbObject LoadCompactBinaryObject(CompressedBuffer&& Payload); +ZENCORE_API CbObject LoadCompactBinaryObject(const CompressedBuffer& Payload); + +/** + * Load a compact binary from JSON. + */ +ZENCORE_API CbFieldIterator LoadCompactBinaryFromJson(std::string_view Json, std::string& Error); +ZENCORE_API CbFieldIterator LoadCompactBinaryFromJson(std::string_view Json); + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * Determine the size in bytes of the compact binary field at the start of the view. + * + * This may be called on an incomplete or invalid field, in which case the returned size is zero. + * A size can always be extracted from a valid field with no name if a view of at least the first + * 10 bytes is provided, regardless of field size. For fields with names, the size of view needed + * to calculate a size is at most 10 + MaxNameLen + MeasureVarUInt(MaxNameLen). + * + * This function can be used when streaming a field, for example, to determine the size of buffer + * to fill before attempting to construct a field from it. + * + * @param View A memory view that may contain the start of a field. + * @param Type HasFieldType means that View contains the type. Otherwise, use the given type. + */ +ZENCORE_API uint64_t MeasureCompactBinary(MemoryView View, CbFieldType Type = CbFieldType::HasFieldType); + +/** + * Try to determine the type and size of the compact binary field at the start of the view. + * + * This may be called on an incomplete or invalid field, in which case it will return false, with + * OutSize being 0 for invalid fields, otherwise the minimum view size necessary to make progress + * in measuring the field on the next call to this function. + * + * @note A return of true from this function does not indicate that the entire field is valid. + * + * @param InView A memory view that may contain the start of a field. + * @param OutType The type (with flags) of the field. None is written until a value is available. + * @param OutSize The total field size for a return of true, 0 for invalid fields, or the size to + * make progress in measuring the field on the next call to this function. + * @param InType HasFieldType means that InView contains the type. Otherwise, use the given type. + * @return true if the size of the field was determined, otherwise false. + */ +ZENCORE_API bool TryMeasureCompactBinary(MemoryView InView, + CbFieldType& OutType, + uint64_t& OutSize, + CbFieldType InType = CbFieldType::HasFieldType); + +inline CbFieldViewIterator +begin(CbFieldView& View) +{ + if (View.IsArray()) + { + return View.AsArrayView().CreateViewIterator(); + } + else if (View.IsObject()) + { + return View.AsObjectView().CreateViewIterator(); + } + + return CbFieldViewIterator(); +} + +inline CbFieldViewIterator +end(CbFieldView&) +{ + return CbFieldViewIterator(); +} + +void uson_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/compactbinarybuilder.h b/src/zencore/include/zencore/compactbinarybuilder.h new file mode 100644 index 000000000..4be8c2ba5 --- /dev/null +++ b/src/zencore/include/zencore/compactbinarybuilder.h @@ -0,0 +1,661 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <zencore/compactbinary.h> + +#include <zencore/enumflags.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/refcount.h> +#include <zencore/sha1.h> + +#include <atomic> +#include <memory> +#include <string> +#include <string_view> +#include <type_traits> +#include <vector> + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +class CbAttachment; +class BinaryWriter; + +/** + * A writer for compact binary object, arrays, and fields. + * + * The writer produces a sequence of fields that can be saved to a provided memory buffer or into + * a new owned buffer. The typical use case is to write a single object, which can be accessed by + * calling Save().AsObjectRef() or Save(Buffer).AsObject(). + * + * The writer will assert on most incorrect usage and will always produce valid compact binary if + * provided with valid input. The writer does not check for invalid UTF-8 string encoding, object + * fields with duplicate names, or invalid compact binary being copied from another source. + * + * It is most convenient to use the streaming API for the writer, as demonstrated in the example. + * + * When writing a small amount of compact binary data, TCbWriter can be more efficient as it uses + * a fixed-size stack buffer for storage before spilling onto the heap. + * + * @see TCbWriter + * + * Example: + * + * CbObjectRef WriteObject() + * { + * CbWriter<256> Writer; + * Writer.BeginObject(); + * + * Writer << "Resize" << true; + * Writer << "MaxWidth" << 1024; + * Writer << "MaxHeight" << 1024; + * + * Writer.BeginArray(); + * Writer << "FormatA" << "FormatB" << "FormatC"; + * Writer.EndArray(); + * + * Writer.EndObject(); + * return Writer.Save().AsObjectRef(); + * } + */ +class CbWriter +{ +public: + ZENCORE_API CbWriter(); + ZENCORE_API ~CbWriter(); + + CbWriter(const CbWriter&) = delete; + CbWriter& operator=(const CbWriter&) = delete; + + /** Empty the writer without releasing any allocated memory. */ + ZENCORE_API void Reset(); + + /** + * Serialize the field(s) to an owned buffer and return it as an iterator. + * + * It is not valid to call this function in the middle of writing an object, array, or field. + * The writer remains valid for further use when this function returns. + */ + ZENCORE_API CbFieldIterator Save(); + + /** + * Serialize the field(s) to memory. + * + * It is not valid to call this function in the middle of writing an object, array, or field. + * The writer remains valid for further use when this function returns. + * + * @param Buffer A mutable memory view to write to. Must be exactly GetSaveSize() bytes. + * @return An iterator for the field(s) written to the buffer. + */ + ZENCORE_API CbFieldViewIterator Save(MutableMemoryView Buffer); + + ZENCORE_API void Save(BinaryWriter& Writer); + + /** + * The size of buffer (in bytes) required to serialize the fields that have been written. + * + * It is not valid to call this function in the middle of writing an object, array, or field. + */ + ZENCORE_API uint64_t GetSaveSize() const; + + /** + * Sets the name of the next field to be written. + * + * It is not valid to call this function when writing a field inside an array. + * Names must be valid UTF-8 and must be unique within an object. + */ + ZENCORE_API CbWriter& SetName(std::string_view Name); + + /** Copy the value (not the name) of an existing field. */ + inline void AddField(std::string_view Name, const CbFieldView& Value) + { + SetName(Name); + AddField(Value); + } + + ZENCORE_API void AddField(const CbFieldView& Value); + + /** Copy the value (not the name) of an existing field. Holds a reference if owned. */ + inline void AddField(std::string_view Name, const CbField& Value) + { + SetName(Name); + AddField(Value); + } + ZENCORE_API void AddField(const CbField& Value); + + /** Begin a new object. Must have a matching call to EndObject. */ + inline void BeginObject(std::string_view Name) + { + SetName(Name); + BeginObject(); + } + ZENCORE_API void BeginObject(); + /** End an object after its fields have been written. */ + ZENCORE_API void EndObject(); + + /** Copy the value (not the name) of an existing object. */ + inline void AddObject(std::string_view Name, const CbObjectView& Value) + { + SetName(Name); + AddObject(Value); + } + ZENCORE_API void AddObject(const CbObjectView& Value); + /** Copy the value (not the name) of an existing object. Holds a reference if owned. */ + inline void AddObject(std::string_view Name, const CbObject& Value) + { + SetName(Name); + AddObject(Value); + } + ZENCORE_API void AddObject(const CbObject& Value); + + /** Begin a new array. Must have a matching call to EndArray. */ + inline void BeginArray(std::string_view Name) + { + SetName(Name); + BeginArray(); + } + ZENCORE_API void BeginArray(); + /** End an array after its fields have been written. */ + ZENCORE_API void EndArray(); + + /** Copy the value (not the name) of an existing array. */ + inline void AddArray(std::string_view Name, const CbArrayView& Value) + { + SetName(Name); + AddArray(Value); + } + ZENCORE_API void AddArray(const CbArrayView& Value); + /** Copy the value (not the name) of an existing array. Holds a reference if owned. */ + inline void AddArray(std::string_view Name, const CbArray& Value) + { + SetName(Name); + AddArray(Value); + } + ZENCORE_API void AddArray(const CbArray& Value); + + /** Write a null field. */ + inline void AddNull(std::string_view Name) + { + SetName(Name); + AddNull(); + } + ZENCORE_API void AddNull(); + + /** Write a binary field by copying Size bytes from Value. */ + inline void AddBinary(std::string_view Name, const void* Value, uint64_t Size) + { + SetName(Name); + AddBinary(Value, Size); + } + ZENCORE_API void AddBinary(const void* Value, uint64_t Size); + /** Write a binary field by copying the view. */ + inline void AddBinary(std::string_view Name, MemoryView Value) + { + SetName(Name); + AddBinary(Value); + } + inline void AddBinary(MemoryView Value) { AddBinary(Value.GetData(), Value.GetSize()); } + + /** Write a binary field by copying the buffer. Holds a reference if owned. */ + inline void AddBinary(std::string_view Name, IoBuffer Value) + { + SetName(Name); + AddBinary(std::move(Value)); + } + ZENCORE_API void AddBinary(IoBuffer Value); + ZENCORE_API void AddBinary(SharedBuffer Value); + + inline void AddBinary(std::string_view Name, const CompositeBuffer& Buffer) + { + SetName(Name); + AddBinary(Buffer); + } + ZENCORE_API void AddBinary(const CompositeBuffer& Buffer); + + /** Write a string field by copying the UTF-8 value. */ + inline void AddString(std::string_view Name, std::string_view Value) + { + SetName(Name); + AddString(Value); + } + ZENCORE_API void AddString(std::string_view Value); + /** Write a string field by converting the UTF-16 value to UTF-8. */ + inline void AddString(std::string_view Name, std::wstring_view Value) + { + SetName(Name); + AddString(Value); + } + ZENCORE_API void AddString(std::wstring_view Value); + + /** Write an integer field. */ + inline void AddInteger(std::string_view Name, int32_t Value) + { + SetName(Name); + AddInteger(Value); + } + ZENCORE_API void AddInteger(int32_t Value); + /** Write an integer field. */ + inline void AddInteger(std::string_view Name, int64_t Value) + { + SetName(Name); + AddInteger(Value); + } + ZENCORE_API void AddInteger(int64_t Value); + /** Write an integer field. */ + inline void AddInteger(std::string_view Name, uint32_t Value) + { + SetName(Name); + AddInteger(Value); + } + ZENCORE_API void AddInteger(uint32_t Value); + /** Write an integer field. */ + inline void AddInteger(std::string_view Name, uint64_t Value) + { + SetName(Name); + AddInteger(Value); + } + ZENCORE_API void AddInteger(uint64_t Value); + + /** Write a float field from a 32-bit float value. */ + inline void AddFloat(std::string_view Name, float Value) + { + SetName(Name); + AddFloat(Value); + } + ZENCORE_API void AddFloat(float Value); + + /** Write a float field from a 64-bit float value. */ + inline void AddFloat(std::string_view Name, double Value) + { + SetName(Name); + AddFloat(Value); + } + ZENCORE_API void AddFloat(double Value); + + /** Write a bool field. */ + inline void AddBool(std::string_view Name, bool bValue) + { + SetName(Name); + AddBool(bValue); + } + ZENCORE_API void AddBool(bool bValue); + + /** Write a field referencing a compact binary attachment by its hash. */ + inline void AddObjectAttachment(std::string_view Name, const IoHash& Value) + { + SetName(Name); + AddObjectAttachment(Value); + } + ZENCORE_API void AddObjectAttachment(const IoHash& Value); + + /** Write a field referencing a binary attachment by its hash. */ + inline void AddBinaryAttachment(std::string_view Name, const IoHash& Value) + { + SetName(Name); + AddBinaryAttachment(Value); + } + ZENCORE_API void AddBinaryAttachment(const IoHash& Value); + + /** Write a field referencing the attachment by its hash. */ + inline void AddAttachment(std::string_view Name, const CbAttachment& Attachment) + { + SetName(Name); + AddAttachment(Attachment); + } + ZENCORE_API void AddAttachment(const CbAttachment& Attachment); + + /** Write a hash field. */ + inline void AddHash(std::string_view Name, const IoHash& Value) + { + SetName(Name); + AddHash(Value); + } + ZENCORE_API void AddHash(const IoHash& Value); + + /** Write a UUID field. */ + inline void AddUuid(std::string_view Name, const Guid& Value) + { + SetName(Name); + AddUuid(Value); + } + ZENCORE_API void AddUuid(const Guid& Value); + + /** Write an ObjectId field. */ + inline void AddObjectId(std::string_view Name, const Oid& Value) + { + SetName(Name); + AddObjectId(Value); + } + ZENCORE_API void AddObjectId(const Oid& Value); + + /** Write a date/time field with the specified count of 100ns ticks since the epoch. */ + inline void AddDateTimeTicks(std::string_view Name, int64_t Ticks) + { + SetName(Name); + AddDateTimeTicks(Ticks); + } + ZENCORE_API void AddDateTimeTicks(int64_t Ticks); + + /** Write a date/time field. */ + inline void AddDateTime(std::string_view Name, DateTime Value) + { + SetName(Name); + AddDateTime(Value); + } + ZENCORE_API void AddDateTime(DateTime Value); + + /** Write a time span field with the specified count of 100ns ticks. */ + inline void AddTimeSpanTicks(std::string_view Name, int64_t Ticks) + { + SetName(Name); + AddTimeSpanTicks(Ticks); + } + ZENCORE_API void AddTimeSpanTicks(int64_t Ticks); + + /** Write a time span field. */ + inline void AddTimeSpan(std::string_view Name, TimeSpan Value) + { + SetName(Name); + AddTimeSpan(Value); + } + ZENCORE_API void AddTimeSpan(TimeSpan Value); + + /** Private flags that are public to work with ENUM_CLASS_FLAGS. */ + enum class StateFlags : uint8_t; + +protected: + /** Reserve the specified size up front until the format is optimized. */ + ZENCORE_API explicit CbWriter(int64_t InitialSize); + +private: + friend CbWriter& operator<<(CbWriter& Writer, std::string_view NameOrValue); + + /** Begin writing a field. May be called twice for named fields. */ + void BeginField(); + + /** Finish writing a field by writing its type. */ + void EndField(CbFieldType Type); + + /** Set the field name if valid in this state, otherwise write add a string field. */ + ZENCORE_API void SetNameOrAddString(std::string_view NameOrValue); + + /** Returns a view of the name of the active field, if any, otherwise the empty view. */ + std::string_view GetActiveName() const; + + /** Remove field types after the first to make the sequence uniform. */ + void MakeFieldsUniform(int64_t FieldBeginOffset, int64_t FieldEndOffset); + + /** State of the object, array, or top-level field being written. */ + struct WriterState + { + StateFlags Flags{}; + /** The type of the fields in the sequence if uniform, otherwise None. */ + CbFieldType UniformType{}; + /** The offset of the start of the current field. */ + int64_t Offset{}; + /** The number of fields written in this state. */ + uint64_t Count{}; + }; + +private: + // This is a prototype-quality format for the writer. Using an array of bytes is inefficient, + // and will lead to many unnecessary copies and moves of the data to resize the array, insert + // object and array sizes, and remove field types for uniform objects and uniform arrays. The + // optimized format will be a list of power-of-two blocks and an optional first block that is + // provided externally, such as on the stack. That format will store the offsets that require + // object or array sizes to be inserted and field types to be removed, and will perform those + // operations only when saving to a buffer. + std::vector<uint8_t> Data; + std::vector<WriterState> States; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * A writer for compact binary object, arrays, and fields that uses a fixed-size stack buffer. + * + * @see CbWriter + */ +template<uint32_t InlineBufferSize> +class FixedCbWriter : public CbWriter +{ +public: + inline FixedCbWriter() : CbWriter(InlineBufferSize) {} + + FixedCbWriter(const FixedCbWriter&) = delete; + FixedCbWriter& operator=(const FixedCbWriter&) = delete; + +private: + // Reserve the inline buffer now even though we are unable to use it. This will avoid causing + // new stack overflows when this functionality is properly implemented in the future. + uint8_t Buffer[InlineBufferSize]; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +class CbObjectWriter : public CbWriter +{ +public: + CbObjectWriter() { BeginObject(); } + + ZENCORE_API CbObject Save() + { + Finalize(); + return CbWriter::Save().AsObject(); + } + + ZENCORE_API void Save(BinaryWriter& Writer) + { + Finalize(); + return CbWriter::Save(Writer); + } + + ZENCORE_API CbFieldViewIterator Save(MutableMemoryView Buffer) + { + ZEN_ASSERT(m_Finalized); + return CbWriter::Save(Buffer); + } + + uint64_t GetSaveSize() + { + ZEN_ASSERT(m_Finalized); + return CbWriter::GetSaveSize(); + } + + void Finalize() + { + if (m_Finalized == false) + { + EndObject(); + m_Finalized = true; + } + } + + CbObjectWriter(const CbWriter&) = delete; + CbObjectWriter& operator=(const CbWriter&) = delete; + +private: + bool m_Finalized = false; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** Write the field name if valid in this state, otherwise write the string value. */ +inline CbWriter& +operator<<(CbWriter& Writer, std::string_view NameOrValue) +{ + Writer.SetNameOrAddString(NameOrValue); + return Writer; +} + +/** Write the field name if valid in this state, otherwise write the string value. */ +inline CbWriter& +operator<<(CbWriter& Writer, const char* NameOrValue) +{ + return Writer << std::string_view(NameOrValue); +} + +inline CbWriter& +operator<<(CbWriter& Writer, const CbFieldView& Value) +{ + Writer.AddField(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, const CbField& Value) +{ + Writer.AddField(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, const CbObjectView& Value) +{ + Writer.AddObject(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, const CbObject& Value) +{ + Writer.AddObject(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, const CbArrayView& Value) +{ + Writer.AddArray(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, const CbArray& Value) +{ + Writer.AddArray(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, std::nullptr_t) +{ + Writer.AddNull(); + return Writer; +} + +#if defined(__clang__) && defined(__APPLE__) +/* Apple Clang has different types for uint64_t and size_t so an override is + needed here. Without it, Clang can't disambiguate integer overloads */ +inline CbWriter& +operator<<(CbWriter& Writer, std::size_t Value) +{ + Writer.AddInteger(uint64_t(Value)); + return Writer; +} +#endif + +inline CbWriter& +operator<<(CbWriter& Writer, std::wstring_view Value) +{ + Writer.AddString(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, const wchar_t* Value) +{ + Writer.AddString(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, int32_t Value) +{ + Writer.AddInteger(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, int64_t Value) +{ + Writer.AddInteger(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, uint32_t Value) +{ + Writer.AddInteger(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, uint64_t Value) +{ + Writer.AddInteger(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, float Value) +{ + Writer.AddFloat(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, double Value) +{ + Writer.AddFloat(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, bool Value) +{ + Writer.AddBool(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, const CbAttachment& Attachment) +{ + Writer.AddAttachment(Attachment); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, const IoHash& Value) +{ + Writer.AddHash(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, const Guid& Value) +{ + Writer.AddUuid(Value); + return Writer; +} + +inline CbWriter& +operator<<(CbWriter& Writer, const Oid& Value) +{ + Writer.AddObjectId(Value); + return Writer; +} + +ZENCORE_API CbWriter& operator<<(CbWriter& Writer, DateTime Value); +ZENCORE_API CbWriter& operator<<(CbWriter& Writer, TimeSpan Value); + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +void usonbuilder_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/compactbinarypackage.h b/src/zencore/include/zencore/compactbinarypackage.h new file mode 100644 index 000000000..16f723edc --- /dev/null +++ b/src/zencore/include/zencore/compactbinarypackage.h @@ -0,0 +1,341 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <zencore/compactbinary.h> +#include <zencore/compress.h> +#include <zencore/iohash.h> + +#include <functional> +#include <span> +#include <variant> + +#ifdef GetObject +# error "windows.h pollution" +# undef GetObject +#endif + +namespace zen { + +class CbWriter; +class BinaryReader; +class BinaryWriter; +class IoBuffer; +class CbAttachment; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * An attachment is either binary or compact binary and is identified by its hash. + * + * A compact binary attachment is also a valid binary attachment and may be accessed as binary. + * + * Attachments are serialized as one or two compact binary fields with no name. A Binary field is + * written first with its content. The content hash is omitted when the content size is zero, and + * is otherwise written as a BinaryReference or CompactBinaryReference depending on the type. + */ +class CbAttachment +{ +public: + /** Construct a null attachment. */ + CbAttachment() = default; + + /** Construct a compact binary attachment. Value is cloned if not owned. */ + inline explicit CbAttachment(const CbObject& InValue) : CbAttachment(InValue, nullptr) {} + + /** Construct a compact binary attachment. Value is cloned if not owned. Hash must match Value. */ + inline explicit CbAttachment(const CbObject& InValue, const IoHash& Hash) : CbAttachment(InValue, &Hash) {} + + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue); + + /** Construct a raw binary attachment. Value is cloned if not owned. Hash must match Value. */ + ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue, const IoHash& Hash); + + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(const CompositeBuffer& InValue); + + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue); + + /** Construct a raw binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue, const IoHash& Hash); + + /** Construct a compressed binary attachment. Value is cloned if not owned. */ + ZENCORE_API explicit CbAttachment(const CompressedBuffer& InValue, const IoHash& Hash); + ZENCORE_API explicit CbAttachment(CompressedBuffer&& InValue, const IoHash& Hash); + + /** Reset this to a null attachment. */ + inline void Reset() { *this = CbAttachment(); } + + /** Whether the attachment has a value. */ + inline explicit operator bool() const { return !IsNull(); } + + /** Whether the attachment has a value. */ + ZENCORE_API [[nodiscard]] bool IsNull() const; + + /** Access the attachment as binary. Defaults to a null buffer on error. */ + ZENCORE_API [[nodiscard]] SharedBuffer AsBinary() const; + + /** Access the attachment as raw binary. Defaults to a null buffer on error. */ + ZENCORE_API [[nodiscard]] CompositeBuffer AsCompositeBinary() const; + + /** Access the attachment as compressed binary. Defaults to a null buffer if the attachment is null. */ + ZENCORE_API [[nodiscard]] CompressedBuffer AsCompressedBinary() const; + + /** Access the attachment as compact binary. Defaults to a field iterator with no value on error. */ + ZENCORE_API [[nodiscard]] CbObject AsObject() const; + + /** Returns true if the attachment is binary */ + ZENCORE_API [[nodiscard]] bool IsBinary() const; + + /** Returns true if the attachment is compressed binary */ + ZENCORE_API [[nodiscard]] bool IsCompressedBinary() const; + + /** Returns whether the attachment is an object. */ + ZENCORE_API [[nodiscard]] bool IsObject() const; + + /** Returns the hash of the attachment value. */ + ZENCORE_API [[nodiscard]] IoHash GetHash() const; + + /** Compares attachments by their hash. Any discrepancy in type must be handled externally. */ + inline bool operator==(const CbAttachment& Attachment) const { return GetHash() == Attachment.GetHash(); } + inline bool operator!=(const CbAttachment& Attachment) const { return GetHash() != Attachment.GetHash(); } + inline bool operator<(const CbAttachment& Attachment) const { return GetHash() < Attachment.GetHash(); } + + /** + * Load the attachment from compact binary as written by Save. + * + * The attachment references the input iterator if it is owned, and otherwise clones the value. + * + * The iterator is advanced as attachment fields are consumed from it. + */ + ZENCORE_API bool TryLoad(CbFieldIterator& Fields); + + /** + * Load the attachment from compact binary as written by Save. + */ + ZENCORE_API bool TryLoad(BinaryReader& Reader, BufferAllocator Allocator = UniqueBuffer::Alloc); + + /** + * Load the attachment from compact binary as written by Save. + */ + ZENCORE_API bool TryLoad(IoBuffer& Buffer, BufferAllocator Allocator = UniqueBuffer::Alloc); + + /** Save the attachment into the writer as a stream of compact binary fields. */ + ZENCORE_API void Save(CbWriter& Writer) const; + + /** Save the attachment into the writer as a stream of compact binary fields. */ + ZENCORE_API void Save(BinaryWriter& Writer) const; + +private: + ZENCORE_API CbAttachment(const CbObject& Value, const IoHash* Hash); + + IoHash Hash; + std::variant<std::nullptr_t, CbObject, CompositeBuffer, CompressedBuffer> Value; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * A package is a compact binary object with attachments for its external references. + * + * A package is basically a Merkle tree with compact binary as its root and other non-leaf nodes, + * and either binary or compact binary as its leaf nodes. A node references its child nodes using + * BinaryHash or FieldHash fields in its compact binary representation. + * + * It is invalid for a package to include attachments that are not referenced by its object or by + * one of its referenced compact binary attachments. When attachments are added explicitly, it is + * the responsibility of the package creator to follow this requirement. Attachments that are not + * referenced may not survive a round-trip through certain storage systems. + * + * It is valid for a package to exclude referenced attachments, but then it is the responsibility + * of the package consumer to have a mechanism for resolving those references when necessary. + * + * A package is serialized as a sequence of compact binary fields with no name. The object may be + * both preceded and followed by attachments. The object itself is written as an Object field and + * followed by its hash in a CompactBinaryReference field when the object is non-empty. A package + * ends with a Null field. The canonical order of components is the object and its hash, followed + * by the attachments ordered by hash, followed by a Null field. It is valid for the a package to + * have its components serialized in any order, provided there is at most one object and the null + * field is written last. + */ +class CbPackage +{ +public: + /** + * A function that resolves a hash to a buffer containing the data matching that hash. + * + * The resolver may return a null buffer to skip resolving an attachment for the hash. + */ + using AttachmentResolver = std::function<SharedBuffer(const IoHash& Hash)>; + + /** Construct a null package. */ + CbPackage() = default; + + /** + * Construct a package from a root object without gathering attachments. + * + * @param InObject The root object, which will be cloned unless it is owned. + */ + inline explicit CbPackage(CbObject InObject) { SetObject(std::move(InObject)); } + + /** + * Construct a package from a root object and gather attachments using the resolver. + * + * @param InObject The root object, which will be cloned unless it is owned. + * @param InResolver A function that is invoked for every reference and binary reference field. + */ + inline explicit CbPackage(CbObject InObject, AttachmentResolver InResolver) { SetObject(std::move(InObject), InResolver); } + + /** + * Construct a package from a root object without gathering attachments. + * + * @param InObject The root object, which will be cloned unless it is owned. + * @param InObjectHash The hash of the object, which must match to avoid validation errors. + */ + inline explicit CbPackage(CbObject InObject, const IoHash& InObjectHash) { SetObject(std::move(InObject), InObjectHash); } + + /** + * Construct a package from a root object and gather attachments using the resolver. + * + * @param InObject The root object, which will be cloned unless it is owned. + * @param InObjectHash The hash of the object, which must match to avoid validation errors. + * @param InResolver A function that is invoked for every reference and binary reference field. + */ + inline explicit CbPackage(CbObject InObject, const IoHash& InObjectHash, AttachmentResolver InResolver) + { + SetObject(std::move(InObject), InObjectHash, InResolver); + } + + /** Reset this to a null package. */ + inline void Reset() { *this = CbPackage(); } + + /** Whether the package has a non-empty object or attachments. */ + inline explicit operator bool() const { return !IsNull(); } + + /** Whether the package has an empty object and no attachments. */ + inline bool IsNull() const { return !Object && Attachments.size() == 0; } + + /** Returns the compact binary object for the package. */ + inline const CbObject& GetObject() const { return Object; } + + /** Returns the has of the compact binary object for the package. */ + inline const IoHash& GetObjectHash() const { return ObjectHash; } + + /** + * Set the root object without gathering attachments. + * + * @param InObject The root object, which will be cloned unless it is owned. + */ + inline void SetObject(CbObject InObject) { SetObject(std::move(InObject), nullptr, nullptr); } + + /** + * Set the root object and gather attachments using the resolver. + * + * @param InObject The root object, which will be cloned unless it is owned. + * @param InResolver A function that is invoked for every reference and binary reference field. + */ + inline void SetObject(CbObject InObject, AttachmentResolver InResolver) { SetObject(std::move(InObject), nullptr, &InResolver); } + + /** + * Set the root object without gathering attachments. + * + * @param InObject The root object, which will be cloned unless it is owned. + * @param InObjectHash The hash of the object, which must match to avoid validation errors. + */ + inline void SetObject(CbObject InObject, const IoHash& InObjectHash) { SetObject(std::move(InObject), &InObjectHash, nullptr); } + + /** + * Set the root object and gather attachments using the resolver. + * + * @param InObject The root object, which will be cloned unless it is owned. + * @param InObjectHash The hash of the object, which must match to avoid validation errors. + * @param InResolver A function that is invoked for every reference and binary reference field. + */ + inline void SetObject(CbObject InObject, const IoHash& InObjectHash, AttachmentResolver InResolver) + { + SetObject(std::move(InObject), &InObjectHash, &InResolver); + } + + /** Returns the attachments in this package. */ + inline std::span<const CbAttachment> GetAttachments() const { return Attachments; } + + /** + * Find an attachment by its hash. + * + * @return The attachment, or null if the attachment is not found. + * @note The returned pointer is only valid until the attachments on this package are modified. + */ + ZENCORE_API const CbAttachment* FindAttachment(const IoHash& Hash) const; + + /** Find an attachment if it exists in the package. */ + inline const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); } + + /** Add the attachment to this package. */ + inline void AddAttachment(const CbAttachment& Attachment) { AddAttachment(Attachment, nullptr); } + + /** Add the attachment to this package, along with any references that can be resolved. */ + inline void AddAttachment(const CbAttachment& Attachment, AttachmentResolver Resolver) { AddAttachment(Attachment, &Resolver); } + + void AddAttachments(std::span<const CbAttachment> Attachments); + + /** + * Remove an attachment by hash. + * + * @return Number of attachments removed, which will be either 0 or 1. + */ + ZENCORE_API int32_t RemoveAttachment(const IoHash& Hash); + inline int32_t RemoveAttachment(const CbAttachment& Attachment) { return RemoveAttachment(Attachment.GetHash()); } + + /** Compares packages by their object and attachment hashes. */ + ZENCORE_API bool Equals(const CbPackage& Package) const; + inline bool operator==(const CbPackage& Package) const { return Equals(Package); } + inline bool operator!=(const CbPackage& Package) const { return !Equals(Package); } + + /** + * Load the object and attachments from compact binary as written by Save. + * + * The object and attachments reference the input iterator, if it is owned, and otherwise clones + * the object and attachments individually to make owned copies. + * + * The iterator is advanced as object and attachment fields are consumed from it. + */ + ZENCORE_API bool TryLoad(CbFieldIterator& Fields); + ZENCORE_API bool TryLoad(IoBuffer Buffer, BufferAllocator Allocator = UniqueBuffer::Alloc, AttachmentResolver* Mapper = nullptr); + ZENCORE_API bool TryLoad(BinaryReader& Reader, BufferAllocator Allocator = UniqueBuffer::Alloc, AttachmentResolver* Mapper = nullptr); + + /** Save the object and attachments into the writer as a stream of compact binary fields. */ + ZENCORE_API void Save(CbWriter& Writer) const; + + /** Save the object and attachments into the writer as a stream of compact binary fields. */ + ZENCORE_API void Save(BinaryWriter& Writer) const; + +private: + ZENCORE_API void SetObject(CbObject Object, const IoHash* Hash, AttachmentResolver* Resolver); + ZENCORE_API void AddAttachment(const CbAttachment& Attachment, AttachmentResolver* Resolver); + + void GatherAttachments(const CbObject& Object, AttachmentResolver Resolver); + + /** Attachments ordered by their hash. */ + std::vector<CbAttachment> Attachments; + CbObject Object; + IoHash ObjectHash; +}; + +namespace legacy { + void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer); + void SaveCbPackage(const CbPackage& Package, CbWriter& Writer); + void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar); + bool TryLoadCbPackage(CbPackage& Package, IoBuffer Buffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper = nullptr); + bool TryLoadCbPackage(CbPackage& Package, + BinaryReader& Reader, + BufferAllocator Allocator, + CbPackage::AttachmentResolver* Mapper = nullptr); +} // namespace legacy + +void usonpackage_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/compactbinaryvalidation.h b/src/zencore/include/zencore/compactbinaryvalidation.h new file mode 100644 index 000000000..b1fab9572 --- /dev/null +++ b/src/zencore/include/zencore/compactbinaryvalidation.h @@ -0,0 +1,197 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <zencore/compactbinary.h> +#include <zencore/enumflags.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/refcount.h> +#include <zencore/sha1.h> + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +/** Flags for validating compact binary data. */ +enum class CbValidateMode : uint32_t +{ + /** Skip validation if no other validation modes are enabled. */ + None = 0, + + /** + * Validate that the value can be read and stays inside the bounds of the memory view. + * + * This is the minimum level of validation required to be able to safely read a field, array, + * or object without the risk of crashing or reading out of bounds. + */ + Default = 1 << 0, + + /** + * Validate that object fields have unique non-empty names and array fields have no names. + * + * Name validation failures typically do not inhibit reading the input, but duplicated fields + * cannot be looked up by name other than the first, and converting to other data formats can + * fail in the presence of naming issues. + */ + Names = 1 << 1, + + /** + * Validate that fields are serialized in the canonical format. + * + * Format validation failures typically do not inhibit reading the input. Values that fail in + * this mode require more memory than in the canonical format, and comparisons of such values + * for equality are not reliable. Examples of failures include uniform arrays or objects that + * were not encoded uniformly, variable-length integers that could be encoded in fewer bytes, + * or 64-bit floats that could be encoded in 32 bits without loss of precision. + */ + Format = 1 << 2, + + /** + * Validate that there is no padding after the value before the end of the memory view. + * + * Padding validation failures have no impact on the ability to read the input, but are using + * more memory than necessary. + */ + Padding = 1 << 3, + + /** + * Validate that a package or attachment has the expected fields. + */ + Package = 1 << 4, + + /** + * Validate that a package or attachment matches its saved hashes. + */ + PackageHash = 1 << 5, + + /** Perform all validation described above. */ + All = Default | Names | Format | Padding | Package | PackageHash, +}; + +ENUM_CLASS_FLAGS(CbValidateMode); + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** Flags for compact binary validation errors. Multiple flags may be combined. */ +enum class CbValidateError : uint32_t +{ + /** The input had no validation errors. */ + None = 0, + + // Mode: Default + + /** The input cannot be read without reading out of bounds. */ + OutOfBounds = 1 << 0, + /** The input has a field with an unrecognized or invalid type. */ + InvalidType = 1 << 1, + + // Mode: Names + + /** An object had more than one field with the same name. */ + DuplicateName = 1 << 2, + /** An object had a field with no name. */ + MissingName = 1 << 3, + /** An array field had a name. */ + ArrayName = 1 << 4, + + // Mode: Format + + /** A name or string payload is not valid UTF-8. */ + InvalidString = 1 << 5, + /** A size or integer payload can be encoded in fewer bytes. */ + InvalidInteger = 1 << 6, + /** A float64 payload can be encoded as a float32 without loss of precision. */ + InvalidFloat = 1 << 7, + /** An object has the same type for every field but is not uniform. */ + NonUniformObject = 1 << 8, + /** An array has the same type for every field and non-empty payloads but is not uniform. */ + NonUniformArray = 1 << 9, + + // Mode: Padding + + /** A value did not use the entire memory view given for validation. */ + Padding = 1 << 10, + + // Mode: Package + + /** The package or attachment had missing fields or fields out of order. */ + InvalidPackageFormat = 1 << 11, + /** The object or an attachment did not match the hash stored for it. */ + InvalidPackageHash = 1 << 12, + /** The package contained more than one copy of the same attachment. */ + DuplicateAttachments = 1 << 13, + /** The package contained more than one object. */ + MultiplePackageObjects = 1 << 14, + /** The package contained an object with no fields. */ + NullPackageObject = 1 << 15, + /** The package contained a null attachment. */ + NullPackageAttachment = 1 << 16, +}; + +ENUM_CLASS_FLAGS(CbValidateError); + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +/** + * Validate the compact binary data for one field in the view as specified by the mode flags. + * + * Only one top-level field is processed from the view, and validation recurses into any array or + * object within that field. To validate multiple consecutive top-level fields, call the function + * once for each top-level field. If the given view might contain multiple top-level fields, then + * either exclude the Padding flag from the Mode or use MeasureCompactBinary to break up the view + * into its constituent fields before validating. + * + * @param View A memory view containing at least one top-level field. + * @param Mode A combination of the flags for the types of validation to perform. + * @param Type HasFieldType means that View contains the type. Otherwise, use the given type. + * @return None on success, otherwise the flags for the types of errors that were detected. + */ +ZENCORE_API CbValidateError ValidateCompactBinary(MemoryView View, CbValidateMode Mode, CbFieldType Type = CbFieldType::HasFieldType); + +/** + * Validate the compact binary data for every field in the view as specified by the mode flags. + * + * This function expects the entire view to contain fields. Any trailing region of the view which + * does not contain a valid field will produce an OutOfBounds or InvalidType error instead of the + * Padding error that would be produced by the single field validation function. + * + * @see ValidateCompactBinary + */ +ZENCORE_API CbValidateError ValidateCompactBinaryRange(MemoryView View, CbValidateMode Mode); + +/** + * Validate the compact binary attachment pointed to by the view as specified by the mode flags. + * + * The attachment is validated with ValidateCompactBinary by using the validation mode specified. + * Include ECbValidateMode::Package to validate the attachment format and hash. + * + * @see ValidateCompactBinary + * + * @param View A memory view containing a package. + * @param Mode A combination of the flags for the types of validation to perform. + * @return None on success, otherwise the flags for the types of errors that were detected. + */ +ZENCORE_API CbValidateError ValidateObjectAttachment(MemoryView View, CbValidateMode Mode); + +/** + * Validate the compact binary package pointed to by the view as specified by the mode flags. + * + * The package, and attachments, are validated with ValidateCompactBinary by using the validation + * mode specified. Include ECbValidateMode::Package to validate the package format and hashes. + * + * @see ValidateCompactBinary + * + * @param View A memory view containing a package. + * @param Mode A combination of the flags for the types of validation to perform. + * @return None on success, otherwise the flags for the types of errors that were detected. + */ +ZENCORE_API CbValidateError ValidateCompactBinaryPackage(MemoryView View, CbValidateMode Mode); + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +void usonvalidation_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/compactbinaryvalue.h b/src/zencore/include/zencore/compactbinaryvalue.h new file mode 100644 index 000000000..0124a8983 --- /dev/null +++ b/src/zencore/include/zencore/compactbinaryvalue.h @@ -0,0 +1,290 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/endian.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/memory.h> + +namespace zen { + +namespace CompactBinaryPrivate { + + template<typename T> + static constexpr inline T ReadUnaligned(const void* const Memory) + { +#if ZEN_PLATFORM_SUPPORTS_UNALIGNED_LOADS + return *static_cast<const T*>(Memory); +#else + T Value; + memcpy(&Value, Memory, sizeof(Value)); + return Value; +#endif + } +} // namespace CompactBinaryPrivate +/** + * A type that provides unchecked access to compact binary values. + * + * The main purpose of the type is to efficiently switch on field type. For every other use case, + * prefer to use the field, array, and object types directly. The accessors here do not check the + * type before reading the value, which means they can read out of bounds even on a valid compact + * binary value if the wrong accessor is used. + */ +class CbValue +{ +public: + CbValue(CbFieldType Type, const void* Value); + + CbObjectView AsObjectView() const; + CbArrayView AsArrayView() const; + + MemoryView AsBinary() const; + + /** Access as a string. Checks for range errors and uses the default if OutError is not null. */ + std::string_view AsString(CbFieldError* OutError = nullptr, std::string_view Default = std::string_view()) const; + + /** Access as a string as UTF8. Checks for range errors and uses the default if OutError is not null. */ + std::u8string_view AsU8String(CbFieldError* OutError = nullptr, std::u8string_view Default = std::u8string_view()) const; + + /** + * Access as an integer, with both positive and negative values returned as unsigned. + * + * Checks for range errors and uses the default if OutError is not null. + */ + uint64_t AsInteger(CompactBinaryPrivate::IntegerParams Params, CbFieldError* OutError = nullptr, uint64_t Default = 0) const; + + uint64_t AsIntegerPositive() const; + int64_t AsIntegerNegative() const; + + float AsFloat32() const; + double AsFloat64() const; + + bool AsBool() const; + + inline IoHash AsObjectAttachment() const { return AsHash(); } + inline IoHash AsBinaryAttachment() const { return AsHash(); } + inline IoHash AsAttachment() const { return AsHash(); } + + IoHash AsHash() const; + Guid AsUuid() const; + + int64_t AsDateTimeTicks() const; + int64_t AsTimeSpanTicks() const; + + Oid AsObjectId() const; + + CbCustomById AsCustomById() const; + CbCustomByName AsCustomByName() const; + + inline CbFieldType GetType() const { return Type; } + inline const void* GetData() const { return Data; } + +private: + const void* Data; + CbFieldType Type; +}; + +inline CbFieldView::CbFieldView(const CbValue& InValue) : Type(InValue.GetType()), Payload(InValue.GetData()) +{ +} + +inline CbValue +CbFieldView::GetValue() const +{ + return CbValue(CbFieldTypeOps::GetType(Type), Payload); +} + +inline CbValue::CbValue(CbFieldType InType, const void* InValue) : Data(InValue), Type(InType) +{ +} + +inline CbObjectView +CbValue::AsObjectView() const +{ + return CbObjectView(*this); +} + +inline CbArrayView +CbValue::AsArrayView() const +{ + return CbArrayView(*this); +} + +inline MemoryView +CbValue::AsBinary() const +{ + const uint8_t* const Bytes = static_cast<const uint8_t*>(Data); + uint32_t ValueSizeByteCount; + const uint64_t ValueSize = ReadVarUInt(Bytes, ValueSizeByteCount); + return MakeMemoryView(Bytes + ValueSizeByteCount, ValueSize); +} + +inline std::string_view +CbValue::AsString(CbFieldError* OutError, std::string_view Default) const +{ + const char* const Chars = static_cast<const char*>(Data); + uint32_t ValueSizeByteCount; + const uint64_t ValueSize = ReadVarUInt(Chars, ValueSizeByteCount); + + if (OutError) + { + if (ValueSize >= (uint64_t(1) << 31)) + { + *OutError = CbFieldError::RangeError; + return Default; + } + *OutError = CbFieldError::None; + } + + return std::string_view(Chars + ValueSizeByteCount, int32_t(ValueSize)); +} + +inline std::u8string_view +CbValue::AsU8String(CbFieldError* OutError, std::u8string_view Default) const +{ + const char8_t* const Chars = static_cast<const char8_t*>(Data); + uint32_t ValueSizeByteCount; + const uint64_t ValueSize = ReadVarUInt(Chars, ValueSizeByteCount); + + if (OutError) + { + if (ValueSize >= (uint64_t(1) << 31)) + { + *OutError = CbFieldError::RangeError; + return Default; + } + *OutError = CbFieldError::None; + } + + return std::u8string_view(Chars + ValueSizeByteCount, int32_t(ValueSize)); +} + +inline uint64_t +CbValue::AsInteger(CompactBinaryPrivate::IntegerParams Params, CbFieldError* OutError, uint64_t Default) const +{ + // A shift of a 64-bit value by 64 is undefined so shift by one less because magnitude is never zero. + const uint64_t OutOfRangeMask = uint64_t(-2) << (Params.MagnitudeBits - 1); + const uint64_t IsNegative = uint8_t(Type) & 1; + + uint32_t MagnitudeByteCount; + const uint64_t Magnitude = ReadVarUInt(Data, MagnitudeByteCount); + const uint64_t Value = Magnitude ^ -int64_t(IsNegative); + + if (OutError) + { + const uint64_t IsInRange = (!(Magnitude & OutOfRangeMask)) & ((!IsNegative) | Params.IsSigned); + *OutError = IsInRange ? CbFieldError::None : CbFieldError::RangeError; + + const uint64_t UseValueMask = -int64_t(IsInRange); + return (Value & UseValueMask) | (Default & ~UseValueMask); + } + + return Value; +} + +inline uint64_t +CbValue::AsIntegerPositive() const +{ + uint32_t MagnitudeByteCount; + return ReadVarUInt(Data, MagnitudeByteCount); +} + +inline int64_t +CbValue::AsIntegerNegative() const +{ + uint32_t MagnitudeByteCount; + return int64_t(ReadVarUInt(Data, MagnitudeByteCount)) ^ -int64_t(1); +} + +inline float +CbValue::AsFloat32() const +{ + const uint32_t Value = FromNetworkOrder(CompactBinaryPrivate::ReadUnaligned<uint32_t>(Data)); + return reinterpret_cast<const float&>(Value); +} + +inline double +CbValue::AsFloat64() const +{ + const uint64_t Value = FromNetworkOrder(CompactBinaryPrivate::ReadUnaligned<uint64_t>(Data)); + return reinterpret_cast<const double&>(Value); +} + +inline bool +CbValue::AsBool() const +{ + return uint8_t(Type) & 1; +} + +inline IoHash +CbValue::AsHash() const +{ + return IoHash::MakeFrom(Data); +} + +inline Guid +CbValue::AsUuid() const +{ + Guid Value; + memcpy(&Value, Data, sizeof(Guid)); + Value.A = FromNetworkOrder(Value.A); + Value.B = FromNetworkOrder(Value.B); + Value.C = FromNetworkOrder(Value.C); + Value.D = FromNetworkOrder(Value.D); + return Value; +} + +inline int64_t +CbValue::AsDateTimeTicks() const +{ + return FromNetworkOrder(CompactBinaryPrivate::ReadUnaligned<int64_t>(Data)); +} + +inline int64_t +CbValue::AsTimeSpanTicks() const +{ + return FromNetworkOrder(CompactBinaryPrivate::ReadUnaligned<int64_t>(Data)); +} + +inline Oid +CbValue::AsObjectId() const +{ + return Oid::FromMemory(Data); +} + +inline CbCustomById +CbValue::AsCustomById() const +{ + const uint8_t* Bytes = static_cast<const uint8_t*>(Data); + uint32_t DataSizeByteCount; + const uint64_t DataSize = ReadVarUInt(Bytes, DataSizeByteCount); + Bytes += DataSizeByteCount; + + CbCustomById Value; + uint32_t TypeIdByteCount; + Value.Id = ReadVarUInt(Bytes, TypeIdByteCount); + Value.Data = MakeMemoryView(Bytes + TypeIdByteCount, DataSize - TypeIdByteCount); + return Value; +} + +inline CbCustomByName +CbValue::AsCustomByName() const +{ + const uint8_t* Bytes = static_cast<const uint8_t*>(Data); + uint32_t DataSizeByteCount; + const uint64_t DataSize = ReadVarUInt(Bytes, DataSizeByteCount); + Bytes += DataSizeByteCount; + + uint32_t TypeNameLenByteCount; + const uint64_t TypeNameLen = ReadVarUInt(Bytes, TypeNameLenByteCount); + Bytes += TypeNameLenByteCount; + + CbCustomByName Value; + Value.Name = std::u8string_view(reinterpret_cast<const char8_t*>(Bytes), static_cast<std::u8string_view::size_type>(TypeNameLen)); + Value.Data = MakeMemoryView(Bytes + TypeNameLen, DataSize - TypeNameLen - TypeNameLenByteCount); + return Value; +} + +} // namespace zen diff --git a/src/zencore/include/zencore/compositebuffer.h b/src/zencore/include/zencore/compositebuffer.h new file mode 100644 index 000000000..4e4b4d002 --- /dev/null +++ b/src/zencore/include/zencore/compositebuffer.h @@ -0,0 +1,142 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/sharedbuffer.h> +#include <zencore/zencore.h> + +#include <functional> +#include <span> +#include <vector> + +namespace zen { + +/** + * CompositeBuffer is a non-contiguous buffer composed of zero or more immutable shared buffers. + * + * A composite buffer is most efficient when its segments are consumed as they are, but it can be + * flattened into a contiguous buffer, when necessary, by calling Flatten(). Ownership of segment + * buffers is not changed on construction, but if ownership of segments is required then that can + * be guaranteed by calling MakeOwned(). + */ + +class CompositeBuffer +{ +public: + /** + * Construct a composite buffer by concatenating the buffers. Does not enforce ownership. + * + * Buffer parameters may be SharedBuffer, CompositeBuffer, or std::vector<SharedBuffer>. + */ + template<typename... BufferTypes> + inline explicit CompositeBuffer(BufferTypes&&... Buffers) + { + if constexpr (sizeof...(Buffers) > 0) + { + m_Segments.reserve((GetBufferCount(std::forward<BufferTypes>(Buffers)) + ...)); + (AppendBuffers(std::forward<BufferTypes>(Buffers)), ...); + std::erase_if(m_Segments, [](const SharedBuffer& It) { return It.IsNull(); }); + } + } + + /** Reset this to null. */ + ZENCORE_API void Reset(); + + /** Returns the total size of the composite buffer in bytes. */ + [[nodiscard]] ZENCORE_API uint64_t GetSize() const; + + /** Returns the segments that the buffer is composed from. */ + [[nodiscard]] inline std::span<const SharedBuffer> GetSegments() const { return std::span<const SharedBuffer>{m_Segments}; } + + /** Returns true if the composite buffer is not null. */ + [[nodiscard]] inline explicit operator bool() const { return !IsNull(); } + + /** Returns true if the composite buffer is null. */ + [[nodiscard]] inline bool IsNull() const { return m_Segments.empty(); } + + /** Returns true if every segment in the composite buffer is owned. */ + [[nodiscard]] ZENCORE_API bool IsOwned() const; + + /** Returns a copy of the buffer where every segment is owned. */ + [[nodiscard]] ZENCORE_API CompositeBuffer MakeOwned() const&; + [[nodiscard]] ZENCORE_API CompositeBuffer MakeOwned() &&; + + /** Returns the concatenation of the segments into a contiguous buffer. */ + [[nodiscard]] ZENCORE_API SharedBuffer Flatten() const&; + [[nodiscard]] ZENCORE_API SharedBuffer Flatten() &&; + + /** Returns the middle part of the buffer by taking the size starting at the offset. */ + [[nodiscard]] ZENCORE_API CompositeBuffer Mid(uint64_t Offset, uint64_t Size = ~uint64_t(0)) const; + + /** + * Returns a view of the range if contained by one segment, otherwise a view of a copy of the range. + * + * @note CopyBuffer is reused if large enough, and otherwise allocated when needed. + * + * @param Offset The byte offset in this buffer that the range starts at. + * @param Size The number of bytes in the range to view or copy. + * @param CopyBuffer The buffer to write the copy into if a copy is required. + */ + [[nodiscard]] ZENCORE_API MemoryView ViewOrCopyRange(uint64_t Offset, uint64_t Size, UniqueBuffer& CopyBuffer) const; + + /** + * Copies a range of the buffer to a contiguous region of memory. + * + * @param Target The view to copy to. Must be no larger than the data available at the offset. + * @param Offset The byte offset in this buffer to start copying from. + */ + ZENCORE_API void CopyTo(MutableMemoryView Target, uint64_t Offset = 0) const; + + /** + * Invokes a visitor with a view of each segment that intersects with a range. + * + * @param Offset The byte offset in this buffer to start visiting from. + * @param Size The number of bytes in the range to visit. + * @param Visitor The visitor to invoke from zero to GetSegments().Num() times. + */ + ZENCORE_API void IterateRange(uint64_t Offset, uint64_t Size, std::function<void(MemoryView View)> Visitor) const; + ZENCORE_API void IterateRange(uint64_t Offset, + uint64_t Size, + std::function<void(MemoryView View, const SharedBuffer& ViewOuter)> Visitor) const; + + struct Iterator + { + size_t SegmentIndex = 0; + uint64_t OffsetInSegment = 0; + }; + ZENCORE_API Iterator GetIterator(uint64_t Offset) const; + ZENCORE_API MemoryView ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& CopyBuffer) const; + ZENCORE_API void CopyTo(MutableMemoryView Target, Iterator& It) const; + + /** A null composite buffer. */ + static const CompositeBuffer Null; + +private: + static inline size_t GetBufferCount(const CompositeBuffer& Buffer) { return Buffer.m_Segments.size(); } + inline void AppendBuffers(const CompositeBuffer& Buffer) + { + m_Segments.insert(m_Segments.end(), begin(Buffer.m_Segments), end(Buffer.m_Segments)); + } + inline void AppendBuffers(CompositeBuffer&& Buffer) + { + // TODO: this operates just like the by-reference version above + m_Segments.insert(m_Segments.end(), begin(Buffer.m_Segments), end(Buffer.m_Segments)); + } + + static inline size_t GetBufferCount(const SharedBuffer&) { return 1; } + inline void AppendBuffers(const SharedBuffer& Buffer) { m_Segments.push_back(Buffer); } + inline void AppendBuffers(SharedBuffer&& Buffer) { m_Segments.push_back(std::move(Buffer)); } + + static inline size_t GetBufferCount(std::vector<SharedBuffer>&& Container) { return Container.size(); } + inline void AppendBuffers(std::vector<SharedBuffer>&& Container) + { + m_Segments.insert(m_Segments.end(), begin(Container), end(Container)); + } + +private: + std::vector<SharedBuffer> m_Segments; +}; + +void compositebuffer_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/compress.h b/src/zencore/include/zencore/compress.h new file mode 100644 index 000000000..99ce20d8a --- /dev/null +++ b/src/zencore/include/zencore/compress.h @@ -0,0 +1,165 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore/zencore.h" + +#include "zencore/blake3.h" +#include "zencore/compositebuffer.h" + +namespace zen { + +enum class OodleCompressor : uint8_t +{ + NotSet = 0, + Selkie = 1, + Mermaid = 2, + Kraken = 3, + Leviathan = 4, +}; + +enum class OodleCompressionLevel : int8_t +{ + HyperFast4 = -4, + HyperFast3 = -3, + HyperFast2 = -2, + HyperFast1 = -1, + None = 0, + SuperFast = 1, + VeryFast = 2, + Fast = 3, + Normal = 4, + Optimal1 = 5, + Optimal2 = 6, + Optimal3 = 7, + Optimal4 = 8, +}; + +/** + * A compressed buffer stores compressed data in a self-contained format. + * + * A buffer is self-contained in the sense that it can be decompressed without external knowledge + * of the compression format or the size of the raw data. + */ +class CompressedBuffer +{ +public: + /** + * Compress the buffer using the specified compressor and compression level. + * + * Data that does not compress will be return uncompressed, as if with level None. + * + * @note Using a level of None will return a buffer that references owned raw data. + * + * @param RawData The raw data to be compressed. + * @param Compressor The compressor to encode with. May use NotSet if level is None. + * @param CompressionLevel The compression level to encode with. + * @param BlockSize The power-of-two block size to encode raw data in. 0 is default. + * @return An owned compressed buffer, or null on error. + */ + [[nodiscard]] ZENCORE_API static CompressedBuffer Compress(const CompositeBuffer& RawData, + OodleCompressor Compressor = OodleCompressor::Mermaid, + OodleCompressionLevel CompressionLevel = OodleCompressionLevel::VeryFast, + uint64_t BlockSize = 0); + [[nodiscard]] ZENCORE_API static CompressedBuffer Compress(const SharedBuffer& RawData, + OodleCompressor Compressor = OodleCompressor::Mermaid, + OodleCompressionLevel CompressionLevel = OodleCompressionLevel::VeryFast, + uint64_t BlockSize = 0); + + /** + * Construct from a compressed buffer previously created by Compress(). + * + * @return A compressed buffer, or null on error, such as an invalid format or corrupt header. + */ + [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressed(const CompositeBuffer& CompressedData, + IoHash& OutRawHash, + uint64_t& OutRawSize); + [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressed(CompositeBuffer&& CompressedData, + IoHash& OutRawHash, + uint64_t& OutRawSize); + [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressed(const SharedBuffer& CompressedData, + IoHash& OutRawHash, + uint64_t& OutRawSize); + [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressed(SharedBuffer&& CompressedData, + IoHash& OutRawHash, + uint64_t& OutRawSize); + [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressedNoValidate(IoBuffer&& CompressedData); + [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressedNoValidate(CompositeBuffer&& CompressedData); + [[nodiscard]] ZENCORE_API static bool ValidateCompressedHeader(IoBuffer&& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize); + [[nodiscard]] ZENCORE_API static bool ValidateCompressedHeader(const IoBuffer& CompressedData, + IoHash& OutRawHash, + uint64_t& OutRawSize); + + /** Reset this to null. */ + inline void Reset() { CompressedData.Reset(); } + + /** Returns true if the compressed buffer is not null. */ + [[nodiscard]] inline explicit operator bool() const { return !IsNull(); } + + /** Returns true if the compressed buffer is null. */ + [[nodiscard]] inline bool IsNull() const { return CompressedData.IsNull(); } + + /** Returns true if the composite buffer is owned. */ + [[nodiscard]] inline bool IsOwned() const { return CompressedData.IsOwned(); } + + /** Returns a copy of the compressed buffer that owns its underlying memory. */ + [[nodiscard]] inline CompressedBuffer MakeOwned() const& { return FromCompressedNoValidate(CompressedData.MakeOwned()); } + [[nodiscard]] inline CompressedBuffer MakeOwned() && { return FromCompressedNoValidate(std::move(CompressedData).MakeOwned()); } + + /** Returns a composite buffer containing the compressed data. May be null. May not be owned. */ + [[nodiscard]] inline const CompositeBuffer& GetCompressed() const& { return CompressedData; } + [[nodiscard]] inline CompositeBuffer GetCompressed() && { return std::move(CompressedData); } + + /** Returns the size of the compressed data. Zero if this is null. */ + [[nodiscard]] inline uint64_t GetCompressedSize() const { return CompressedData.GetSize(); } + + /** Returns the size of the raw data. Zero on error or if this is empty or null. */ + [[nodiscard]] ZENCORE_API uint64_t DecodeRawSize() const; + + /** Returns the hash of the raw data. Zero on error or if this is null. */ + [[nodiscard]] ZENCORE_API IoHash DecodeRawHash() const; + + [[nodiscard]] ZENCORE_API CompressedBuffer CopyRange(uint64_t RawOffset, uint64_t RawSize = ~uint64_t(0)) const; + + /** + * Returns the compressor and compression level used by this buffer. + * + * The compressor and compression level may differ from those specified when creating the buffer + * because an incompressible buffer is stored with no compression. Parameters cannot be accessed + * if this is null or uses a method other than Oodle, in which case this returns false. + * + * @return True if parameters were written, otherwise false. + */ + [[nodiscard]] ZENCORE_API bool TryGetCompressParameters(OodleCompressor& OutCompressor, + OodleCompressionLevel& OutCompressionLevel, + uint64_t& OutBlockSize) const; + + /** + * Decompress into a memory view that is less or equal GetRawSize() bytes. + */ + [[nodiscard]] ZENCORE_API bool TryDecompressTo(MutableMemoryView RawView, uint64_t RawOffset = 0) const; + + /** + * Decompress into an owned buffer. + * + * @return An owned buffer containing the raw data, or null on error. + */ + [[nodiscard]] ZENCORE_API SharedBuffer Decompress(uint64_t RawOffset = 0, uint64_t RawSize = ~uint64_t(0)) const; + + /** + * Decompress into an owned composite buffer. + * + * @return An owned buffer containing the raw data, or null on error. + */ + [[nodiscard]] ZENCORE_API CompositeBuffer DecompressToComposite() const; + + /** A null compressed buffer. */ + static const CompressedBuffer Null; + +private: + CompositeBuffer CompressedData; +}; + +void compress_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/config.h.in b/src/zencore/include/zencore/config.h.in new file mode 100644 index 000000000..3372eca2a --- /dev/null +++ b/src/zencore/include/zencore/config.h.in @@ -0,0 +1,16 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +// NOTE: Generated from config.h.in + +#define ZEN_CFG_VERSION "${VERSION}" +#define ZEN_CFG_VERSION_MAJOR ${VERSION_MAJOR} +#define ZEN_CFG_VERSION_MINOR ${VERSION_MINOR} +#define ZEN_CFG_VERSION_ALTER ${VERSION_ALTER} +#define ZEN_CFG_VERSION_BUILD ${VERSION_BUILD} +#define ZEN_CFG_VERSION_BRANCH "${GIT_BRANCH}" +#define ZEN_CFG_VERSION_COMMIT "${GIT_COMMIT}" +#define ZEN_CFG_VERSION_BUILD_STRING "${VERSION}-${plat}-${arch}-${mode}" +#define ZEN_CFG_VERSION_BUILD_STRING_FULL "${VERSION}-${VERSION_BUILD}-${plat}-${arch}-${mode}-${GIT_COMMIT}" +#define ZEN_CFG_SCHEMA_VERSION ${ZEN_SCHEMA_VERSION} diff --git a/src/zencore/include/zencore/crc32.h b/src/zencore/include/zencore/crc32.h new file mode 100644 index 000000000..336bda77e --- /dev/null +++ b/src/zencore/include/zencore/crc32.h @@ -0,0 +1,13 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +namespace zen { + +uint32_t MemCrc32(const void* InData, size_t Length, uint32_t Crc = 0); +uint32_t MemCrc32_Deprecated(const void* InData, size_t Length, uint32_t Crc = 0); +uint32_t StrCrc_Deprecated(const char* Data); + +} // namespace zen diff --git a/src/zencore/include/zencore/crypto.h b/src/zencore/include/zencore/crypto.h new file mode 100644 index 000000000..83d416b0f --- /dev/null +++ b/src/zencore/include/zencore/crypto.h @@ -0,0 +1,77 @@ + +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/memory.h> +#include <zencore/zencore.h> + +#include <memory> +#include <optional> + +namespace zen { + +template<size_t BitCount> +struct CryptoBits +{ +public: + static constexpr size_t ByteCount = BitCount / 8; + + CryptoBits() = default; + + bool IsNull() const { return memcmp(&m_Bits, &Zero, ByteCount) == 0; } + bool IsValid() const { return IsNull() == false; } + + size_t GetSize() const { return ByteCount; } + size_t GetBitCount() const { return BitCount; } + + MemoryView GetView() const { return MemoryView(m_Bits, ByteCount); } + + static CryptoBits FromMemoryView(MemoryView Bits) + { + if (Bits.GetSize() != ByteCount) + { + return CryptoBits(); + } + + return CryptoBits(Bits); + } + + static CryptoBits FromString(std::string_view Str) { return FromMemoryView(MakeMemoryView(Str)); } + +private: + CryptoBits(MemoryView Bits) + { + ZEN_ASSERT(Bits.GetSize() == GetSize()); + memcpy(&m_Bits, Bits.GetData(), GetSize()); + } + + static constexpr uint8_t Zero[ByteCount] = {0}; + + uint8_t m_Bits[ByteCount] = {0}; +}; + +using AesKey256Bit = CryptoBits<256>; +using AesIV128Bit = CryptoBits<128>; + +class Aes +{ +public: + static constexpr size_t BlockSize = 16; + + static MemoryView Encrypt(const AesKey256Bit& Key, + const AesIV128Bit& IV, + MemoryView In, + MutableMemoryView Out, + std::optional<std::string>& Reason); + + static MemoryView Decrypt(const AesKey256Bit& Key, + const AesIV128Bit& IV, + MemoryView In, + MutableMemoryView Out, + std::optional<std::string>& Reason); +}; + +void crypto_forcelink(); + +} // namespace zen diff --git a/src/zencore/include/zencore/endian.h b/src/zencore/include/zencore/endian.h new file mode 100644 index 000000000..7a9e6b44c --- /dev/null +++ b/src/zencore/include/zencore/endian.h @@ -0,0 +1,113 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#include <cstdint> + +namespace zen { + +inline uint16_t +ByteSwap(uint16_t x) +{ +#if ZEN_COMPILER_MSC + return _byteswap_ushort(x); +#else + return __builtin_bswap16(x); +#endif +} + +inline uint32_t +ByteSwap(uint32_t x) +{ +#if ZEN_COMPILER_MSC + return _byteswap_ulong(x); +#else + return __builtin_bswap32(x); +#endif +} + +inline uint64_t +ByteSwap(uint64_t x) +{ +#if ZEN_COMPILER_MSC + return _byteswap_uint64(x); +#else + return __builtin_bswap64(x); +#endif +} + +inline uint16_t +FromNetworkOrder(uint16_t x) +{ + return ByteSwap(x); +} + +inline uint32_t +FromNetworkOrder(uint32_t x) +{ + return ByteSwap(x); +} + +inline uint64_t +FromNetworkOrder(uint64_t x) +{ + return ByteSwap(x); +} + +inline uint16_t +FromNetworkOrder(int16_t x) +{ + return ByteSwap(uint16_t(x)); +} + +inline uint32_t +FromNetworkOrder(int32_t x) +{ + return ByteSwap(uint32_t(x)); +} + +inline uint64_t +FromNetworkOrder(int64_t x) +{ + return ByteSwap(uint64_t(x)); +} + +inline uint16_t +ToNetworkOrder(uint16_t x) +{ + return ByteSwap(x); +} + +inline uint32_t +ToNetworkOrder(uint32_t x) +{ + return ByteSwap(x); +} + +inline uint64_t +ToNetworkOrder(uint64_t x) +{ + return ByteSwap(x); +} + +inline uint16_t +ToNetworkOrder(int16_t x) +{ + return ByteSwap(uint16_t(x)); +} + +inline uint32_t +ToNetworkOrder(int32_t x) +{ + return ByteSwap(uint32_t(x)); +} + +inline uint64_t +ToNetworkOrder(int64_t x) +{ + return ByteSwap(uint64_t(x)); +} + +} // namespace zen diff --git a/src/zencore/include/zencore/enumflags.h b/src/zencore/include/zencore/enumflags.h new file mode 100644 index 000000000..ebe747bf0 --- /dev/null +++ b/src/zencore/include/zencore/enumflags.h @@ -0,0 +1,61 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +namespace zen { + +// Enum class helpers + +// Defines all bitwise operators for enum classes so it can be (mostly) used as a regular flags enum +#define ENUM_CLASS_FLAGS(Enum) \ + inline Enum& operator|=(Enum& Lhs, Enum Rhs) { return Lhs = (Enum)((__underlying_type(Enum))Lhs | (__underlying_type(Enum))Rhs); } \ + inline Enum& operator&=(Enum& Lhs, Enum Rhs) { return Lhs = (Enum)((__underlying_type(Enum))Lhs & (__underlying_type(Enum))Rhs); } \ + inline Enum& operator^=(Enum& Lhs, Enum Rhs) { return Lhs = (Enum)((__underlying_type(Enum))Lhs ^ (__underlying_type(Enum))Rhs); } \ + inline constexpr Enum operator|(Enum Lhs, Enum Rhs) { return (Enum)((__underlying_type(Enum))Lhs | (__underlying_type(Enum))Rhs); } \ + inline constexpr Enum operator&(Enum Lhs, Enum Rhs) { return (Enum)((__underlying_type(Enum))Lhs & (__underlying_type(Enum))Rhs); } \ + inline constexpr Enum operator^(Enum Lhs, Enum Rhs) { return (Enum)((__underlying_type(Enum))Lhs ^ (__underlying_type(Enum))Rhs); } \ + inline constexpr bool operator!(Enum E) { return !(__underlying_type(Enum))E; } \ + inline constexpr Enum operator~(Enum E) { return (Enum) ~(__underlying_type(Enum))E; } + +// Friends all bitwise operators for enum classes so the definition can be kept private / protected. +#define FRIEND_ENUM_CLASS_FLAGS(Enum) \ + friend Enum& operator|=(Enum& Lhs, Enum Rhs); \ + friend Enum& operator&=(Enum& Lhs, Enum Rhs); \ + friend Enum& operator^=(Enum& Lhs, Enum Rhs); \ + friend constexpr Enum operator|(Enum Lhs, Enum Rhs); \ + friend constexpr Enum operator&(Enum Lhs, Enum Rhs); \ + friend constexpr Enum operator^(Enum Lhs, Enum Rhs); \ + friend constexpr bool operator!(Enum E); \ + friend constexpr Enum operator~(Enum E); + +template<typename Enum> +constexpr bool +EnumHasAllFlags(Enum Flags, Enum Contains) +{ + return (((__underlying_type(Enum))Flags) & (__underlying_type(Enum))Contains) == ((__underlying_type(Enum))Contains); +} + +template<typename Enum> +constexpr bool +EnumHasAnyFlags(Enum Flags, Enum Contains) +{ + return (((__underlying_type(Enum))Flags) & (__underlying_type(Enum))Contains) != 0; +} + +template<typename Enum> +void +EnumAddFlags(Enum& Flags, Enum FlagsToAdd) +{ + Flags |= FlagsToAdd; +} + +template<typename Enum> +void +EnumRemoveFlags(Enum& Flags, Enum FlagsToRemove) +{ + Flags &= ~FlagsToRemove; +} + +} // namespace zen diff --git a/src/zencore/include/zencore/except.h b/src/zencore/include/zencore/except.h new file mode 100644 index 000000000..c61db5ba9 --- /dev/null +++ b/src/zencore/include/zencore/except.h @@ -0,0 +1,57 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/string.h> +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#else +# include <errno.h> +#endif +#if __has_include("source_location") +# include <source_location> +#endif +#include <string> +#include <system_error> + +namespace zen { + +#if ZEN_PLATFORM_WINDOWS +ZENCORE_API void ThrowSystemException [[noreturn]] (HRESULT hRes, std::string_view Message); +#endif // ZEN_PLATFORM_WINDOWS + +#if defined(__cpp_lib_source_location) +ZENCORE_API void ThrowLastErrorImpl [[noreturn]] (std::string_view Message, const std::source_location& Location); +# define ThrowLastError(Message) ThrowLastErrorImpl(Message, std::source_location::current()) +#else +ZENCORE_API void ThrowLastError [[noreturn]] (std::string_view Message); +#endif + +ZENCORE_API void ThrowSystemError [[noreturn]] (uint32_t ErrorCode, std::string_view Message); + +ZENCORE_API std::string GetLastErrorAsString(); +ZENCORE_API std::string GetSystemErrorAsString(uint32_t Win32ErrorCode); + +inline int32_t +GetLastError() +{ +#if ZEN_PLATFORM_WINDOWS + return ::GetLastError(); +#else + return errno; +#endif +} + +inline std::error_code +MakeErrorCode(uint32_t ErrorCode) noexcept +{ + return std::error_code(ErrorCode, std::system_category()); +} + +inline std::error_code +MakeErrorCodeFromLastError() noexcept +{ + return std::error_code(zen::GetLastError(), std::system_category()); +} + +} // namespace zen diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h new file mode 100644 index 000000000..fa5f94170 --- /dev/null +++ b/src/zencore/include/zencore/filesystem.h @@ -0,0 +1,190 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#include <zencore/iobuffer.h> +#include <zencore/string.h> + +#include <filesystem> +#include <functional> + +namespace zen { + +class IoBuffer; + +/** Delete directory (after deleting any contents) + */ +ZENCORE_API bool DeleteDirectories(const std::filesystem::path& dir); + +/** Ensure directory exists. + + Will also create any required parent directories + */ +ZENCORE_API bool CreateDirectories(const std::filesystem::path& dir); + +/** Ensure directory exists and delete contents (if any) before returning + */ +ZENCORE_API bool CleanDirectory(const std::filesystem::path& dir); + +/** Map native file handle to a path + */ +ZENCORE_API std::filesystem::path PathFromHandle(void* NativeHandle); + +ZENCORE_API std::filesystem::path GetRunningExecutablePath(); + +/** Set the max open file handle count to max allowed for the current process on Linux and MacOS + */ +ZENCORE_API void MaximizeOpenFileCount(); + +struct FileContents +{ + std::vector<IoBuffer> Data; + std::error_code ErrorCode; + + IoBuffer Flatten(); +}; + +ZENCORE_API FileContents ReadStdIn(); +ZENCORE_API FileContents ReadFile(std::filesystem::path Path); +ZENCORE_API bool ScanFile(std::filesystem::path Path, uint64_t ChunkSize, std::function<void(const void* Data, size_t Size)>&& ProcessFunc); +ZENCORE_API void WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t BufferCount); +ZENCORE_API void WriteFile(std::filesystem::path Path, IoBuffer Data); + +struct CopyFileOptions +{ + bool EnableClone = true; + bool MustClone = false; +}; + +ZENCORE_API bool CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const CopyFileOptions& Options); +ZENCORE_API bool SupportsBlockRefCounting(std::filesystem::path Path); + +ZENCORE_API void PathToUtf8(const std::filesystem::path& Path, StringBuilderBase& Out); +ZENCORE_API std::string PathToUtf8(const std::filesystem::path& Path); + +extern template class StringBuilderImpl<std::filesystem::path::value_type>; + +/** + * Helper class for building paths. Backed by a string builder. + * + */ +class PathBuilderBase : public StringBuilderImpl<std::filesystem::path::value_type> +{ +private: + using Super = StringBuilderImpl<std::filesystem::path::value_type>; + +protected: + using CharType = std::filesystem::path::value_type; + using ViewType = std::basic_string_view<CharType>; + +public: + void Append(const std::filesystem::path& Rhs) { Super::Append(Rhs.c_str()); } + void operator/=(const std::filesystem::path& Rhs) { this->operator/=(Rhs.c_str()); }; + void operator/=(const CharType* Rhs) + { + AppendSeparator(); + Super::Append(Rhs); + } + operator ViewType() const { return ToView(); } + std::basic_string_view<CharType> ToView() const { return std::basic_string_view<CharType>(Data(), Size()); } + std::filesystem::path ToPath() const { return std::filesystem::path(ToView()); } + + std::string ToUtf8() const + { +#if ZEN_PLATFORM_WINDOWS + return WideToUtf8(ToView()); +#else + return std::string(ToView()); +#endif + } + + void AppendSeparator() + { + if (ToView().ends_with(std::filesystem::path::preferred_separator) +#if ZEN_PLATFORM_WINDOWS + || ToView().ends_with('/') +#endif + ) + return; + + Super::Append(std::filesystem::path::preferred_separator); + } +}; + +template<size_t N> +class PathBuilder : public PathBuilderBase +{ +public: + PathBuilder() { Init(m_Buffer, N); } + +private: + PathBuilderBase::CharType m_Buffer[N]; +}; + +template<size_t N> +class ExtendablePathBuilder : public PathBuilder<N> +{ +public: + ExtendablePathBuilder() { this->m_IsExtendable = true; } +}; + +struct DiskSpace +{ + uint64_t Free{}; + uint64_t Total{}; +}; + +ZENCORE_API DiskSpace DiskSpaceInfo(std::filesystem::path Directory, std::error_code& Error); + +inline bool +DiskSpaceInfo(std::filesystem::path Directory, DiskSpace& Space) +{ + std::error_code Err; + Space = DiskSpaceInfo(Directory, Err); + return !Err; +} + +/** + * Efficient file system traversal + * + * Uses the best available mechanism for the platform in question and could take + * advantage of any file system tracking mechanisms in the future + * + */ +class FileSystemTraversal +{ +public: + struct TreeVisitor + { + using path_view = std::basic_string_view<std::filesystem::path::value_type>; + using path_string = std::filesystem::path::string_type; + + virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t FileSize) = 0; + + // This should return true if we should recurse into the directory + virtual bool VisitDirectory(const std::filesystem::path& Parent, const path_view& DirectoryName) = 0; + }; + + void TraverseFileSystem(const std::filesystem::path& RootDir, TreeVisitor& Visitor); +}; + +struct DirectoryContent +{ + static const uint8_t IncludeDirsFlag = 1u << 0; + static const uint8_t IncludeFilesFlag = 1u << 1; + static const uint8_t RecursiveFlag = 1u << 2; + std::vector<std::filesystem::path> Files; + std::vector<std::filesystem::path> Directories; +}; + +void GetDirectoryContent(const std::filesystem::path& RootDir, uint8_t Flags, DirectoryContent& OutContent); + +std::string GetEnvVariable(std::string_view VariableName); + +////////////////////////////////////////////////////////////////////////// + +void filesystem_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/fmtutils.h b/src/zencore/include/zencore/fmtutils.h new file mode 100644 index 000000000..70867fe72 --- /dev/null +++ b/src/zencore/include/zencore/fmtutils.h @@ -0,0 +1,52 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> +#include <zencore/string.h> +#include <zencore/uid.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <filesystem> +#include <string_view> + +// Custom formatting for some zencore types + +template<> +struct fmt::formatter<zen::IoHash> : formatter<string_view> +{ + template<typename FormatContext> + auto format(const zen::IoHash& Hash, FormatContext& ctx) + { + zen::IoHash::String_t String; + Hash.ToHexString(String); + return formatter<string_view>::format({String, zen::IoHash::StringLength}, ctx); + } +}; + +template<> +struct fmt::formatter<zen::Oid> : formatter<string_view> +{ + template<typename FormatContext> + auto format(const zen::Oid& Id, FormatContext& ctx) + { + zen::StringBuilder<32> String; + Id.ToString(String); + return formatter<string_view>::format({String.c_str(), zen::Oid::StringLength}, ctx); + } +}; + +template<> +struct fmt::formatter<std::filesystem::path> : formatter<string_view> +{ + template<typename FormatContext> + auto format(const std::filesystem::path& Path, FormatContext& ctx) + { + zen::ExtendableStringBuilder<128> String; + String << Path.u8string(); + return formatter<string_view>::format(String.ToView(), ctx); + } +}; diff --git a/src/zencore/include/zencore/intmath.h b/src/zencore/include/zencore/intmath.h new file mode 100644 index 000000000..f24caed6e --- /dev/null +++ b/src/zencore/include/zencore/intmath.h @@ -0,0 +1,183 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#include <stdint.h> + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_COMPILER_MSC || ZEN_PLATFORM_WINDOWS +# pragma intrinsic(_BitScanReverse) +# pragma intrinsic(_BitScanReverse64) +#else +inline uint8_t +_BitScanReverse(unsigned long* Index, uint32_t Mask) +{ + if (Mask == 0) + { + return 0; + } + + *Index = 31 - __builtin_clz(Mask); + return 1; +} + +inline uint8_t +_BitScanReverse64(unsigned long* Index, uint64_t Mask) +{ + if (Mask == 0) + { + return 0; + } + + *Index = 63 - __builtin_clzll(Mask); + return 1; +} + +inline uint8_t +_BitScanForward64(unsigned long* Index, uint64_t Mask) +{ + if (Mask == 0) + { + return 0; + } + + *Index = __builtin_ctzll(Mask); + return 1; +} +#endif + +namespace zen { + +inline constexpr bool +IsPow2(uint64_t n) +{ + return 0 == (n & (n - 1)); +} + +/// Round an integer up to the closest integer multiplier of 'base' ('base' must be a power of two) +template<Integral T> +T +RoundUp(T Value, auto Base) +{ + ZEN_ASSERT_SLOW(IsPow2(Base)); + return ((Value + T(Base - 1)) & (~T(Base - 1))); +} + +bool +IsMultipleOf(Integral auto Value, auto MultiplierPow2) +{ + ZEN_ASSERT_SLOW(IsPow2(MultiplierPow2)); + return (Value & (MultiplierPow2 - 1)) == 0; +} + +inline uint64_t +NextPow2(uint64_t n) +{ + // http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 + + --n; + + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n |= n >> 32; + + return n + 1; +} + +static inline uint32_t +FloorLog2(uint32_t Value) +{ + // Use BSR to return the log2 of the integer + unsigned long Log2; + if (_BitScanReverse(&Log2, Value) != 0) + { + return Log2; + } + + return 0; +} + +static inline uint32_t +CountLeadingZeros(uint32_t Value) +{ + unsigned long Log2 = 0; + _BitScanReverse64(&Log2, (uint64_t(Value) << 1) | 1); + return 32 - Log2; +} + +static inline uint64_t +FloorLog2_64(uint64_t Value) +{ + unsigned long Log2 = 0; + long Mask = -long(_BitScanReverse64(&Log2, Value) != 0); + return Log2 & Mask; +} + +static inline uint64_t +CountLeadingZeros64(uint64_t Value) +{ + unsigned long Log2 = 0; + long Mask = -long(_BitScanReverse64(&Log2, Value) != 0); + return ((63 - Log2) & Mask) | (64 & ~Mask); +} + +static inline uint64_t +CeilLogTwo64(uint64_t Arg) +{ + int64_t Bitmask = ((int64_t)(CountLeadingZeros64(Arg) << 57)) >> 63; + return (64 - CountLeadingZeros64(Arg - 1)) & (~Bitmask); +} + +static inline uint64_t +CountTrailingZeros64(uint64_t Value) +{ + if (Value == 0) + { + return 64; + } + unsigned long BitIndex; // 0-based, where the LSB is 0 and MSB is 31 + _BitScanForward64(&BitIndex, Value); // Scans from LSB to MSB + return BitIndex; +} + +////////////////////////////////////////////////////////////////////////// + +static inline bool +IsPointerAligned(const void* Ptr, uint64_t Alignment) +{ + ZEN_ASSERT_SLOW(IsPow2(Alignment)); + + return 0 == (reinterpret_cast<uintptr_t>(Ptr) & (Alignment - 1)); +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_PLATFORM_WINDOWS +# ifdef min +# error "Looks like you did #include <windows.h> -- use <zencore/windows.h> instead" +# endif +#endif + +constexpr auto +Min(auto x, auto y) +{ + return x < y ? x : y; +} + +constexpr auto +Max(auto x, auto y) +{ + return x > y ? x : y; +} + +////////////////////////////////////////////////////////////////////////// + +void intmath_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/iobuffer.h b/src/zencore/include/zencore/iobuffer.h new file mode 100644 index 000000000..a39dbf6d6 --- /dev/null +++ b/src/zencore/include/zencore/iobuffer.h @@ -0,0 +1,423 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <memory.h> +#include <zencore/memory.h> +#include <atomic> +#include "refcount.h" +#include "zencore.h" + +#include <filesystem> + +namespace zen { + +struct IoHash; +struct IoBufferExtendedCore; + +enum class ZenContentType : uint8_t +{ + kBinary = 0, // Note that since this is zero, this will be the default value in IoBuffer + kText = 1, + kJSON = 2, + kCbObject = 3, + kCbPackage = 4, + kYAML = 5, + kCbPackageOffer = 6, + kCompressedBinary = 7, + kUnknownContentType = 8, + kHTML = 9, + kJavaScript = 10, + kCSS = 11, + kPNG = 12, + kIcon = 13, + kCOUNT +}; + +inline std::string_view +ToString(ZenContentType ContentType) +{ + using namespace std::literals; + + switch (ContentType) + { + default: + case ZenContentType::kUnknownContentType: + return "unknown"sv; + case ZenContentType::kBinary: + return "binary"sv; + case ZenContentType::kText: + return "text"sv; + case ZenContentType::kJSON: + return "json"sv; + case ZenContentType::kCbObject: + return "cb-object"sv; + case ZenContentType::kCbPackage: + return "cb-package"sv; + case ZenContentType::kCbPackageOffer: + return "cb-package-offer"sv; + case ZenContentType::kCompressedBinary: + return "compressed-binary"sv; + case ZenContentType::kYAML: + return "yaml"sv; + case ZenContentType::kHTML: + return "html"sv; + case ZenContentType::kJavaScript: + return "javascript"sv; + case ZenContentType::kCSS: + return "css"sv; + case ZenContentType::kPNG: + return "png"sv; + case ZenContentType::kIcon: + return "icon"sv; + } +} + +struct IoBufferFileReference +{ + void* FileHandle; + uint64_t FileChunkOffset; + uint64_t FileChunkSize; +}; + +struct IoBufferCore +{ +public: + inline IoBufferCore() : m_Flags(kIsNull) {} + inline IoBufferCore(const void* DataPtr, size_t SizeBytes) : m_DataPtr(DataPtr), m_DataBytes(SizeBytes) {} + inline IoBufferCore(const IoBufferCore* Outer, const void* DataPtr, size_t SizeBytes) + : m_DataPtr(DataPtr) + , m_DataBytes(SizeBytes) + , m_OuterCore(Outer) + { + } + + ZENCORE_API explicit IoBufferCore(size_t SizeBytes); + ZENCORE_API IoBufferCore(size_t SizeBytes, size_t Alignment); + ZENCORE_API ~IoBufferCore(); + + // Reference counting + + inline uint32_t AddRef() const { return AtomicIncrement(const_cast<IoBufferCore*>(this)->m_RefCount); } + inline uint32_t Release() const + { + const uint32_t NewRefCount = AtomicDecrement(const_cast<IoBufferCore*>(this)->m_RefCount); + if (NewRefCount == 0) + { + DeleteThis(); + } + return NewRefCount; + } + + // Copying reference counted objects doesn't make a lot of sense generally, so let's prevent it + + IoBufferCore(const IoBufferCore&) = delete; + IoBufferCore(IoBufferCore&&) = delete; + IoBufferCore& operator=(const IoBufferCore&) = delete; + IoBufferCore& operator=(IoBufferCore&&) = delete; + + // + + ZENCORE_API void Materialize() const; + ZENCORE_API void DeleteThis() const; + ZENCORE_API void MakeOwned(bool Immutable = true); + + inline void EnsureDataValid() const + { + const uint32_t LocalFlags = m_Flags.load(std::memory_order_acquire); + if ((LocalFlags & kIsExtended) && !(LocalFlags & kIsMaterialized)) + { + Materialize(); + } + } + + inline bool IsOwnedByThis() const { return !!(m_Flags.load(std::memory_order_relaxed) & kIsOwnedByThis); } + + inline void SetIsOwnedByThis(bool NewState) + { + if (NewState) + { + m_Flags.fetch_or(kIsOwnedByThis, std::memory_order_relaxed); + } + else + { + m_Flags.fetch_and(~kIsOwnedByThis, std::memory_order_relaxed); + } + } + + inline bool IsOwned() const + { + if (IsOwnedByThis()) + { + return true; + } + return m_OuterCore && m_OuterCore->IsOwned(); + } + + inline bool IsImmutable() const { return (m_Flags.load(std::memory_order_relaxed) & kIsMutable) == 0; } + inline bool IsWholeFile() const { return (m_Flags.load(std::memory_order_relaxed) & kIsWholeFile) != 0; } + inline bool IsNull() const { return (m_Flags.load(std::memory_order_relaxed) & kIsNull) != 0; } + + inline IoBufferExtendedCore* ExtendedCore(); + inline const IoBufferExtendedCore* ExtendedCore() const; + + ZENCORE_API void* MutableDataPointer() const; + + inline const void* DataPointer() const + { + EnsureDataValid(); + return m_DataPtr; + } + + inline size_t DataBytes() const { return m_DataBytes; } + + inline void Set(const void* Ptr, size_t Sz) + { + m_DataPtr = Ptr; + m_DataBytes = Sz; + } + + inline void SetIsImmutable(bool NewState) + { + if (!NewState) + { + m_Flags.fetch_or(kIsMutable, std::memory_order_relaxed); + } + else + { + m_Flags.fetch_and(~kIsMutable, std::memory_order_relaxed); + } + } + + inline void SetIsWholeFile(bool NewState) + { + if (NewState) + { + m_Flags.fetch_or(kIsWholeFile, std::memory_order_relaxed); + } + else + { + m_Flags.fetch_and(~kIsWholeFile, std::memory_order_relaxed); + } + } + + inline void SetContentType(ZenContentType ContentType) + { + ZEN_ASSERT_SLOW((uint32_t(ContentType) & kContentTypeMask) == uint32_t(ContentType)); + uint32_t OldValue = m_Flags.load(std::memory_order_relaxed); + uint32_t NewValue; + do + { + NewValue = (OldValue & ~(kContentTypeMask << kContentTypeShift)) | (uint32_t(ContentType) << kContentTypeShift); + } while (!m_Flags.compare_exchange_weak(OldValue, NewValue, std::memory_order_relaxed, std::memory_order_relaxed)); + } + + inline ZenContentType GetContentType() const + { + return ZenContentType((m_Flags.load(std::memory_order_relaxed) >> kContentTypeShift) & kContentTypeMask); + } + + inline uint32_t GetRefCount() const { return m_RefCount; } + +protected: + uint32_t m_RefCount = 0; + mutable std::atomic<uint32_t> m_Flags{0}; + mutable const void* m_DataPtr = nullptr; + size_t m_DataBytes = 0; + RefPtr<const IoBufferCore> m_OuterCore; + + enum + { + kContentTypeShift = 24, + kContentTypeMask = 0xf + }; + + static_assert((uint32_t(ZenContentType::kUnknownContentType) & ~kContentTypeMask) == 0); + + enum Flags : uint32_t + { + kIsNull = 1 << 0, // This is a null IoBuffer + kIsMutable = 1 << 1, + kIsExtended = 1 << 2, // Is actually a SharedBufferExtendedCore + kIsMaterialized = 1 << 3, // Data pointers are valid + kLowLevelAlloc = 1 << 4, // Using direct memory allocation + kIsWholeFile = 1 << 5, // References an entire file + kIoBufferAlloc = 1 << 6, // Using IoBuffer allocator + kIsOwnedByThis = 1 << 7, + + // Note that we have some extended flags defined below + // so not all bits are available to use here + + kContentTypeBit0 = 1 << (24 + 0), // These constants + kContentTypeBit1 = 1 << (24 + 1), // are here mostly to + kContentTypeBit2 = 1 << (24 + 2), // indicate that these + kContentTypeBit3 = 1 << (24 + 3), // bits are reserved + }; + + void AllocateBuffer(size_t InSize, size_t Alignment) const; + void FreeBuffer(); +}; + +/** + * An "Extended" core references a segment of a file + */ + +struct IoBufferExtendedCore : public IoBufferCore +{ + IoBufferExtendedCore(void* FileHandle, uint64_t Offset, uint64_t Size, bool TransferHandleOwnership); + IoBufferExtendedCore(const IoBufferExtendedCore* Outer, uint64_t Offset, uint64_t Size); + ~IoBufferExtendedCore(); + + enum ExtendedFlags + { + kOwnsFile = 1 << 16, + kOwnsMmap = 1 << 17 + }; + + void Materialize() const; + bool GetFileReference(IoBufferFileReference& OutRef) const; + void MarkAsDeleteOnClose(); + +private: + void* m_FileHandle = nullptr; + uint64_t m_FileOffset = 0; + mutable void* m_MmapHandle = nullptr; + mutable void* m_MappedPointer = nullptr; + bool m_DeleteOnClose = false; +}; + +inline IoBufferExtendedCore* +IoBufferCore::ExtendedCore() +{ + if (m_Flags.load(std::memory_order_relaxed) & kIsExtended) + { + return static_cast<IoBufferExtendedCore*>(this); + } + + return nullptr; +} + +inline const IoBufferExtendedCore* +IoBufferCore::ExtendedCore() const +{ + if (m_Flags.load(std::memory_order_relaxed) & kIsExtended) + { + return static_cast<const IoBufferExtendedCore*>(this); + } + + return nullptr; +} + +/** + * I/O buffer + * + * This represents a reference to a payload in memory or on disk + * + */ +class IoBuffer +{ +public: + enum ECloneTag + { + Clone + }; + enum EWrapTag + { + Wrap + }; + enum EFileTag + { + File + }; + enum EBorrowedFileTag + { + BorrowedFile + }; + + inline IoBuffer() = default; + inline IoBuffer(IoBuffer&& Rhs) noexcept = default; + inline IoBuffer(const IoBuffer& Rhs) = default; + inline IoBuffer& operator=(const IoBuffer& Rhs) = default; + inline IoBuffer& operator=(IoBuffer&& Rhs) noexcept = default; + + /** Create an uninitialized buffer of the given size + */ + ZENCORE_API explicit IoBuffer(size_t InSize); + + /** Create an uninitialized buffer of the given size with the specified alignment + */ + ZENCORE_API explicit IoBuffer(size_t InSize, uint64_t InAlignment); + + /** Create a buffer which references a sequence of bytes inside another buffer + */ + ZENCORE_API IoBuffer(const IoBuffer& OuterBuffer, size_t Offset, size_t SizeBytes = ~0ull); + + /** Create a buffer which references a range of bytes which we assume will live + * for the entire life time. + */ + inline IoBuffer(EWrapTag, const void* DataPtr, size_t SizeBytes) : m_Core(new IoBufferCore(DataPtr, SizeBytes)) {} + + inline IoBuffer(ECloneTag, const void* DataPtr, size_t SizeBytes) : m_Core(new IoBufferCore(SizeBytes)) + { + memcpy(const_cast<void*>(m_Core->DataPointer()), DataPtr, SizeBytes); + } + + ZENCORE_API IoBuffer(EFileTag, void* FileHandle, uint64_t ChunkFileOffset, uint64_t ChunkSize); + ZENCORE_API IoBuffer(EBorrowedFileTag, void* FileHandle, uint64_t ChunkFileOffset, uint64_t ChunkSize); + + inline explicit operator bool() const { return !m_Core->IsNull(); } + inline operator MemoryView() const& { return MemoryView(m_Core->DataPointer(), m_Core->DataBytes()); } + inline void MakeOwned() { return m_Core->MakeOwned(); } + [[nodiscard]] inline bool IsOwned() const { return m_Core->IsOwned(); } + [[nodiscard]] inline bool IsWholeFile() const { return m_Core->IsWholeFile(); } + [[nodiscard]] void* MutableData() const { return m_Core->MutableDataPointer(); } + void MakeImmutable() { m_Core->SetIsImmutable(true); } + [[nodiscard]] const void* Data() const { return m_Core->DataPointer(); } + [[nodiscard]] const void* GetData() const { return m_Core->DataPointer(); } + [[nodiscard]] size_t Size() const { return m_Core->DataBytes(); } + [[nodiscard]] size_t GetSize() const { return m_Core->DataBytes(); } + inline void SetContentType(ZenContentType ContentType) { m_Core->SetContentType(ContentType); } + [[nodiscard]] inline ZenContentType GetContentType() const { return m_Core->GetContentType(); } + [[nodiscard]] ZENCORE_API bool GetFileReference(IoBufferFileReference& OutRef) const; + void MarkAsDeleteOnClose(); + + inline MemoryView GetView() const { return MemoryView(m_Core->DataPointer(), m_Core->DataBytes()); } + inline MutableMemoryView GetMutableView() { return MutableMemoryView(m_Core->MutableDataPointer(), m_Core->DataBytes()); } + + template<typename T> + [[nodiscard]] const T* Data() const + { + return reinterpret_cast<const T*>(m_Core->DataPointer()); + } + + template<typename T> + [[nodiscard]] T* MutableData() const + { + return reinterpret_cast<T*>(m_Core->MutableDataPointer()); + } + +private: + RefPtr<IoBufferCore> m_Core = new IoBufferCore; + + IoBuffer(IoBufferCore* Core) : m_Core(Core) {} + + friend class SharedBuffer; + friend class IoBufferBuilder; +}; + +class IoBufferBuilder +{ +public: + ZENCORE_API static IoBuffer MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset = 0, uint64_t Size = ~0ull); + ZENCORE_API static IoBuffer MakeFromTemporaryFile(const std::filesystem::path& FileName); + ZENCORE_API static IoBuffer MakeFromFileHandle(void* FileHandle, uint64_t Offset = 0, uint64_t Size = ~0ull); + ZENCORE_API static IoBuffer ReadFromFileMaybe(IoBuffer& InBuffer); + inline static IoBuffer MakeCloneFromMemory(const void* Ptr, size_t Sz) { return IoBuffer(IoBuffer::Clone, Ptr, Sz); } + inline static IoBuffer MakeCloneFromMemory(MemoryView Memory) { return IoBuffer(IoBuffer::Clone, Memory.GetData(), Memory.GetSize()); } +}; + +IoHash HashBuffer(IoBuffer& Buffer); + +void iobuffer_forcelink(); + +} // namespace zen diff --git a/src/zencore/include/zencore/iohash.h b/src/zencore/include/zencore/iohash.h new file mode 100644 index 000000000..fd0f4b2a7 --- /dev/null +++ b/src/zencore/include/zencore/iohash.h @@ -0,0 +1,115 @@ +// Copyright Epic Games, Inc. All Rights Reserved. +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#include <zencore/blake3.h> +#include <zencore/memory.h> + +#include <compare> +#include <string_view> + +namespace zen { + +class StringBuilderBase; +class CompositeBuffer; + +/** + * Hash used for content addressable storage + * + * This is basically a BLAKE3-160 hash (note: this is probably not an officially + * recognized identifier). It is generated by computing a 32-byte BLAKE3 hash and + * picking the first 20 bytes of the resulting hash. + * + */ +struct IoHash +{ + alignas(uint32_t) uint8_t Hash[20] = {}; + + static IoHash MakeFrom(const void* data /* 20 bytes */) + { + IoHash Io; + memcpy(Io.Hash, data, sizeof Io); + return Io; + } + + static IoHash FromBLAKE3(const BLAKE3& Blake3) + { + IoHash Io; + memcpy(Io.Hash, Blake3.Hash, sizeof Io.Hash); + return Io; + } + + static IoHash HashBuffer(const void* data, size_t byteCount); + static IoHash HashBuffer(MemoryView Data) { return HashBuffer(Data.GetData(), Data.GetSize()); } + static IoHash HashBuffer(const CompositeBuffer& Buffer); + static IoHash FromHexString(const char* string); + static IoHash FromHexString(const std::string_view string); + const char* ToHexString(char* outString /* 40 characters + NUL terminator */) const; + StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const; + std::string ToHexString() const; + + static const int StringLength = 40; + typedef char String_t[StringLength + 1]; + + static const IoHash Zero; // Initialized to all zeros + + inline auto operator<=>(const IoHash& rhs) const = default; + + struct Hasher + { + size_t operator()(const IoHash& v) const + { + size_t h; + memcpy(&h, v.Hash, sizeof h); + return h; + } + }; +}; + +struct IoHashStream +{ + /// Begin streaming hash compute (not needed on freshly constructed instance) + void Reset() { m_Blake3Stream.Reset(); } + + /// Append another chunk + IoHashStream& Append(const void* data, size_t byteCount) + { + m_Blake3Stream.Append(data, byteCount); + return *this; + } + + /// Append another chunk + IoHashStream& Append(MemoryView Data) + { + m_Blake3Stream.Append(Data.GetData(), Data.GetSize()); + return *this; + } + + /// Obtain final hash. If you wish to reuse the instance call reset() + IoHash GetHash() + { + BLAKE3 b3 = m_Blake3Stream.GetHash(); + + IoHash Io; + memcpy(Io.Hash, b3.Hash, sizeof Io.Hash); + + return Io; + } + +private: + BLAKE3Stream m_Blake3Stream; +}; + +} // namespace zen + +namespace std { + +template<> +struct hash<zen::IoHash> : public zen::IoHash::Hasher +{ +}; + +} // namespace std diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h new file mode 100644 index 000000000..5cbe034cf --- /dev/null +++ b/src/zencore/include/zencore/logging.h @@ -0,0 +1,136 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <spdlog/spdlog.h> +#undef GetObject +ZEN_THIRD_PARTY_INCLUDES_END + +#include <string_view> + +namespace zen::logging { + +spdlog::logger& Default(); +void SetDefault(std::shared_ptr<spdlog::logger> NewDefaultLogger); +spdlog::logger& ConsoleLog(); +spdlog::logger& Get(std::string_view Name); +spdlog::logger* ErrorLog(); +void SetErrorLog(std::shared_ptr<spdlog::logger>&& NewErrorLogger); + +void InitializeLogging(); +void ShutdownLogging(); + +} // namespace zen::logging + +namespace zen { +extern spdlog::logger* TheDefaultLogger; + +inline spdlog::logger& +Log() +{ + return *TheDefaultLogger; +} + +using logging::ConsoleLog; +using logging::ErrorLog; +} // namespace zen + +using zen::ConsoleLog; +using zen::ErrorLog; +using zen::Log; + +struct LogCategory +{ + LogCategory(std::string_view InCategory) : Category(InCategory) {} + + spdlog::logger& Logger() + { + static spdlog::logger& Inst = zen::logging::Get(Category); + return Inst; + } + + std::string Category; +}; + +inline consteval bool +LogIsErrorLevel(int level) +{ + return (level == spdlog::level::err || level == spdlog::level::critical); +}; + +#define ZEN_LOG_WITH_LOCATION(logger, loc, level, fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + if (logger.should_log(level)) \ + { \ + logger.log(loc, level, fmtstr, ##__VA_ARGS__); \ + if (LogIsErrorLevel(level)) \ + { \ + if (auto ErrLogger = zen::logging::ErrorLog(); ErrLogger != nullptr) \ + { \ + ErrLogger->log(loc, level, fmtstr, ##__VA_ARGS__); \ + } \ + } \ + } \ + } while (false); + +#define ZEN_LOG(logger, level, fmtstr, ...) ZEN_LOG_WITH_LOCATION(logger, spdlog::source_loc{}, level, fmtstr, ##__VA_ARGS__) + +#define ZEN_DEFINE_LOG_CATEGORY_STATIC(Category, Name) \ + static struct LogCategory##Category : public LogCategory \ + { \ + LogCategory##Category() : LogCategory(Name) {} \ + } Category; + +#define ZEN_LOG_TRACE(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), spdlog::level::trace, fmtstr##sv, ##__VA_ARGS__) + +#define ZEN_LOG_DEBUG(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), spdlog::level::debug, fmtstr##sv, ##__VA_ARGS__) + +#define ZEN_LOG_INFO(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), spdlog::level::info, fmtstr##sv, ##__VA_ARGS__) + +#define ZEN_LOG_WARN(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), spdlog::level::warn, fmtstr##sv, ##__VA_ARGS__) + +#define ZEN_LOG_ERROR(Category, fmtstr, ...) \ + ZEN_LOG_WITH_LOCATION(Category.Logger(), \ + spdlog::source_loc(__FILE__, __LINE__, SPDLOG_FUNCTION), \ + spdlog::level::err, \ + fmtstr##sv, \ + ##__VA_ARGS__) + +#define ZEN_LOG_CRITICAL(Category, fmtstr, ...) \ + ZEN_LOG_WITH_LOCATION(Category.Logger(), \ + spdlog::source_loc(__FILE__, __LINE__, SPDLOG_FUNCTION), \ + spdlog::level::critical, \ + fmtstr##sv, \ + ##__VA_ARGS__) + + // Helper macros for logging + +#define ZEN_TRACE(fmtstr, ...) ZEN_LOG(Log(), spdlog::level::trace, fmtstr##sv, ##__VA_ARGS__) + +#define ZEN_DEBUG(fmtstr, ...) ZEN_LOG(Log(), spdlog::level::debug, fmtstr##sv, ##__VA_ARGS__) + +#define ZEN_INFO(fmtstr, ...) ZEN_LOG(Log(), spdlog::level::info, fmtstr##sv, ##__VA_ARGS__) + +#define ZEN_WARN(fmtstr, ...) ZEN_LOG(Log(), spdlog::level::warn, fmtstr##sv, ##__VA_ARGS__) + +#define ZEN_ERROR(fmtstr, ...) \ + ZEN_LOG_WITH_LOCATION(Log(), spdlog::source_loc(__FILE__, __LINE__, SPDLOG_FUNCTION), spdlog::level::err, fmtstr##sv, ##__VA_ARGS__) + +#define ZEN_CRITICAL(fmtstr, ...) \ + ZEN_LOG_WITH_LOCATION(Log(), \ + spdlog::source_loc(__FILE__, __LINE__, SPDLOG_FUNCTION), \ + spdlog::level::critical, \ + fmtstr##sv, \ + ##__VA_ARGS__) + +#define ZEN_CONSOLE(fmtstr, ...) \ + do \ + { \ + using namespace std::literals; \ + ConsoleLog().info(fmtstr##sv, ##__VA_ARGS__); \ + } while (false) diff --git a/src/zencore/include/zencore/md5.h b/src/zencore/include/zencore/md5.h new file mode 100644 index 000000000..d934dd86b --- /dev/null +++ b/src/zencore/include/zencore/md5.h @@ -0,0 +1,50 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <stdint.h> +#include <compare> +#include "zencore.h" + +namespace zen { + +class StringBuilderBase; + +struct MD5 +{ + uint8_t Hash[16]; + + inline auto operator<=>(const MD5& rhs) const = default; + + static const int StringLength = 32; + typedef char String_t[StringLength + 1]; + + static MD5 HashMemory(const void* data, size_t byteCount); + static MD5 FromHexString(const char* string); + const char* ToHexString(char* outString /* 32 characters + NUL terminator */) const; + StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const; + + static MD5 Zero; // Initialized to all zeroes +}; + +/** + * Utility class for computing MD5 hashes + */ +class MD5Stream +{ +public: + MD5Stream(); + + /// Begin streaming MD5 compute (not needed on freshly constructed MD5Stream instance) + void Reset(); + /// Append another chunk + MD5Stream& Append(const void* data, size_t byteCount); + /// Obtain final MD5 hash. If you wish to reuse the MD5Stream instance call reset() + MD5 GetHash(); + +private: +}; + +void md5_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/memory.h b/src/zencore/include/zencore/memory.h new file mode 100644 index 000000000..560fa9ffc --- /dev/null +++ b/src/zencore/include/zencore/memory.h @@ -0,0 +1,401 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#include <zencore/intmath.h> +#include <zencore/thread.h> + +#include <cstddef> +#include <cstring> +#include <span> +#include <vector> + +namespace zen { + +#if defined(__cpp_lib_ranges) && __cpp_lib_ranges >= 201911L +template<typename T> +concept ContiguousRange = std::ranges::contiguous_range<T>; +#else +template<typename T> +concept ContiguousRange = true; +#endif + +struct MemoryView; + +class MemoryArena +{ +public: + ZENCORE_API MemoryArena(); + ZENCORE_API ~MemoryArena(); + + ZENCORE_API void* Alloc(size_t Size, size_t Alignment); + ZENCORE_API void Free(void* Ptr); + +private: +}; + +class Memory +{ +public: + ZENCORE_API static void* Alloc(size_t Size, size_t Alignment = sizeof(void*)); + ZENCORE_API static void Free(void* Ptr); +}; + +/** Allocator which claims fixed-size blocks from the underlying allocator. + + There is no way to free individual memory blocks. + + \note This is not thread-safe, you will need to provide synchronization yourself +*/ + +class ChunkingLinearAllocator +{ +public: + ChunkingLinearAllocator(uint64_t ChunkSize, uint64_t ChunkAlignment = sizeof(std::max_align_t)); + ~ChunkingLinearAllocator(); + + ZENCORE_API void Reset(); + + ZENCORE_API void* Alloc(size_t Size, size_t Alignment = sizeof(void*)); + inline void Free(void* Ptr) { ZEN_UNUSED(Ptr); /* no-op */ } + + ChunkingLinearAllocator(const ChunkingLinearAllocator&) = delete; + ChunkingLinearAllocator& operator=(const ChunkingLinearAllocator&) = delete; + +private: + uint8_t* m_ChunkCursor = nullptr; + uint64_t m_ChunkBytesRemain = 0; + const uint64_t m_ChunkSize = 0; + const uint64_t m_ChunkAlignment = 0; + std::vector<void*> m_ChunkList; +}; + +////////////////////////////////////////////////////////////////////////// + +struct MutableMemoryView +{ + MutableMemoryView() = default; + + MutableMemoryView(void* DataPtr, size_t DataSize) + : m_Data(reinterpret_cast<uint8_t*>(DataPtr)) + , m_DataEnd(reinterpret_cast<uint8_t*>(DataPtr) + DataSize) + { + } + + MutableMemoryView(void* DataPtr, void* DataEndPtr) + : m_Data(reinterpret_cast<uint8_t*>(DataPtr)) + , m_DataEnd(reinterpret_cast<uint8_t*>(DataEndPtr)) + { + } + + inline bool IsEmpty() const { return m_Data == m_DataEnd; } + void* GetData() const { return m_Data; } + void* GetDataEnd() const { return m_DataEnd; } + size_t GetSize() const { return reinterpret_cast<uint8_t*>(m_DataEnd) - reinterpret_cast<uint8_t*>(m_Data); } + + inline bool EqualBytes(const MutableMemoryView& InView) const + { + const size_t Size = GetSize(); + + return Size == InView.GetSize() && (memcmp(m_Data, InView.m_Data, Size) == 0); + } + + /** Modifies the view to be the given number of bytes from the right. */ + inline void RightInline(uint64_t InSize) + { + const uint64_t OldSize = GetSize(); + const uint64_t NewSize = zen::Min(OldSize, InSize); + m_Data = GetDataAtOffsetNoCheck(OldSize - NewSize); + m_DataEnd = m_Data + NewSize; + } + + /** Returns the right-most part of the view by taking the given number of bytes from the right. */ + [[nodiscard]] inline MutableMemoryView Right(uint64_t InSize) const + { + MutableMemoryView View(*this); + View.RightChopInline(InSize); + return View; + } + + /** Modifies the view by chopping the given number of bytes from the left. */ + inline void RightChopInline(uint64_t InSize) + { + const uint64_t Offset = zen::Min(GetSize(), InSize); + m_Data = GetDataAtOffsetNoCheck(Offset); + } + + /** Returns the left-most part of the view by taking the given number of bytes from the left. */ + constexpr inline MutableMemoryView Left(uint64_t InSize) const + { + MutableMemoryView View(*this); + View.LeftInline(InSize); + return View; + } + + /** Modifies the view to be the given number of bytes from the left. */ + constexpr inline void LeftInline(uint64_t InSize) { m_DataEnd = zen::Min(m_DataEnd, m_Data + InSize); } + + /** Modifies the view to be the middle part by taking up to the given number of bytes from the given offset. */ + inline void MidInline(uint64_t InOffset, uint64_t InSize = ~uint64_t(0)) + { + RightChopInline(InOffset); + LeftInline(InSize); + } + + /** Returns the middle part of the view by taking up to the given number of bytes from the given position. */ + [[nodiscard]] inline MutableMemoryView Mid(uint64_t InOffset, uint64_t InSize = ~uint64_t(0)) const + { + MutableMemoryView View(*this); + View.MidInline(InOffset, InSize); + return View; + } + + /** Returns the right-most part of the view by chopping the given number of bytes from the left. */ + [[nodiscard]] inline MutableMemoryView RightChop(uint64_t InSize) const + { + MutableMemoryView View(*this); + View.RightChopInline(InSize); + return View; + } + + inline MutableMemoryView& operator+=(size_t InSize) + { + RightChopInline(InSize); + return *this; + } + + /** Copies bytes from the input view into this view, and returns the remainder of this view. */ + inline MutableMemoryView CopyFrom(MemoryView InView) const; + +private: + uint8_t* m_Data = nullptr; + uint8_t* m_DataEnd = nullptr; + + /** Returns the data pointer advanced by an offset in bytes. */ + inline constexpr uint8_t* GetDataAtOffsetNoCheck(uint64_t InOffset) const { return m_Data + InOffset; } +}; + +////////////////////////////////////////////////////////////////////////// + +struct MemoryView +{ + MemoryView() = default; + + MemoryView(const MutableMemoryView& MutableView) + : m_Data(reinterpret_cast<const uint8_t*>(MutableView.GetData())) + , m_DataEnd(m_Data + MutableView.GetSize()) + { + } + + MemoryView(const void* DataPtr, size_t DataSize) + : m_Data(reinterpret_cast<const uint8_t*>(DataPtr)) + , m_DataEnd(reinterpret_cast<const uint8_t*>(DataPtr) + DataSize) + { + } + + MemoryView(const void* DataPtr, const void* DataEndPtr) + : m_Data(reinterpret_cast<const uint8_t*>(DataPtr)) + , m_DataEnd(reinterpret_cast<const uint8_t*>(DataEndPtr)) + { + } + + inline bool Contains(const MemoryView& Other) const { return (m_Data <= Other.m_Data) && (m_DataEnd >= Other.m_DataEnd); } + inline bool IsEmpty() const { return m_Data == m_DataEnd; } + const void* GetData() const { return m_Data; } + const void* GetDataEnd() const { return m_DataEnd; } + size_t GetSize() const { return reinterpret_cast<const uint8_t*>(m_DataEnd) - reinterpret_cast<const uint8_t*>(m_Data); } + inline bool operator==(const MemoryView& Rhs) const { return m_Data == Rhs.m_Data && m_DataEnd == Rhs.m_DataEnd; } + + inline bool EqualBytes(const MemoryView& InView) const + { + const size_t Size = GetSize(); + + return Size == InView.GetSize() && (memcmp(m_Data, InView.GetData(), Size) == 0); + } + + inline MemoryView& operator+=(size_t InSize) + { + RightChopInline(InSize); + return *this; + } + + /** Modifies the view by chopping the given number of bytes from the left. */ + inline void RightChopInline(uint64_t InSize) + { + const uint64_t Offset = zen::Min(GetSize(), InSize); + m_Data = GetDataAtOffsetNoCheck(Offset); + } + + inline MemoryView RightChop(uint64_t InSize) + { + MemoryView View(*this); + View.RightChopInline(InSize); + return View; + } + + /** Returns the right-most part of the view by taking the given number of bytes from the right. */ + [[nodiscard]] inline MemoryView Right(uint64_t InSize) const + { + MemoryView View(*this); + View.RightInline(InSize); + return View; + } + + /** Modifies the view to be the given number of bytes from the right. */ + inline void RightInline(uint64_t InSize) + { + const uint64_t OldSize = GetSize(); + const uint64_t NewSize = zen::Min(OldSize, InSize); + m_Data = GetDataAtOffsetNoCheck(OldSize - NewSize); + m_DataEnd = m_Data + NewSize; + } + + /** Returns the left-most part of the view by taking the given number of bytes from the left. */ + inline MemoryView Left(uint64_t InSize) const + { + MemoryView View(*this); + View.LeftInline(InSize); + return View; + } + + /** Modifies the view to be the given number of bytes from the left. */ + inline void LeftInline(uint64_t InSize) + { + InSize = zen::Min(GetSize(), InSize); + m_DataEnd = zen::Min(m_DataEnd, m_Data + InSize); + } + + /** Modifies the view to be the middle part by taking up to the given number of bytes from the given offset. */ + inline void MidInline(uint64_t InOffset, uint64_t InSize = ~uint64_t(0)) + { + RightChopInline(InOffset); + LeftInline(InSize); + } + + /** Returns the middle part of the view by taking up to the given number of bytes from the given position. */ + [[nodiscard]] inline MemoryView Mid(uint64_t InOffset, uint64_t InSize = ~uint64_t(0)) const + { + MemoryView View(*this); + View.MidInline(InOffset, InSize); + return View; + } + + constexpr void Reset() + { + m_Data = nullptr; + m_DataEnd = nullptr; + } + +private: + const uint8_t* m_Data = nullptr; + const uint8_t* m_DataEnd = nullptr; + + /** Returns the data pointer advanced by an offset in bytes. */ + inline constexpr const uint8_t* GetDataAtOffsetNoCheck(uint64_t InOffset) const { return m_Data + InOffset; } +}; + +inline MutableMemoryView +MutableMemoryView::CopyFrom(MemoryView InView) const +{ + ZEN_ASSERT(InView.GetSize() <= GetSize()); + memcpy(m_Data, InView.GetData(), InView.GetSize()); + return RightChop(InView.GetSize()); +} + +/** Advances the start of the view by an offset, which is clamped to stay within the view. */ +inline MemoryView +operator+(const MemoryView& View, uint64_t Offset) +{ + return MemoryView(View) += Offset; +} + +/** Advances the start of the view by an offset, which is clamped to stay within the view. */ +inline MemoryView +operator+(uint64_t Offset, const MemoryView& View) +{ + return MemoryView(View) += Offset; +} + +/** Advances the start of the view by an offset, which is clamped to stay within the view. */ +inline MutableMemoryView +operator+(const MutableMemoryView& View, uint64_t Offset) +{ + return MutableMemoryView(View) += Offset; +} + +/** Advances the start of the view by an offset, which is clamped to stay within the view. */ +inline MutableMemoryView +operator+(uint64_t Offset, const MutableMemoryView& View) +{ + return MutableMemoryView(View) += Offset; +} + +/** + * Make a non-owning view of the memory of the initializer list. + * + * This overload is only available when the element type does not need to be deduced. + */ +template<typename T> +[[nodiscard]] inline MemoryView +MakeMemoryView(std::initializer_list<typename std::type_identity<T>::type> List) +{ + return MemoryView(List.begin(), List.size() * sizeof(T)); +} + +/** Make a non-owning view of the memory of the contiguous container. */ +template<ContiguousRange R> +[[nodiscard]] constexpr inline MemoryView +MakeMemoryView(const R& Container) +{ + std::span Span = Container; + return MemoryView(Span.data(), Span.size() * sizeof(typename decltype(Span)::element_type)); +} + +/** Make a non-owning const view starting at Data and ending at DataEnd. */ + +[[nodiscard]] inline MemoryView +MakeMemoryView(const void* Data, const void* DataEnd) +{ + return MemoryView(Data, DataEnd); +} + +[[nodiscard]] inline MemoryView +MakeMemoryView(const void* Data, uint64_t Size) +{ + return MemoryView(Data, reinterpret_cast<const uint8_t*>(Data) + Size); +} + +/** + * Make a non-owning mutable view of the memory of the initializer list. + * + * This overload is only available when the element type does not need to be deduced. + */ +template<typename T> +[[nodiscard]] inline MutableMemoryView +MakeMutableMemoryView(std::initializer_list<typename std::type_identity<T>::type> List) +{ + return MutableMemoryView(List.begin(), List.size() * sizeof(T)); +} + +/** Make a non-owning mutable view of the memory of the contiguous container. */ +template<ContiguousRange R> +[[nodiscard]] constexpr inline MutableMemoryView +MakeMutableMemoryView(R& Container) +{ + std::span Span = Container; + return MutableMemoryView(Span.data(), Span.size() * sizeof(typename decltype(Span)::element_type)); +} + +/** Make a non-owning mutable view starting at Data and ending at DataEnd. */ + +[[nodiscard]] inline MutableMemoryView +MakeMutableMemoryView(void* Data, void* DataEnd) +{ + return MutableMemoryView(Data, DataEnd); +} + +void memory_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/meta.h b/src/zencore/include/zencore/meta.h new file mode 100644 index 000000000..82eb5cc30 --- /dev/null +++ b/src/zencore/include/zencore/meta.h @@ -0,0 +1,30 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +/* This file contains utility functions for meta programming + * + * Since you're in here you're probably quite observant, and you'll + * note that it's quite barren here. This is because template + * metaprogramming is awful and I try not to engage in it. However, + * sometimes these things are forced upon us. + * + */ + +namespace zen { + +/** + * Uses implicit conversion to create an instance of a specific type. + * Useful to make things clearer or circumvent unintended type deduction in templates. + * Safer than C casts and static_casts, e.g. does not allow down-casts + * + * @param Obj The object (usually pointer or reference) to convert. + * + * @return The object converted to the specified type. + */ +template<typename T> +inline T +ImplicitConv(typename std::type_identity<T>::type Obj) +{ + return Obj; +} + +} // namespace zen diff --git a/src/zencore/include/zencore/mpscqueue.h b/src/zencore/include/zencore/mpscqueue.h new file mode 100644 index 000000000..19e410d85 --- /dev/null +++ b/src/zencore/include/zencore/mpscqueue.h @@ -0,0 +1,110 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <atomic> +#include <memory> +#include <new> +#include <optional> + +#ifdef __cpp_lib_hardware_interference_size +using std::hardware_constructive_interference_size; +using std::hardware_destructive_interference_size; +#else +// 64 bytes on x86-64 │ L1_CACHE_BYTES │ L1_CACHE_SHIFT │ __cacheline_aligned │ ... +constexpr std::size_t hardware_constructive_interference_size = 64; +constexpr std::size_t hardware_destructive_interference_size = 64; +#endif + +namespace zen { + +/** An untyped array of data with compile-time alignment and size derived from another type. */ +template<typename ElementType> +struct TypeCompatibleStorage +{ + ElementType* Data() { return (ElementType*)this; } + const ElementType* Data() const { return (const ElementType*)this; } + + alignas(ElementType) char DataMember; +}; + +/** Fast multi-producer/single-consumer unbounded concurrent queue. + + Based on http://www.1024cores.net/home/lock-free-algorithms/queues/non-intrusive-mpsc-node-based-queue + */ + +template<typename T> +class MpscQueue final +{ +public: + using ElementType = T; + + MpscQueue() + { + Node* Sentinel = new Node; + Head.store(Sentinel, std::memory_order_relaxed); + Tail = Sentinel; + } + + ~MpscQueue() + { + Node* Next = Tail->Next.load(std::memory_order_relaxed); + + // sentinel's value is already destroyed + delete Tail; + + while (Next != nullptr) + { + Tail = Next; + Next = Tail->Next.load(std::memory_order_relaxed); + + std::destroy_at((ElementType*)&Tail->Value); + delete Tail; + } + } + + template<typename... ArgTypes> + void Enqueue(ArgTypes&&... Args) + { + Node* New = new Node; + new (&New->Value) ElementType(std::forward<ArgTypes>(Args)...); + + Node* Prev = Head.exchange(New, std::memory_order_acq_rel); + Prev->Next.store(New, std::memory_order_release); + } + + std::optional<ElementType> Dequeue() + { + Node* Next = Tail->Next.load(std::memory_order_acquire); + + if (Next == nullptr) + { + return {}; + } + + ElementType* ValuePtr = (ElementType*)&Next->Value; + std::optional<ElementType> Res{std::move(*ValuePtr)}; + std::destroy_at(ValuePtr); + + delete Tail; // current sentinel + + Tail = Next; // new sentinel + return Res; + } + +private: + struct Node + { + std::atomic<Node*> Next{nullptr}; + TypeCompatibleStorage<ElementType> Value; + }; + +private: + std::atomic<Node*> Head; // accessed only by producers + alignas(hardware_constructive_interference_size) + Node* Tail; // accessed only by consumer, hence should be on a different cache line than `Head` +}; + +void mpscqueue_forcelink(); + +} // namespace zen diff --git a/src/zencore/include/zencore/refcount.h b/src/zencore/include/zencore/refcount.h new file mode 100644 index 000000000..f0bb6b85e --- /dev/null +++ b/src/zencore/include/zencore/refcount.h @@ -0,0 +1,186 @@ +// Copyright Epic Games, Inc. All Rights Reserved. +#pragma once + +#include "atomic.h" +#include "zencore.h" + +#include <compare> + +namespace zen { + +/** + * Helper base class for reference counted objects using intrusive reference counts + * + * This class is pretty straightforward but does one thing which may be unexpected: + * + * - Instances on the stack are initialized with a reference count of one to ensure + * nobody tries to accidentally delete it. (TODO: is this really useful?) + */ +class RefCounted +{ +public: + RefCounted() = default; + virtual ~RefCounted() = default; + + inline uint32_t AddRef() const { return AtomicIncrement(const_cast<RefCounted*>(this)->m_RefCount); } + inline uint32_t Release() const + { + uint32_t refCount = AtomicDecrement(const_cast<RefCounted*>(this)->m_RefCount); + if (refCount == 0) + { + delete this; + } + return refCount; + } + + // Copying reference counted objects doesn't make a lot of sense generally, so let's prevent it + + RefCounted(const RefCounted&) = delete; + RefCounted(RefCounted&&) = delete; + RefCounted& operator=(const RefCounted&) = delete; + RefCounted& operator=(RefCounted&&) = delete; + +protected: + inline uint32_t RefCount() const { return m_RefCount; } + +private: + uint32_t m_RefCount = 0; +}; + +/** + * Smart pointer for classes derived from RefCounted + */ + +template<class T> +class RefPtr +{ +public: + inline RefPtr() = default; + inline RefPtr(const RefPtr& Rhs) : m_Ref(Rhs.m_Ref) { m_Ref && m_Ref->AddRef(); } + inline RefPtr(T* Ptr) : m_Ref(Ptr) { m_Ref && m_Ref->AddRef(); } + inline ~RefPtr() { m_Ref && m_Ref->Release(); } + + [[nodiscard]] inline bool IsNull() const { return m_Ref == nullptr; } + inline explicit operator bool() const { return m_Ref != nullptr; } + inline operator T*() const { return m_Ref; } + inline T* operator->() const { return m_Ref; } + + inline std::strong_ordering operator<=>(const RefPtr& Rhs) const = default; + + inline RefPtr& operator=(T* Rhs) + { + Rhs && Rhs->AddRef(); + m_Ref && m_Ref->Release(); + m_Ref = Rhs; + return *this; + } + inline RefPtr& operator=(const RefPtr& Rhs) + { + if (&Rhs != this) + { + Rhs && Rhs->AddRef(); + m_Ref && m_Ref->Release(); + m_Ref = Rhs.m_Ref; + } + return *this; + } + inline RefPtr& operator=(RefPtr&& Rhs) noexcept + { + if (&Rhs != this) + { + m_Ref && m_Ref->Release(); + m_Ref = Rhs.m_Ref; + Rhs.m_Ref = nullptr; + } + return *this; + } + template<typename OtherType> + inline RefPtr& operator=(RefPtr<OtherType>&& Rhs) noexcept + { + if ((RefPtr*)&Rhs != this) + { + m_Ref && m_Ref->Release(); + m_Ref = Rhs.m_Ref; + Rhs.m_Ref = nullptr; + } + return *this; + } + inline RefPtr(RefPtr&& Rhs) noexcept : m_Ref(Rhs.m_Ref) { Rhs.m_Ref = nullptr; } + template<typename OtherType> + explicit inline RefPtr(RefPtr<OtherType>&& Rhs) noexcept : m_Ref(Rhs.m_Ref) + { + Rhs.m_Ref = nullptr; + } + +private: + T* m_Ref = nullptr; + template<typename U> + friend class RefPtr; +}; + +/** + * Smart pointer for classes derived from RefCounted + * + * This variant does not decay to a raw pointer + * + */ + +template<class T> +class Ref +{ +public: + inline Ref() = default; + inline Ref(const Ref& Rhs) : m_Ref(Rhs.m_Ref) { m_Ref && m_Ref->AddRef(); } + inline explicit Ref(T* Ptr) : m_Ref(Ptr) { m_Ref && m_Ref->AddRef(); } + inline ~Ref() { m_Ref && m_Ref->Release(); } + + template<typename DerivedType> + requires DerivedFrom<DerivedType, T> + inline Ref(const Ref<DerivedType>& Rhs) : Ref(Rhs.m_Ref) {} + + [[nodiscard]] inline bool IsNull() const { return m_Ref == nullptr; } + inline explicit operator bool() const { return m_Ref != nullptr; } + inline T* operator->() const { return m_Ref; } + inline T* Get() const { return m_Ref; } + + inline std::strong_ordering operator<=>(const Ref& Rhs) const = default; + + inline Ref& operator=(T* Rhs) + { + Rhs && Rhs->AddRef(); + m_Ref && m_Ref->Release(); + m_Ref = Rhs; + return *this; + } + inline Ref& operator=(const Ref& Rhs) + { + if (&Rhs != this) + { + Rhs && Rhs->AddRef(); + m_Ref && m_Ref->Release(); + m_Ref = Rhs.m_Ref; + } + return *this; + } + inline Ref& operator=(Ref&& Rhs) noexcept + { + if (&Rhs != this) + { + m_Ref && m_Ref->Release(); + m_Ref = Rhs.m_Ref; + Rhs.m_Ref = nullptr; + } + return *this; + } + inline Ref(Ref&& Rhs) noexcept : m_Ref(Rhs.m_Ref) { Rhs.m_Ref = nullptr; } + +private: + T* m_Ref = nullptr; + + template<class U> + friend class Ref; +}; + +void refcount_forcelink(); + +} // namespace zen diff --git a/src/zencore/include/zencore/scopeguard.h b/src/zencore/include/zencore/scopeguard.h new file mode 100644 index 000000000..d04c8ed9c --- /dev/null +++ b/src/zencore/include/zencore/scopeguard.h @@ -0,0 +1,45 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <type_traits> +#include "logging.h" +#include "zencore.h" + +namespace zen { + +template<typename T> +class [[nodiscard]] ScopeGuardImpl +{ +public: + inline ScopeGuardImpl(T&& func) : m_guardFunc(func) {} + ~ScopeGuardImpl() + { + if (!m_dismissed) + { + try + { + m_guardFunc(); + } + catch (std::exception& Ex) + { + ZEN_ERROR("scope guard threw exception: '{}'", Ex.what()); + } + } + } + + void Dismiss() { m_dismissed = true; } + +private: + bool m_dismissed = false; + T m_guardFunc; +}; + +template<typename T> +ScopeGuardImpl<T> +MakeGuard(T&& fn) +{ + return ScopeGuardImpl<T>(std::move(fn)); +} + +} // namespace zen diff --git a/src/zencore/include/zencore/session.h b/src/zencore/include/zencore/session.h new file mode 100644 index 000000000..dd90197bf --- /dev/null +++ b/src/zencore/include/zencore/session.h @@ -0,0 +1,14 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +namespace zen { + +struct Oid; + +ZENCORE_API [[nodiscard]] Oid GetSessionId(); +ZENCORE_API [[nodiscard]] std::string_view GetSessionIdString(); + +} // namespace zen diff --git a/src/zencore/include/zencore/sha1.h b/src/zencore/include/zencore/sha1.h new file mode 100644 index 000000000..fc26f442b --- /dev/null +++ b/src/zencore/include/zencore/sha1.h @@ -0,0 +1,76 @@ +// ////////////////////////////////////////////////////////// +// sha1.h +// Copyright (c) 2014,2015 Stephan Brumme. All rights reserved. +// see http://create.stephan-brumme.com/disclaimer.html +// + +#pragma once + +#include <stdint.h> +#include <compare> +#include "zencore.h" + +namespace zen { + +class StringBuilderBase; + +struct SHA1 +{ + uint8_t Hash[20]; + + inline auto operator<=>(const SHA1& rhs) const = default; + + static const int StringLength = 40; + typedef char String_t[StringLength + 1]; + + static SHA1 HashMemory(const void* data, size_t byteCount); + static SHA1 FromHexString(const char* string); + const char* ToHexString(char* outString /* 40 characters + NUL terminator */) const; + StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const; + + static SHA1 Zero; // Initialized to all zeroes +}; + +/** + * Utility class for computing SHA1 hashes + */ +class SHA1Stream +{ +public: + SHA1Stream(); + + /** compute SHA1 of a memory block + + \note SHA1 class contains a slightly more convenient helper function for this use case + \see SHA1::fromMemory() + */ + SHA1 Compute(const void* data, size_t byteCount); + + /// Begin streaming SHA1 compute (not needed on freshly constructed SHA1Stream instance) + void Reset(); + /// Append another chunk + SHA1Stream& Append(const void* data, size_t byteCount); + /// Obtain final SHA1 hash. If you wish to reuse the SHA1Stream instance call reset() + SHA1 GetHash(); + +private: + void ProcessBlock(const void* data); + void ProcessBuffer(); + + enum + { + /// split into 64 byte blocks (=> 512 bits) + BlockSize = 512 / 8, + HashBytes = 20, + HashValues = HashBytes / 4 + }; + + uint64_t m_NumBytes; // size of processed data in bytes + size_t m_BufferSize; // valid bytes in m_buffer + uint8_t m_Buffer[BlockSize]; // bytes not processed yet + uint32_t m_Hash[HashValues]; +}; + +void sha1_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/sharedbuffer.h b/src/zencore/include/zencore/sharedbuffer.h new file mode 100644 index 000000000..97c5a9d21 --- /dev/null +++ b/src/zencore/include/zencore/sharedbuffer.h @@ -0,0 +1,167 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#include <zencore/iobuffer.h> +#include <zencore/memory.h> +#include <zencore/refcount.h> + +#include <memory.h> + +namespace zen { + +class SharedBuffer; + +/** + * Reference to a memory buffer with a single owner + * + * Internally + */ +class UniqueBuffer +{ +public: + UniqueBuffer() = default; + UniqueBuffer(UniqueBuffer&&) = default; + UniqueBuffer& operator=(UniqueBuffer&&) = default; + UniqueBuffer(const UniqueBuffer&) = delete; + UniqueBuffer& operator=(const UniqueBuffer&) = delete; + + ZENCORE_API explicit UniqueBuffer(IoBufferCore* Owner); + + [[nodiscard]] void* GetData() { return m_Buffer ? m_Buffer->MutableDataPointer() : nullptr; } + [[nodiscard]] const void* GetData() const { return m_Buffer ? m_Buffer->DataPointer() : nullptr; } + [[nodiscard]] size_t GetSize() const { return m_Buffer ? m_Buffer->DataBytes() : 0; } + + operator MutableMemoryView() { return GetMutableView(); } + operator MemoryView() const { return GetView(); } + + /** + * Returns true if this does not point to a buffer owner. + * + * A null buffer is always owned, materialized, and empty. + */ + [[nodiscard]] inline bool IsNull() const { return m_Buffer.IsNull(); } + + /** Reset this to null. */ + ZENCORE_API void Reset(); + + [[nodiscard]] inline MutableMemoryView GetMutableView() { return MutableMemoryView(GetData(), GetSize()); } + [[nodiscard]] inline MemoryView GetView() const { return MemoryView(GetData(), GetSize()); } + + /** Make an uninitialized owned buffer of the specified size. */ + [[nodiscard]] ZENCORE_API static UniqueBuffer Alloc(uint64_t Size); + + /** Make a non-owned view of the input. */ + [[nodiscard]] ZENCORE_API static UniqueBuffer MakeMutableView(void* DataPtr, uint64_t Size); + + /** + * Convert this to an immutable shared buffer, leaving this null. + * + * Steals the buffer owner from the unique buffer. + */ + [[nodiscard]] ZENCORE_API SharedBuffer MoveToShared(); + +private: + // This may be null, for a default constructed UniqueBuffer only + RefPtr<IoBufferCore> m_Buffer; + + friend class SharedBuffer; +}; + +/** + * Reference to a memory buffer with shared ownership + */ +class SharedBuffer +{ +public: + SharedBuffer() = default; + ZENCORE_API explicit SharedBuffer(UniqueBuffer&&); + inline explicit SharedBuffer(IoBufferCore* Owner) : m_Buffer(Owner) {} + ZENCORE_API explicit SharedBuffer(IoBuffer&& Buffer) : m_Buffer(std::move(Buffer.m_Core)) {} + ZENCORE_API explicit SharedBuffer(const IoBuffer& Buffer) : m_Buffer(Buffer.m_Core) {} + ZENCORE_API explicit SharedBuffer(RefPtr<IoBufferCore>&& Owner) : m_Buffer(std::move(Owner)) {} + + [[nodiscard]] const void* GetData() const + { + if (m_Buffer) + { + return m_Buffer->DataPointer(); + } + return nullptr; + } + + [[nodiscard]] size_t GetSize() const + { + if (m_Buffer) + { + return m_Buffer->DataBytes(); + } + return 0; + } + + inline void MakeImmutable() + { + ZEN_ASSERT(m_Buffer); + m_Buffer->SetIsImmutable(true); + } + + /** Returns a buffer that is owned, by cloning if not owned. */ + [[nodiscard]] ZENCORE_API SharedBuffer MakeOwned() const&; + [[nodiscard]] ZENCORE_API SharedBuffer MakeOwned() &&; + + [[nodiscard]] bool IsOwned() const { return !m_Buffer || m_Buffer->IsOwned(); } + [[nodiscard]] inline bool IsNull() const { return !m_Buffer; } + inline void Reset() { m_Buffer = nullptr; } + + [[nodiscard]] MemoryView GetView() const + { + if (m_Buffer) + { + return MemoryView(m_Buffer->DataPointer(), m_Buffer->DataBytes()); + } + else + { + return MemoryView(); + } + } + operator MemoryView() const { return GetView(); } + + /** Returns true if this points to a buffer owner. */ + [[nodiscard]] inline explicit operator bool() const { return !IsNull(); } + + [[nodiscard]] inline IoBuffer AsIoBuffer() const { return IoBuffer(m_Buffer); } + + SharedBuffer& operator=(UniqueBuffer&& Rhs) + { + m_Buffer = std::move(Rhs.m_Buffer); + return *this; + } + + std::strong_ordering operator<=>(const SharedBuffer& Rhs) const = default; + + /** Make a non-owned view of the input */ + [[nodiscard]] inline static SharedBuffer MakeView(MemoryView View) { return MakeView(View.GetData(), View.GetSize()); } + /** Make a non-owning view of the memory of the contiguous container. */ + [[nodiscard]] inline static SharedBuffer MakeView(const ContiguousRange auto& Container) + { + std::span Span = Container; + return MakeView(Span.data(), Span.size() * sizeof(typename decltype(Span)::element_type)); + } + /** Make a non-owned view of the input */ + [[nodiscard]] ZENCORE_API static SharedBuffer MakeView(const void* Data, uint64_t Size); + /** Make a non-owned view of the input */ + [[nodiscard]] ZENCORE_API static SharedBuffer MakeView(MemoryView View, SharedBuffer OuterBuffer); + /** Make an owned clone of the buffer */ + [[nodiscard]] ZENCORE_API SharedBuffer Clone(); + /** Make an owned clone of the memory in the input view */ + [[nodiscard]] ZENCORE_API static SharedBuffer Clone(MemoryView View); + +private: + RefPtr<IoBufferCore> m_Buffer; +}; + +void sharedbuffer_forcelink(); + +} // namespace zen diff --git a/src/zencore/include/zencore/stats.h b/src/zencore/include/zencore/stats.h new file mode 100644 index 000000000..1a0817b99 --- /dev/null +++ b/src/zencore/include/zencore/stats.h @@ -0,0 +1,295 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#include <atomic> +#include <random> + +namespace zen { +class CbObjectWriter; +} + +namespace zen::metrics { + +template<typename T> +class Gauge +{ +public: + Gauge() : m_Value{0} {} + + T Value() const { return m_Value; } + void SetValue(T Value) { m_Value = Value; } + +private: + std::atomic<T> m_Value; +}; + +/** Stats counter + * + * A counter is modified by adding or subtracting a value from a current value. + * This would typically be used to track number of requests in flight, number + * of active jobs etc + * + */ +class Counter +{ +public: + inline void SetValue(uint64_t Value) { m_count = Value; } + inline uint64_t Value() const { return m_count; } + + inline void Increment(int64_t AddValue) { m_count.fetch_add(AddValue); } + inline void Decrement(int64_t SubValue) { m_count.fetch_sub(SubValue); } + inline void Clear() { m_count.store(0, std::memory_order_release); } + +private: + std::atomic<uint64_t> m_count{0}; +}; + +/** Exponential Weighted Moving Average + + This is very raw, to use as little state as possible. If we + want to use this more broadly in user code we should perhaps + add a more user-friendly wrapper + */ + +class RawEWMA +{ +public: + /// <summary> + /// Update EWMA with new measure + /// </summary> + /// <param name="Alpha">Smoothing factor (between 0 and 1)</param> + /// <param name="Interval">Elapsed time since last</param> + /// <param name="Count">Value</param> + /// <param name="IsInitialUpdate">Whether this is the first update or not</param> + void Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate); + double Rate() const; + +private: + std::atomic<double> m_Rate = 0; +}; + +/// <summary> +/// Tracks rate of events over time (i.e requests/sec), using +/// exponential moving averages +/// </summary> +class Meter +{ +public: + Meter(); + ~Meter(); + + inline uint64_t Count() const { return m_TotalCount; } + double Rate1(); // One-minute rate + double Rate5(); // Five-minute rate + double Rate15(); // Fifteen-minute rate + double MeanRate() const; // Mean rate since instantiation of this meter + void Mark(uint64_t Count = 1); // Register one or more events + +private: + std::atomic<uint64_t> m_TotalCount{0}; // Accumulator counting number of marks since beginning + std::atomic<uint64_t> m_PendingCount{0}; // Pending EWMA update accumulator + std::atomic<uint64_t> m_StartTick{0}; // Time this was instantiated (for mean) + std::atomic<uint64_t> m_LastTick{0}; // Timestamp of last EWMA tick + std::atomic<int64_t> m_Remainder{0}; // Tracks the "modulo" of tick time + bool m_IsFirstTick = true; + RawEWMA m_RateM1; + RawEWMA m_RateM5; + RawEWMA m_RateM15; + + void TickIfNecessary(); + void Tick(); +}; + +/** Moment-in-time snapshot of a distribution + */ +class SampleSnapshot +{ +public: + SampleSnapshot(std::vector<double>&& Values); + ~SampleSnapshot(); + + uint32_t Size() const { return (uint32_t)m_Values.size(); } + double GetQuantileValue(double Quantile); + double GetMedian() { return GetQuantileValue(0.5); } + double Get75Percentile() { return GetQuantileValue(0.75); } + double Get95Percentile() { return GetQuantileValue(0.95); } + double Get98Percentile() { return GetQuantileValue(0.98); } + double Get99Percentile() { return GetQuantileValue(0.99); } + double Get999Percentile() { return GetQuantileValue(0.999); } + const std::vector<double>& GetValues() const; + +private: + std::vector<double> m_Values; +}; + +/** Randomly selects samples from a stream. Uses Vitter's + Algorithm R to produce a statistically representative sample. + + http://www.cs.umd.edu/~samir/498/vitter.pdf - Random Sampling with a Reservoir + */ + +class UniformSample +{ +public: + UniformSample(uint32_t ReservoirSize); + ~UniformSample(); + + void Clear(); + uint32_t Size() const; + void Update(int64_t Value); + SampleSnapshot Snapshot() const; + + template<Invocable<int64_t> T> + void IterateValues(T Callback) const + { + for (const auto& Value : m_Values) + { + Callback(Value); + } + } + +private: + std::atomic<uint64_t> m_SampleCounter{0}; + std::vector<std::atomic<int64_t>> m_Values; +}; + +/** Track (probabilistic) sample distribution along with min/max + */ +class Histogram +{ +public: + Histogram(int32_t SampleCount = 1028); + ~Histogram(); + + void Clear(); + void Update(int64_t Value); + int64_t Max() const; + int64_t Min() const; + double Mean() const; + uint64_t Count() const; + SampleSnapshot Snapshot() const { return m_Sample.Snapshot(); } + +private: + UniformSample m_Sample; + std::atomic<int64_t> m_Min{0}; + std::atomic<int64_t> m_Max{0}; + std::atomic<int64_t> m_Sum{0}; + std::atomic<int64_t> m_Count{0}; +}; + +/** Track timing and frequency of some operation + + Example usage would be to track frequency and duration of network + requests, or function calls. + + */ +class OperationTiming +{ +public: + OperationTiming(int32_t SampleCount = 514); + ~OperationTiming(); + + void Update(int64_t Duration); + int64_t Max() const; + int64_t Min() const; + double Mean() const; + uint64_t Count() const; + SampleSnapshot Snapshot() const { return m_Histogram.Snapshot(); } + + double Rate1() { return m_Meter.Rate1(); } + double Rate5() { return m_Meter.Rate5(); } + double Rate15() { return m_Meter.Rate15(); } + double MeanRate() const { return m_Meter.MeanRate(); } + + struct Scope + { + Scope(OperationTiming& Outer); + ~Scope(); + + void Stop(); + void Cancel(); + + private: + OperationTiming& m_Outer; + uint64_t m_StartTick; + }; + +private: + Meter m_Meter; + Histogram m_Histogram; +}; + +/** Metrics for network requests + + Aggregates tracking of duration, payload sizes into a single + class + + */ +class RequestStats +{ +public: + RequestStats(int32_t SampleCount = 514); + ~RequestStats(); + + void Update(int64_t Duration, int64_t Bytes); + uint64_t Count() const; + + // Timing + + int64_t MaxDuration() const { return m_BytesHistogram.Max(); } + int64_t MinDuration() const { return m_BytesHistogram.Min(); } + double MeanDuration() const { return m_BytesHistogram.Mean(); } + SampleSnapshot DurationSnapshot() const { return m_RequestTimeHistogram.Snapshot(); } + double Rate1() { return m_RequestMeter.Rate1(); } + double Rate5() { return m_RequestMeter.Rate5(); } + double Rate15() { return m_RequestMeter.Rate15(); } + double MeanRate() const { return m_RequestMeter.MeanRate(); } + + // Bytes + + int64_t MaxBytes() const { return m_BytesHistogram.Max(); } + int64_t MinBytes() const { return m_BytesHistogram.Min(); } + double MeanBytes() const { return m_BytesHistogram.Mean(); } + SampleSnapshot BytesSnapshot() const { return m_BytesHistogram.Snapshot(); } + double ByteRate1() { return m_BytesMeter.Rate1(); } + double ByteRate5() { return m_BytesMeter.Rate5(); } + double ByteRate15() { return m_BytesMeter.Rate15(); } + double ByteMeanRate() const { return m_BytesMeter.MeanRate(); } + + struct Scope + { + Scope(OperationTiming& Outer); + ~Scope(); + + void Cancel(); + + private: + OperationTiming& m_Outer; + uint64_t m_StartTick; + }; + + void EmitSnapshot(std::string_view Tag, CbObjectWriter& Cbo); + +private: + Meter m_RequestMeter; + Meter m_BytesMeter; + Histogram m_RequestTimeHistogram; + Histogram m_BytesHistogram; +}; + +void EmitSnapshot(std::string_view Tag, OperationTiming& Stat, CbObjectWriter& Cbo); +void EmitSnapshot(std::string_view Tag, const Histogram& Stat, CbObjectWriter& Cbo, double ConversionFactor); +void EmitSnapshot(std::string_view Tag, Meter& Stat, CbObjectWriter& Cbo); + +void EmitSnapshot(const Histogram& Stat, CbObjectWriter& Cbo, double ConversionFactor); + +} // namespace zen::metrics + +namespace zen { + +extern void stats_forcelink(); + +} // namespace zen diff --git a/src/zencore/include/zencore/stream.h b/src/zencore/include/zencore/stream.h new file mode 100644 index 000000000..9e4996249 --- /dev/null +++ b/src/zencore/include/zencore/stream.h @@ -0,0 +1,90 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#include <zencore/memory.h> +#include <zencore/thread.h> + +#include <vector> + +namespace zen { + +/** + * Binary stream writer + */ + +class BinaryWriter +{ +public: + inline BinaryWriter() = default; + ~BinaryWriter() = default; + + inline void Write(const void* DataPtr, size_t ByteCount) + { + Write(DataPtr, ByteCount, m_Offset); + m_Offset += ByteCount; + } + + inline void Write(MemoryView Memory) { Write(Memory.GetData(), Memory.GetSize()); } + void Write(std::initializer_list<const MemoryView> Buffers); + + inline uint64_t CurrentOffset() const { return m_Offset; } + + inline const uint8_t* Data() const { return m_Buffer.data(); } + inline const uint8_t* GetData() const { return m_Buffer.data(); } + inline uint64_t Size() const { return m_Buffer.size(); } + inline uint64_t GetSize() const { return m_Buffer.size(); } + void Reset(); + + inline MemoryView GetView() const { return MemoryView(m_Buffer.data(), m_Offset); } + inline MutableMemoryView GetMutableView() { return MutableMemoryView(m_Buffer.data(), m_Offset); } + +private: + std::vector<uint8_t> m_Buffer; + uint64_t m_Offset = 0; + + void Write(const void* DataPtr, size_t ByteCount, uint64_t Offset); +}; + +inline MemoryView +MakeMemoryView(const BinaryWriter& Stream) +{ + return MemoryView(Stream.Data(), Stream.Size()); +} + +/** + * Binary stream reader + */ + +class BinaryReader +{ +public: + inline BinaryReader(const void* Buffer, uint64_t Size) : m_BufferBase(reinterpret_cast<const uint8_t*>(Buffer)), m_BufferSize(Size) {} + inline BinaryReader(MemoryView Buffer) + : m_BufferBase(reinterpret_cast<const uint8_t*>(Buffer.GetData())) + , m_BufferSize(Buffer.GetSize()) + { + } + + inline void Read(void* DataPtr, size_t ByteCount) + { + memcpy(DataPtr, m_BufferBase + m_Offset, ByteCount); + m_Offset += ByteCount; + } + + inline uint64_t Size() const { return m_BufferSize; } + inline uint64_t GetSize() const { return Size(); } + inline uint64_t CurrentOffset() const { return m_Offset; } + inline void Skip(size_t ByteCount) { m_Offset += ByteCount; }; + +private: + const uint8_t* m_BufferBase; + uint64_t m_BufferSize; + uint64_t m_Offset = 0; +}; + +void stream_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h new file mode 100644 index 000000000..ab111ff81 --- /dev/null +++ b/src/zencore/include/zencore/string.h @@ -0,0 +1,1115 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "intmath.h" +#include "zencore.h" + +#include <stdint.h> +#include <string.h> +#include <charconv> +#include <codecvt> +#include <compare> +#include <concepts> +#include <optional> +#include <span> +#include <string_view> + +#include <type_traits> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +inline bool +StringEquals(const char8_t* s1, const char* s2) +{ + return strcmp(reinterpret_cast<const char*>(s1), s2) == 0; +} + +inline bool +StringEquals(const char* s1, const char* s2) +{ + return strcmp(s1, s2) == 0; +} + +inline size_t +StringLength(const char* str) +{ + return strlen(str); +} + +inline bool +StringEquals(const wchar_t* s1, const wchar_t* s2) +{ + return wcscmp(s1, s2) == 0; +} + +inline size_t +StringLength(const wchar_t* str) +{ + return wcslen(str); +} + +////////////////////////////////////////////////////////////////////////// +// File name helpers +// + +ZENCORE_API const char* FilepathFindExtension(const std::string_view& path, const char* extensionToMatch = nullptr); + +////////////////////////////////////////////////////////////////////////// +// Text formatting of numbers +// + +ZENCORE_API bool ToString(std::span<char> Buffer, uint64_t Num); +ZENCORE_API bool ToString(std::span<char> Buffer, int64_t Num); + +struct TextNumBase +{ + inline const char* c_str() const { return m_Buffer; } + inline operator std::string_view() const { return std::string_view(m_Buffer); } + +protected: + char m_Buffer[24]; +}; + +struct IntNum : public TextNumBase +{ + inline IntNum(UnsignedIntegral auto Number) { ToString(m_Buffer, uint64_t(Number)); } + inline IntNum(SignedIntegral auto Number) { ToString(m_Buffer, int64_t(Number)); } +}; + +////////////////////////////////////////////////////////////////////////// +// +// Quick-and-dirty string builder. Good enough for me, but contains traps +// and not-quite-ideal behaviour especially when mixing character types etc +// + +template<typename C> +class StringBuilderImpl +{ +public: + StringBuilderImpl() = default; + ZENCORE_API ~StringBuilderImpl(); + + StringBuilderImpl(const StringBuilderImpl&) = delete; + StringBuilderImpl(const StringBuilderImpl&&) = delete; + const StringBuilderImpl& operator=(const StringBuilderImpl&) = delete; + const StringBuilderImpl& operator=(const StringBuilderImpl&&) = delete; + + inline size_t AddUninitialized(size_t Count) + { + EnsureCapacity(Count); + const size_t OldCount = Size(); + m_CurPos += Count; + return OldCount; + } + + StringBuilderImpl& Append(C OneChar) + { + EnsureCapacity(1); + + *m_CurPos++ = OneChar; + + return *this; + } + + inline StringBuilderImpl& AppendAscii(const std::string_view& String) + { + const size_t len = String.size(); + + EnsureCapacity(len); + + for (size_t i = 0; i < len; ++i) + m_CurPos[i] = String[i]; + + m_CurPos += len; + + return *this; + } + + inline StringBuilderImpl& AppendAscii(const std::u8string_view& String) + { + const size_t len = String.size(); + + EnsureCapacity(len); + + for (size_t i = 0; i < len; ++i) + m_CurPos[i] = String[i]; + + m_CurPos += len; + + return *this; + } + + inline StringBuilderImpl& AppendAscii(const char* NulTerminatedString) + { + size_t StringLen = StringLength(NulTerminatedString); + + return AppendAscii({NulTerminatedString, StringLen}); + } + + inline StringBuilderImpl& Append(const char8_t* NulTerminatedString) + { + // This is super hacky and not fully functional - needs better + // solution + if constexpr (sizeof(C) == 1) + { + size_t len = StringLength((const char*)NulTerminatedString); + + EnsureCapacity(len); + + for (size_t i = 0; i < len; ++i) + m_CurPos[i] = C(NulTerminatedString[i]); + + m_CurPos += len; + } + else + { + ZEN_NOT_IMPLEMENTED(); + } + + return *this; + } + + inline StringBuilderImpl& AppendAsciiRange(const char* BeginString, const char* EndString) + { + EnsureCapacity(EndString - BeginString); + + while (BeginString != EndString) + *m_CurPos++ = *BeginString++; + + return *this; + } + + inline StringBuilderImpl& Append(const C* NulTerminatedString) + { + size_t Len = StringLength(NulTerminatedString); + + EnsureCapacity(Len); + memcpy(m_CurPos, NulTerminatedString, Len * sizeof(C)); + m_CurPos += Len; + + return *this; + } + + inline StringBuilderImpl& Append(const C* NulTerminatedString, size_t MaxChars) + { + size_t len = Min(MaxChars, StringLength(NulTerminatedString)); + + EnsureCapacity(len); + memcpy(m_CurPos, NulTerminatedString, len * sizeof(C)); + m_CurPos += len; + + return *this; + } + + inline StringBuilderImpl& AppendRange(const C* BeginString, const C* EndString) + { + size_t Len = EndString - BeginString; + + EnsureCapacity(Len); + memcpy(m_CurPos, BeginString, Len * sizeof(C)); + m_CurPos += Len; + + return *this; + } + + inline StringBuilderImpl& Append(const std::basic_string_view<C>& String) + { + return AppendRange(String.data(), String.data() + String.size()); + } + + inline StringBuilderImpl& AppendBool(bool v) + { + // This is a method instead of a << operator overload as the latter can + // easily get called with non-bool types like pointers. It is a very + // subtle behaviour that can cause bugs. + using namespace std::literals; + if (v) + { + return AppendAscii("true"sv); + } + return AppendAscii("false"sv); + } + + inline void RemoveSuffix(uint32_t Count) + { + ZEN_ASSERT(Count <= Size()); + m_CurPos -= Count; + } + + inline const C* c_str() const + { + EnsureNulTerminated(); + return m_Base; + } + + inline C* Data() + { + EnsureNulTerminated(); + return m_Base; + } + + inline const C* Data() const + { + EnsureNulTerminated(); + return m_Base; + } + + inline size_t Size() const { return m_CurPos - m_Base; } + inline bool IsDynamic() const { return m_IsDynamic; } + inline void Reset() { m_CurPos = m_Base; } + + inline StringBuilderImpl& operator<<(uint64_t n) + { + IntNum Str(n); + return AppendAscii(Str); + } + inline StringBuilderImpl& operator<<(int64_t n) + { + IntNum Str(n); + return AppendAscii(Str); + } + inline StringBuilderImpl& operator<<(uint32_t n) + { + IntNum Str(n); + return AppendAscii(Str); + } + inline StringBuilderImpl& operator<<(int32_t n) + { + IntNum Str(n); + return AppendAscii(Str); + } + inline StringBuilderImpl& operator<<(uint16_t n) + { + IntNum Str(n); + return AppendAscii(Str); + } + inline StringBuilderImpl& operator<<(int16_t n) + { + IntNum Str(n); + return AppendAscii(Str); + } + inline StringBuilderImpl& operator<<(uint8_t n) + { + IntNum Str(n); + return AppendAscii(Str); + } + inline StringBuilderImpl& operator<<(int8_t n) + { + IntNum Str(n); + return AppendAscii(Str); + } + + inline StringBuilderImpl& operator<<(const char* str) { return AppendAscii(str); } + inline StringBuilderImpl& operator<<(const std::string_view str) { return AppendAscii(str); } + inline StringBuilderImpl& operator<<(const std::u8string_view str) { return AppendAscii(str); } + +protected: + inline void Init(C* Base, size_t Capacity) + { + m_Base = m_CurPos = Base; + m_End = Base + Capacity; + } + + inline void EnsureNulTerminated() const { *m_CurPos = '\0'; } + + inline void EnsureCapacity(size_t ExtraRequired) + { + // precondition: we know the current buffer has enough capacity + // for the existing string including NUL terminator + + if ((m_CurPos + ExtraRequired) < m_End) + return; + + Extend(ExtraRequired); + } + + ZENCORE_API void Extend(size_t ExtraCapacity); + ZENCORE_API void* AllocBuffer(size_t ByteCount); + ZENCORE_API void FreeBuffer(void* Buffer, size_t ByteCount); + + ZENCORE_API [[noreturn]] void Fail(const char* FailReason); // note: throws exception + + C* m_Base; + C* m_CurPos; + C* m_End; + bool m_IsDynamic = false; + bool m_IsExtendable = false; +}; + +////////////////////////////////////////////////////////////////////////// + +extern template class StringBuilderImpl<char>; + +inline StringBuilderImpl<char>& +operator<<(StringBuilderImpl<char>& Builder, char Char) +{ + return Builder.Append(Char); +} + +class StringBuilderBase : public StringBuilderImpl<char> +{ +public: + inline StringBuilderBase(char* bufferPointer, size_t bufferCapacity) { Init(bufferPointer, bufferCapacity); } + inline ~StringBuilderBase() = default; + + // Note that we don't need a terminator for the string_view so we avoid calling data() here + inline operator std::string_view() const { return std::string_view(m_Base, m_CurPos - m_Base); } + inline std::string_view ToView() const { return std::string_view(m_Base, m_CurPos - m_Base); } + inline std::string ToString() const { return std::string{Data(), Size()}; } + + inline void AppendCodepoint(uint32_t cp) + { + if (cp < 0x80) // one octet + { + Append(static_cast<char8_t>(cp)); + } + else if (cp < 0x800) + { + EnsureCapacity(2); // two octets + m_CurPos[0] = static_cast<char8_t>((cp >> 6) | 0xc0); + m_CurPos[1] = static_cast<char8_t>((cp & 0x3f) | 0x80); + m_CurPos += 2; + } + else if (cp < 0x10000) + { + EnsureCapacity(3); // three octets + m_CurPos[0] = static_cast<char8_t>((cp >> 12) | 0xe0); + m_CurPos[1] = static_cast<char8_t>(((cp >> 6) & 0x3f) | 0x80); + m_CurPos[2] = static_cast<char8_t>((cp & 0x3f) | 0x80); + m_CurPos += 3; + } + else + { + EnsureCapacity(4); // four octets + m_CurPos[0] = static_cast<char8_t>((cp >> 18) | 0xf0); + m_CurPos[1] = static_cast<char8_t>(((cp >> 12) & 0x3f) | 0x80); + m_CurPos[2] = static_cast<char8_t>(((cp >> 6) & 0x3f) | 0x80); + m_CurPos[3] = static_cast<char8_t>((cp & 0x3f) | 0x80); + m_CurPos += 4; + } + } +}; + +template<size_t N> +class StringBuilder : public StringBuilderBase +{ +public: + inline StringBuilder() : StringBuilderBase(m_StringBuffer, sizeof m_StringBuffer) {} + inline ~StringBuilder() = default; + +private: + char m_StringBuffer[N]; +}; + +template<size_t N> +class ExtendableStringBuilder : public StringBuilderBase +{ +public: + inline ExtendableStringBuilder() : StringBuilderBase(m_StringBuffer, sizeof m_StringBuffer) { m_IsExtendable = true; } + inline ~ExtendableStringBuilder() = default; + +private: + char m_StringBuffer[N]; +}; + +template<size_t N> +class WriteToString : public ExtendableStringBuilder<N> +{ +public: + template<typename... ArgTypes> + explicit WriteToString(ArgTypes&&... Args) + { + (*this << ... << std::forward<ArgTypes>(Args)); + } +}; + +////////////////////////////////////////////////////////////////////////// + +extern template class StringBuilderImpl<wchar_t>; + +class WideStringBuilderBase : public StringBuilderImpl<wchar_t> +{ +public: + inline WideStringBuilderBase(wchar_t* BufferPointer, size_t BufferCapacity) { Init(BufferPointer, BufferCapacity); } + inline ~WideStringBuilderBase() = default; + + inline operator std::wstring_view() const { return std::wstring_view{Data(), Size()}; } + inline std::wstring_view ToView() const { return std::wstring_view{Data(), Size()}; } + inline std::wstring ToString() const { return std::wstring{Data(), Size()}; } + + inline StringBuilderImpl& operator<<(const std::wstring_view str) { return Append((const wchar_t*)str.data(), str.size()); } + inline StringBuilderImpl& operator<<(const wchar_t* str) { return Append(str); } + using StringBuilderImpl:: operator<<; +}; + +template<size_t N> +class WideStringBuilder : public WideStringBuilderBase +{ +public: + inline WideStringBuilder() : WideStringBuilderBase(m_Buffer, N) {} + ~WideStringBuilder() = default; + +private: + wchar_t m_Buffer[N]; +}; + +template<size_t N> +class ExtendableWideStringBuilder : public WideStringBuilderBase +{ +public: + inline ExtendableWideStringBuilder() : WideStringBuilderBase(m_Buffer, N) { m_IsExtendable = true; } + ~ExtendableWideStringBuilder() = default; + +private: + wchar_t m_Buffer[N]; +}; + +template<size_t N> +class WriteToWideString : public ExtendableWideStringBuilder<N> +{ +public: + template<typename... ArgTypes> + explicit WriteToWideString(ArgTypes&&... Args) + { + (*this << ... << Forward<ArgTypes>(Args)); + } +}; + +////////////////////////////////////////////////////////////////////////// + +void Utf8ToWide(const char8_t* str, WideStringBuilderBase& out); +void Utf8ToWide(const std::u8string_view& wstr, WideStringBuilderBase& out); +void Utf8ToWide(const std::string_view& wstr, WideStringBuilderBase& out); +std::wstring Utf8ToWide(const std::string_view& wstr); + +void WideToUtf8(const wchar_t* wstr, StringBuilderBase& out); +std::string WideToUtf8(const wchar_t* wstr); +void WideToUtf8(const std::wstring_view& wstr, StringBuilderBase& out); +std::string WideToUtf8(const std::wstring_view Wstr); + +inline uint8_t +Char2Nibble(char c) +{ + if (c >= '0' && c <= '9') + { + return uint8_t(c - '0'); + } + if (c >= 'a' && c <= 'f') + { + return uint8_t(c - 'a' + 10); + } + if (c >= 'A' && c <= 'F') + { + return uint8_t(c - 'A' + 10); + } + return uint8_t(0xff); +}; + +static constexpr const char HexChars[] = "0123456789abcdef"; + +/// <summary> +/// Parse hex string into a byte buffer +/// </summary> +/// <param name="string">Input string</param> +/// <param name="characterCount">Number of characters in string</param> +/// <param name="outPtr">Pointer to output buffer</param> +/// <returns>true if the input consisted of all valid hexadecimal characters</returns> + +inline bool +ParseHexBytes(const char* InputString, size_t CharacterCount, uint8_t* OutPtr) +{ + ZEN_ASSERT((CharacterCount & 1) == 0); + + uint8_t allBits = 0; + + while (CharacterCount) + { + uint8_t n0 = Char2Nibble(InputString[0]); + uint8_t n1 = Char2Nibble(InputString[1]); + + allBits |= n0 | n1; + + *OutPtr = (n0 << 4) | n1; + + OutPtr += 1; + InputString += 2; + CharacterCount -= 2; + } + + return (allBits & 0x80) == 0; +} + +inline void +ToHexBytes(const uint8_t* InputData, size_t ByteCount, char* OutString) +{ + while (ByteCount--) + { + uint8_t byte = *InputData++; + + *OutString++ = HexChars[byte >> 4]; + *OutString++ = HexChars[byte & 15]; + } +} + +inline bool +ParseHexNumber(const char* InputString, size_t CharacterCount, uint8_t* OutPtr) +{ + ZEN_ASSERT((CharacterCount & 1) == 0); + + uint8_t allBits = 0; + + InputString += CharacterCount; + while (CharacterCount) + { + InputString -= 2; + uint8_t n0 = Char2Nibble(InputString[0]); + uint8_t n1 = Char2Nibble(InputString[1]); + + allBits |= n0 | n1; + + *OutPtr = (n0 << 4) | n1; + + OutPtr += 1; + CharacterCount -= 2; + } + + return (allBits & 0x80) == 0; +} + +inline void +ToHexNumber(const uint8_t* InputData, size_t ByteCount, char* OutString) +{ + InputData += ByteCount; + while (ByteCount--) + { + uint8_t byte = *(--InputData); + + *OutString++ = HexChars[byte >> 4]; + *OutString++ = HexChars[byte & 15]; + } +} + +/// <summary> +/// Generates a hex number from a pointer to an integer type, this formats the number in the correct order for a hexadecimal number +/// </summary> +/// <param name="Value">Integer value type</param> +/// <param name="outString">Output buffer where resulting string is written</param> +void +ToHexNumber(UnsignedIntegral auto Value, char* OutString) +{ + ToHexNumber((const uint8_t*)&Value, sizeof(Value), OutString); + OutString[sizeof(Value) * 2] = 0; +} + +/// <summary> +/// Parse hex number string into a value, this formats the number in the correct order for a hexadecimal number +/// </summary> +/// <param name="string">Input string</param> +/// <param name="characterCount">Number of characters in string</param> +/// <param name="OutValue">Pointer to output value</param> +/// <returns>true if the input consisted of all valid hexadecimal characters</returns> +bool +ParseHexNumber(const std::string HexString, UnsignedIntegral auto& OutValue) +{ + return ParseHexNumber(HexString.c_str(), sizeof(OutValue) * 2, (uint8_t*)&OutValue); +} + +////////////////////////////////////////////////////////////////////////// +// Format numbers for humans +// + +ZENCORE_API size_t NiceNumToBuffer(uint64_t Num, std::span<char> Buffer); +ZENCORE_API size_t NiceBytesToBuffer(uint64_t Num, std::span<char> Buffer); +ZENCORE_API size_t NiceByteRateToBuffer(uint64_t Num, uint64_t ms, std::span<char> Buffer); +ZENCORE_API size_t NiceLatencyNsToBuffer(uint64_t NanoSeconds, std::span<char> Buffer); +ZENCORE_API size_t NiceTimeSpanMsToBuffer(uint64_t Milliseconds, std::span<char> Buffer); + +struct NiceBase +{ + inline const char* c_str() const { return m_Buffer; } + inline operator std::string_view() const { return std::string_view(m_Buffer); } + +protected: + char m_Buffer[16]; +}; + +struct NiceNum : public NiceBase +{ + inline NiceNum(uint64_t Num) { NiceNumToBuffer(Num, m_Buffer); } +}; + +struct NiceBytes : public NiceBase +{ + inline NiceBytes(uint64_t Num) { NiceBytesToBuffer(Num, m_Buffer); } +}; + +struct NiceByteRate : public NiceBase +{ + inline NiceByteRate(uint64_t Bytes, uint64_t TimeMilliseconds) { NiceByteRateToBuffer(Bytes, TimeMilliseconds, m_Buffer); } +}; + +struct NiceLatencyNs : public NiceBase +{ + inline NiceLatencyNs(uint64_t Milliseconds) { NiceLatencyNsToBuffer(Milliseconds, m_Buffer); } +}; + +struct NiceTimeSpanMs : public NiceBase +{ + inline NiceTimeSpanMs(uint64_t Milliseconds) { NiceTimeSpanMsToBuffer(Milliseconds, m_Buffer); } +}; + +////////////////////////////////////////////////////////////////////////// + +inline std::string +NiceRate(uint64_t Num, uint32_t DurationMilliseconds, const char* Unit = "B") +{ + char Buffer[32]; + + if (DurationMilliseconds) + { + // Leave a little of 'Buffer' for the "Unit/s" suffix + std::span<char> BufferSpan(Buffer, sizeof(Buffer) - 8); + NiceNumToBuffer(Num * 1000 / DurationMilliseconds, BufferSpan); + } + else + { + strcpy(Buffer, "0"); + } + + strncat(Buffer, Unit, 4); + strcat(Buffer, "/s"); + + return Buffer; +} + +////////////////////////////////////////////////////////////////////////// + +template<Integral T> +std::optional<T> +ParseInt(const std::string_view& Input) +{ + T Out = 0; + const std::from_chars_result Result = std::from_chars(Input.data(), Input.data() + Input.size(), Out); + if (Result.ec == std::errc::invalid_argument || Result.ec == std::errc::result_out_of_range) + { + return std::nullopt; + } + return Out; +} + +////////////////////////////////////////////////////////////////////////// + +constexpr uint32_t +HashStringDjb2(const std::string_view& InString) +{ + uint32_t HashValue = 5381; + + for (int CurChar : InString) + { + HashValue = HashValue * 33 + CurChar; + } + + return HashValue; +} + +constexpr uint32_t +HashStringAsLowerDjb2(const std::string_view& InString) +{ + uint32_t HashValue = 5381; + + for (uint8_t CurChar : InString) + { + CurChar -= ((CurChar - 'A') <= ('Z' - 'A')) * ('A' - 'a'); // this should be compiled into branchless logic + HashValue = HashValue * 33 + CurChar; + } + + return HashValue; +} + +////////////////////////////////////////////////////////////////////////// + +inline std::string +ToLower(const std::string_view& InString) +{ + std::string Out(InString); + + for (char& CurChar : Out) + { + CurChar -= (uint8_t(CurChar - 'A') <= ('Z' - 'A')) * ('A' - 'a'); // this should be compiled into branchless logic + } + + return Out; +} + +////////////////////////////////////////////////////////////////////////// + +template<typename Fn> +uint32_t +ForEachStrTok(const std::string_view& Str, char Delim, Fn&& Func) +{ + const char* It = Str.data(); + const char* End = It + Str.length(); + uint32_t Count = 0; + + while (It != End) + { + if (*It == Delim) + { + It++; + continue; + } + + std::string_view Remaining{It, size_t(ptrdiff_t(End - It))}; + size_t Idx = Remaining.find(Delim, 0); + + if (Idx == std::string_view::npos) + { + Idx = Remaining.size(); + } + + Count++; + std::string_view Token{It, Idx}; + if (!Func(Token)) + { + break; + } + + It = It + Idx; + } + + return Count; +} + +////////////////////////////////////////////////////////////////////////// + +inline int32_t +StrCaseCompare(const char* Lhs, const char* Rhs, int64_t Length = -1) +{ + // A helper for cross-platform case-insensitive string comparison. +#if ZEN_PLATFORM_WINDOWS + return (Length < 0) ? _stricmp(Lhs, Rhs) : _strnicmp(Lhs, Rhs, size_t(Length)); +#else + return (Length < 0) ? strcasecmp(Lhs, Rhs) : strncasecmp(Lhs, Rhs, size_t(Length)); +#endif +} + +/** + * @brief + * Helper function to implement case sensitive spaceship operator for strings. + * MacOS clang version we use does not implement <=> for std::string + * @param Lhs string + * @param Rhs string + * @return std::strong_ordering indicating relationship between Lhs and Rhs + */ +inline auto +caseSensitiveCompareStrings(const std::string& Lhs, const std::string& Rhs) +{ + int r = Lhs.compare(Rhs); + return r == 0 ? std::strong_ordering::equal : r < 0 ? std::strong_ordering::less : std::strong_ordering::greater; +} + +////////////////////////////////////////////////////////////////////////// + +/** + * ASCII character bitset useful for fast and readable parsing + * + * Entirely constexpr. Works with both wide and narrow strings. + * + * Example use cases: + * + * constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n"); + * bool bIsWhitespace = WhitespaceCharacters.Test(MyChar); + * const char* HelloWorld = AsciiSet::Skip(" \t\tHello world!", WhitespaceCharacters); + * + * constexpr AsciiSet XmlEscapeChars("&<>\"'"); + * check(AsciiSet::HasNone(EscapedXmlString, XmlEscapeChars)); + * + * constexpr AsciiSet Delimiters(".:;"); + * const TCHAR* DelimiterOrEnd = AsciiSet::FindFirstOrEnd(PrefixedName, Delimiters); + * FString Prefix(PrefixedName, DelimiterOrEnd - PrefixedName); + * + * constexpr AsciiSet Slashes("/\\"); + * const TCHAR* SlashOrEnd = AsciiSet::FindLastOrEnd(PathName, Slashes); + * const TCHAR* FileName = *SlashOrEnd ? SlashOrEnd + 1 : PathName; + */ +class AsciiSet +{ +public: + template<typename CharType, int N> + constexpr AsciiSet(const CharType (&Chars)[N]) : AsciiSet(StringToBitset(Chars)) + { + } + + /** Returns true if a character is part of the set */ + template<typename CharType> + constexpr inline bool Contains(CharType Char) const + { + using UnsignedCharType = typename std::make_unsigned<CharType>::type; + + return !!TestImpl((UnsignedCharType)Char); + } + + /** Returns non-zero if a character is part of the set. Prefer Contains() to avoid VS2019 conversion warnings. */ + template<typename CharType> + constexpr inline uint64_t Test(CharType Char) const + { + using UnsignedCharType = typename std::make_unsigned<CharType>::type; + + return TestImpl((UnsignedCharType)Char); + } + + /** Create new set with specified character in it */ + constexpr inline AsciiSet operator+(char Char) const + { + using UnsignedCharType = typename std::make_unsigned<char>::type; + + InitData Bitset = {LoMask, HiMask}; + SetImpl(Bitset, (UnsignedCharType)Char); + return AsciiSet(Bitset); + } + + /** Create new set containing inverse set of characters - likely including null-terminator */ + constexpr inline AsciiSet operator~() const { return AsciiSet(~LoMask, ~HiMask); } + + ////////// Algorithms for C strings ////////// + + /** Find first character of string inside set or end pointer. Never returns null. */ + template<class CharType> + static constexpr const CharType* FindFirstOrEnd(const CharType* Str, AsciiSet Set) + { + for (AsciiSet SetOrNil(Set.LoMask | NilMask, Set.HiMask); !SetOrNil.Test(*Str); ++Str) + ; + + return Str; + } + + /** Find last character of string inside set or end pointer. Never returns null. */ + template<class CharType> + static constexpr const CharType* FindLastOrEnd(const CharType* Str, AsciiSet Set) + { + const CharType* Last = FindFirstOrEnd(Str, Set); + + for (const CharType* It = Last; *It; It = FindFirstOrEnd(It + 1, Set)) + { + Last = It; + } + + return Last; + } + + /** Find first character of string outside of set. Never returns null. */ + template<typename CharType> + static constexpr const CharType* Skip(const CharType* Str, AsciiSet Set) + { + while (Set.Contains(*Str)) + { + ++Str; + } + + return Str; + } + + /** Test if string contains any character in set */ + template<typename CharType> + static constexpr bool HasAny(const CharType* Str, AsciiSet Set) + { + return *FindFirstOrEnd(Str, Set) != '\0'; + } + + /** Test if string contains no character in set */ + template<typename CharType> + static constexpr bool HasNone(const CharType* Str, AsciiSet Set) + { + return *FindFirstOrEnd(Str, Set) == '\0'; + } + + /** Test if string contains any character outside of set */ + template<typename CharType> + static constexpr bool HasOnly(const CharType* Str, AsciiSet Set) + { + return *Skip(Str, Set) == '\0'; + } + + ////////// Algorithms for string types like std::string_view and std::string ////////// + + /** Get initial substring with all characters in set */ + template<class StringType> + static constexpr StringType FindPrefixWith(const StringType& Str, AsciiSet Set) + { + return Scan<EDir::Forward, EInclude::Members, EKeep::Head>(Str, Set); + } + + /** Get initial substring with no characters in set */ + template<class StringType> + static constexpr StringType FindPrefixWithout(const StringType& Str, AsciiSet Set) + { + return Scan<EDir::Forward, EInclude::NonMembers, EKeep::Head>(Str, Set); + } + + /** Trim initial characters in set */ + template<class StringType> + static constexpr StringType TrimPrefixWith(const StringType& Str, AsciiSet Set) + { + return Scan<EDir::Forward, EInclude::Members, EKeep::Tail>(Str, Set); + } + + /** Trim initial characters not in set */ + template<class StringType> + static constexpr StringType TrimPrefixWithout(const StringType& Str, AsciiSet Set) + { + return Scan<EDir::Forward, EInclude::NonMembers, EKeep::Tail>(Str, Set); + } + + /** Get trailing substring with all characters in set */ + template<class StringType> + static constexpr StringType FindSuffixWith(const StringType& Str, AsciiSet Set) + { + return Scan<EDir::Reverse, EInclude::Members, EKeep::Tail>(Str, Set); + } + + /** Get trailing substring with no characters in set */ + template<class StringType> + static constexpr StringType FindSuffixWithout(const StringType& Str, AsciiSet Set) + { + return Scan<EDir::Reverse, EInclude::NonMembers, EKeep::Tail>(Str, Set); + } + + /** Trim trailing characters in set */ + template<class StringType> + static constexpr StringType TrimSuffixWith(const StringType& Str, AsciiSet Set) + { + return Scan<EDir::Reverse, EInclude::Members, EKeep::Head>(Str, Set); + } + + /** Trim trailing characters not in set */ + template<class StringType> + static constexpr StringType TrimSuffixWithout(const StringType& Str, AsciiSet Set) + { + return Scan<EDir::Reverse, EInclude::NonMembers, EKeep::Head>(Str, Set); + } + + /** Test if string contains any character in set */ + template<class StringType> + static constexpr bool HasAny(const StringType& Str, AsciiSet Set) + { + return !HasNone(Str, Set); + } + + /** Test if string contains no character in set */ + template<class StringType> + static constexpr bool HasNone(const StringType& Str, AsciiSet Set) + { + uint64_t Match = 0; + for (auto Char : Str) + { + Match |= Set.Test(Char); + } + return Match == 0; + } + + /** Test if string contains any character outside of set */ + template<class StringType> + static constexpr bool HasOnly(const StringType& Str, AsciiSet Set) + { + auto End = Str.data() + Str.size(); + return FindFirst<EInclude::Members>(Set, Str.data(), End) == End; + } + +private: + enum class EDir + { + Forward, + Reverse + }; + enum class EInclude + { + Members, + NonMembers + }; + enum class EKeep + { + Head, + Tail + }; + + template<EInclude Include, typename CharType> + static constexpr const CharType* FindFirst(AsciiSet Set, const CharType* It, const CharType* End) + { + for (; It != End && (Include == EInclude::Members) == !!Set.Test(*It); ++It) + ; + return It; + } + + template<EInclude Include, typename CharType> + static constexpr const CharType* FindLast(AsciiSet Set, const CharType* It, const CharType* End) + { + for (; It != End && (Include == EInclude::Members) == !!Set.Test(*It); --It) + ; + return It; + } + + template<EDir Dir, EInclude Include, EKeep Keep, class StringType> + static constexpr StringType Scan(const StringType& Str, AsciiSet Set) + { + auto Begin = Str.data(); + auto End = Begin + Str.size(); + auto It = Dir == EDir::Forward ? FindFirst<Include>(Set, Begin, End) : FindLast<Include>(Set, End - 1, Begin - 1) + 1; + + return Keep == EKeep::Head ? StringType(Begin, static_cast<int32_t>(It - Begin)) : StringType(It, static_cast<int32_t>(End - It)); + } + + // Work-around for constexpr limitations + struct InitData + { + uint64_t Lo, Hi; + }; + static constexpr uint64_t NilMask = uint64_t(1) << '\0'; + + static constexpr inline void SetImpl(InitData& Bitset, uint32_t Char) + { + uint64_t IsLo = uint64_t(0) - (Char >> 6 == 0); + uint64_t IsHi = uint64_t(0) - (Char >> 6 == 1); + uint64_t Bit = uint64_t(1) << uint8_t(Char & 0x3f); + + Bitset.Lo |= Bit & IsLo; + Bitset.Hi |= Bit & IsHi; + } + + constexpr inline uint64_t TestImpl(uint32_t Char) const + { + uint64_t IsLo = uint64_t(0) - (Char >> 6 == 0); + uint64_t IsHi = uint64_t(0) - (Char >> 6 == 1); + uint64_t Bit = uint64_t(1) << (Char & 0x3f); + + return (Bit & IsLo & LoMask) | (Bit & IsHi & HiMask); + } + + template<typename CharType, int N> + static constexpr InitData StringToBitset(const CharType (&Chars)[N]) + { + using UnsignedCharType = typename std::make_unsigned<CharType>::type; + + InitData Bitset = {0, 0}; + for (int I = 0; I < N - 1; ++I) + { + SetImpl(Bitset, UnsignedCharType(Chars[I])); + } + + return Bitset; + } + + constexpr AsciiSet(InitData Bitset) : LoMask(Bitset.Lo), HiMask(Bitset.Hi) {} + + constexpr AsciiSet(uint64_t Lo, uint64_t Hi) : LoMask(Lo), HiMask(Hi) {} + + uint64_t LoMask, HiMask; +}; + +////////////////////////////////////////////////////////////////////////// + +void string_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/testing.h b/src/zencore/include/zencore/testing.h new file mode 100644 index 000000000..a00ee3166 --- /dev/null +++ b/src/zencore/include/zencore/testing.h @@ -0,0 +1,67 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <memory> + +#ifndef ZEN_TEST_WITH_RUNNER +# define ZEN_TEST_WITH_RUNNER 0 +#endif + +#if ZEN_TEST_WITH_RUNNER +# define DOCTEST_CONFIG_IMPLEMENT +#endif + +#if ZEN_WITH_TESTS +# include <doctest/doctest.h> +inline auto +Approx(auto Value) +{ + return doctest::Approx(Value); +} +#endif + +/** + * Test runner helper + * + * This acts as a thin layer between the test app and the test + * framework, which is used to customize configuration logic + * and to set up logging. + * + * If you don't want to implement custom setup then the + * ZEN_RUN_TESTS macro can be used instead. + */ + +#if ZEN_WITH_TESTS +namespace zen::testing { + +class TestRunner +{ +public: + TestRunner(); + ~TestRunner(); + + int ApplyCommandLine(int argc, char const* const* argv); + int Run(); + +private: + struct Impl; + + std::unique_ptr<Impl> m_Impl; +}; + +# define ZEN_RUN_TESTS(argC, argV) \ + [&] { \ + zen::testing::TestRunner Runner; \ + Runner.ApplyCommandLine(argC, argV); \ + return Runner.Run(); \ + }() + +} // namespace zen::testing +#endif + +#if ZEN_TEST_WITH_RUNNER +# undef DOCTEST_CONFIG_IMPLEMENT +#endif diff --git a/src/zencore/include/zencore/testutils.h b/src/zencore/include/zencore/testutils.h new file mode 100644 index 000000000..04648c6de --- /dev/null +++ b/src/zencore/include/zencore/testutils.h @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <filesystem> + +namespace zen { + +std::filesystem::path CreateTemporaryDirectory(); + +class ScopedTemporaryDirectory +{ +public: + explicit ScopedTemporaryDirectory(std::filesystem::path Directory); + ScopedTemporaryDirectory(); + ~ScopedTemporaryDirectory(); + + std::filesystem::path& Path() { return m_RootPath; } + +private: + std::filesystem::path m_RootPath; +}; + +struct ScopedCurrentDirectoryChange +{ + std::filesystem::path OldPath{std::filesystem::current_path()}; + + ScopedCurrentDirectoryChange() { std::filesystem::current_path(CreateTemporaryDirectory()); } + ~ScopedCurrentDirectoryChange() { std::filesystem::current_path(OldPath); } +}; + +} // namespace zen diff --git a/src/zencore/include/zencore/thread.h b/src/zencore/include/zencore/thread.h new file mode 100644 index 000000000..a9c96d422 --- /dev/null +++ b/src/zencore/include/zencore/thread.h @@ -0,0 +1,273 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#include <atomic> +#include <filesystem> +#include <shared_mutex> +#include <string_view> +#include <vector> + +namespace zen { + +void SetCurrentThreadName(std::string_view ThreadName); + +/** + * Reader-writer lock + * + * - A single thread may hold an exclusive lock at any given moment + * + * - Multiple threads may hold shared locks, but only if no thread has + * acquired an exclusive lock + */ +class RwLock +{ +public: + ZENCORE_API void AcquireShared(); + ZENCORE_API void ReleaseShared(); + + ZENCORE_API void AcquireExclusive(); + ZENCORE_API void ReleaseExclusive(); + + struct SharedLockScope + { + SharedLockScope(RwLock& Lock) : m_Lock(&Lock) { Lock.AcquireShared(); } + ~SharedLockScope() { ReleaseNow(); } + + void ReleaseNow() + { + if (m_Lock) + { + m_Lock->ReleaseShared(); + m_Lock = nullptr; + } + } + + private: + RwLock* m_Lock; + }; + + struct ExclusiveLockScope + { + ExclusiveLockScope(RwLock& Lock) : m_Lock(&Lock) { Lock.AcquireExclusive(); } + ~ExclusiveLockScope() { ReleaseNow(); } + + void ReleaseNow() + { + if (m_Lock) + { + m_Lock->ReleaseExclusive(); + m_Lock = nullptr; + } + } + + private: + RwLock* m_Lock; + }; + +private: + std::shared_mutex m_Mutex; +}; + +/** Basic abstraction of a simple event synchronization mechanism (aka 'binary semaphore') + */ +class Event +{ +public: + ZENCORE_API Event(); + ZENCORE_API ~Event(); + + Event(Event&& Rhs) noexcept : m_EventHandle(Rhs.m_EventHandle) { Rhs.m_EventHandle = nullptr; } + + Event(const Event& Rhs) = delete; + Event& operator=(const Event& Rhs) = delete; + + inline Event& operator=(Event&& Rhs) noexcept + { + std::swap(m_EventHandle, Rhs.m_EventHandle); + return *this; + } + + ZENCORE_API void Set(); + ZENCORE_API void Reset(); + ZENCORE_API bool Wait(int TimeoutMs = -1); + ZENCORE_API void Close(); + +protected: + explicit Event(void* EventHandle) : m_EventHandle(EventHandle) {} + + void* m_EventHandle = nullptr; +}; + +/** Basic abstraction of an IPC mechanism (aka 'binary semaphore') + */ +class NamedEvent +{ +public: + NamedEvent() = default; + ZENCORE_API explicit NamedEvent(std::string_view EventName); + ZENCORE_API ~NamedEvent(); + ZENCORE_API void Close(); + ZENCORE_API void Set(); + ZENCORE_API bool Wait(int TimeoutMs = -1); + + NamedEvent(NamedEvent&& Rhs) noexcept : m_EventHandle(Rhs.m_EventHandle) { Rhs.m_EventHandle = nullptr; } + + inline NamedEvent& operator=(NamedEvent&& Rhs) noexcept + { + std::swap(m_EventHandle, Rhs.m_EventHandle); + return *this; + } + +protected: + void* m_EventHandle = nullptr; + +private: + NamedEvent(const NamedEvent& Rhs) = delete; + NamedEvent& operator=(const NamedEvent& Rhs) = delete; +}; + +/** Basic abstraction of a named (system wide) mutex primitive + */ +class NamedMutex +{ +public: + ~NamedMutex(); + + ZENCORE_API [[nodiscard]] bool Create(std::string_view MutexName); + + ZENCORE_API static bool Exists(std::string_view MutexName); + +private: + void* m_MutexHandle = nullptr; +}; + +/** + * Downward counter of type std::ptrdiff_t which can be used to synchronize threads + */ +class Latch +{ +public: + Latch(std::ptrdiff_t Count) : Counter(Count) {} + + void CountDown() + { + std::ptrdiff_t Old = Counter.fetch_sub(1); + if (Old == 1) + { + Complete.Set(); + } + } + + std::ptrdiff_t Remaining() const { return Counter.load(); } + + // If you want to add dynamic count, make sure to set the initial counter to 1 + // and then do a CountDown() just before wait to not trigger the event causing + // false positive completion results. + void AddCount(std::ptrdiff_t Count) + { + std::atomic_ptrdiff_t Old = Counter.fetch_add(Count); + ZEN_ASSERT_SLOW(Old > 0); + } + + bool Wait(int TimeoutMs = -1) + { + std::ptrdiff_t Old = Counter.load(); + if (Old == 0) + { + return true; + } + return Complete.Wait(TimeoutMs); + } + +private: + std::atomic_ptrdiff_t Counter; + Event Complete; +}; + +/** Basic process abstraction + */ +class ProcessHandle +{ +public: + ZENCORE_API ProcessHandle(); + + ProcessHandle(const ProcessHandle&) = delete; + ProcessHandle& operator=(const ProcessHandle&) = delete; + + ZENCORE_API ~ProcessHandle(); + + ZENCORE_API void Initialize(int Pid); + ZENCORE_API void Initialize(void* ProcessHandle); /// Initialize with an existing handle - takes ownership of the handle + ZENCORE_API [[nodiscard]] bool IsRunning() const; + ZENCORE_API [[nodiscard]] bool IsValid() const; + ZENCORE_API bool Wait(int TimeoutMs = -1); + ZENCORE_API void Terminate(int ExitCode); + ZENCORE_API void Reset(); + [[nodiscard]] inline int Pid() const { return m_Pid; } + +private: + void* m_ProcessHandle = nullptr; + int m_Pid = 0; +}; + +/** Basic process creation + */ +struct CreateProcOptions +{ + enum + { + Flag_NewConsole = 1 << 0, + Flag_Elevated = 1 << 1, + Flag_Unelevated = 1 << 2, + }; + + const std::filesystem::path* WorkingDirectory = nullptr; + uint32_t Flags = 0; +}; + +#if ZEN_PLATFORM_WINDOWS +using CreateProcResult = void*; // handle to the process +#else +using CreateProcResult = int32_t; // pid +#endif + +ZENCORE_API CreateProcResult CreateProc(const std::filesystem::path& Executable, + std::string_view CommandLine, // should also include arg[0] (executable name) + const CreateProcOptions& Options = {}); + +/** Process monitor - monitors a list of running processes via polling + + Intended to be used to monitor a set of "sponsor" processes, where + we need to determine when none of them remain alive + + */ + +class ProcessMonitor +{ +public: + ProcessMonitor(); + ~ProcessMonitor(); + + ZENCORE_API bool IsRunning(); + ZENCORE_API void AddPid(int Pid); + ZENCORE_API bool IsActive() const; + +private: + using HandleType = void*; + + mutable RwLock m_Lock; + std::vector<HandleType> m_ProcessHandles; +}; + +ZENCORE_API bool IsProcessRunning(int pid); +ZENCORE_API int GetCurrentProcessId(); +ZENCORE_API int GetCurrentThreadId(); + +ZENCORE_API void Sleep(int ms); + +void thread_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/timer.h b/src/zencore/include/zencore/timer.h new file mode 100644 index 000000000..e4ddc3505 --- /dev/null +++ b/src/zencore/include/zencore/timer.h @@ -0,0 +1,58 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#if ZEN_COMPILER_MSC +# include <intrin.h> +#elif ZEN_ARCH_X64 +# include <x86intrin.h> +#endif + +#include <stdint.h> + +namespace zen { + +// High frequency timers + +ZENCORE_API uint64_t GetHifreqTimerValue(); +ZENCORE_API uint64_t GetHifreqTimerFrequency(); +ZENCORE_API double GetHifreqTimerToSeconds(); +ZENCORE_API uint64_t GetHifreqTimerFrequencySafe(); // May be used during static init + +class Stopwatch +{ +public: + inline Stopwatch() : m_StartValue(GetHifreqTimerValue()) {} + + inline uint64_t GetElapsedTimeMs() const { return (GetHifreqTimerValue() - m_StartValue) * 1'000 / GetHifreqTimerFrequency(); } + inline uint64_t GetElapsedTimeUs() const { return (GetHifreqTimerValue() - m_StartValue) * 1'000'000 / GetHifreqTimerFrequency(); } + inline uint64_t GetElapsedTicks() const { return GetHifreqTimerValue() - m_StartValue; } + inline void Reset() { m_StartValue = GetHifreqTimerValue(); } + + static inline uint64_t GetElapsedTimeMs(uint64_t Ticks) { return Ticks * 1'000 / GetHifreqTimerFrequency(); } + static inline uint64_t GetElapsedTimeUs(uint64_t Ticks) { return Ticks * 1'000'000 / GetHifreqTimerFrequency(); } + +private: + uint64_t m_StartValue; +}; + +// Low frequency timers + +namespace detail { + extern ZENCORE_API uint64_t g_LofreqTimerValue; +} // namespace detail + +inline uint64_t +GetLofreqTimerValue() +{ + return detail::g_LofreqTimerValue; +} + +ZENCORE_API void UpdateLofreqTimerValue(); +ZENCORE_API uint64_t GetLofreqTimerFrequency(); + +void timer_forcelink(); // internal + +} // namespace zen diff --git a/src/zencore/include/zencore/trace.h b/src/zencore/include/zencore/trace.h new file mode 100644 index 000000000..0af490f23 --- /dev/null +++ b/src/zencore/include/zencore/trace.h @@ -0,0 +1,36 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +/* clang-format off */ + +#include <zencore/zencore.h> + +#if ZEN_WITH_TRACE + +ZEN_THIRD_PARTY_INCLUDES_START +#if !defined(TRACE_IMPLEMENT) +# define TRACE_IMPLEMENT 0 +#endif +#include <trace.h> +#undef TRACE_IMPLEMENT +ZEN_THIRD_PARTY_INCLUDES_END + +#define ZEN_TRACE_CPU(x) TRACE_CPU_SCOPE(x) + +enum class TraceType +{ + File, + Network, + None +}; + +void TraceInit(const char* HostOrPath, TraceType Type); + +#else + +#define ZEN_TRACE_CPU(x) + +#endif // ZEN_WITH_TRACE + +/* clang-format on */ diff --git a/src/zencore/include/zencore/uid.h b/src/zencore/include/zencore/uid.h new file mode 100644 index 000000000..9659f5893 --- /dev/null +++ b/src/zencore/include/zencore/uid.h @@ -0,0 +1,87 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> +#include <compare> + +namespace zen { + +class StringBuilderBase; + +/** Object identifier + + Can be used as a GUID essentially, but is more compact (12 bytes) and as such + is more susceptible to collisions than a 16-byte GUID but also I don't expect + the population to be large so in practice the risk should be minimal due to + how the identifiers work. + + Similar in spirit to MongoDB ObjectId + + When serialized, object identifiers generated in a given session in sequence + will sort in chronological order since the timestamp is in the MSB in big + endian format. This makes it suitable as a database key since most indexing + structures work better when keys are inserted in lexicographically + increasing order. + + The current layout is basically: + + |----------------|----------------|----------------| + | timestamp | serial # | run id | + |----------------|----------------|----------------| + MSB LSB + + - Timestamp is a unsigned 32-bit value (seconds since 00:00:00 Jan 1 2021) + - Serial # is another unsigned 32-bit value which is assigned a (strong) + random number at initialization time which is incremented when a new Oid + is generated + - The run id is generated from a strong random number generator + at initialization time and stays fixed for the duration of the program + + Timestamp and serial are stored in memory in such a way that they can be + ordered lexicographically. I.e they are in big-endian byte order. + + NOTE: The information above is only meant to explain the properties of + the identifiers. Client code should simply treat the identifier as an + opaque value and may not make any assumptions on the structure, as there + may be other ways of generating the identifiers in the future if an + application benefits. + + */ + +struct Oid +{ + static const int StringLength = 24; + typedef char String_t[StringLength + 1]; + + static void Initialize(); + [[nodiscard]] static Oid NewOid(); + + const Oid& Generate(); + [[nodiscard]] static Oid FromHexString(const std::string_view String); + StringBuilderBase& ToString(StringBuilderBase& OutString) const; + void ToString(char OutString[StringLength]); + [[nodiscard]] static Oid FromMemory(const void* Ptr); + + auto operator<=>(const Oid& rhs) const = default; + [[nodiscard]] inline operator bool() const { return *this != Zero; } + + static const Oid Zero; // Min (can be used to signify a "null" value, or for open range queries) + static const Oid Max; // Max (can be used for open range queries) + + struct Hasher + { + size_t operator()(const Oid& id) const + { + const size_t seed = id.OidBits[0]; + return ((seed << 6) + (seed >> 2) + 0x9e3779b9 + uint64_t(id.OidBits[1])) | (uint64_t(id.OidBits[2]) << 32); + } + }; + + // You should not assume anything about these words + uint32_t OidBits[3]; +}; + +extern void uid_forcelink(); + +} // namespace zen diff --git a/src/zencore/include/zencore/varint.h b/src/zencore/include/zencore/varint.h new file mode 100644 index 000000000..e57e1d497 --- /dev/null +++ b/src/zencore/include/zencore/varint.h @@ -0,0 +1,277 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "intmath.h" + +#include <algorithm> + +namespace zen { + +// Variable-Length Integer Encoding +// +// ZigZag encoding is used to convert signed integers into unsigned integers in a way that allows +// integers with a small magnitude to have a smaller encoded representation. +// +// An unsigned integer is encoded into 1-9 bytes based on its magnitude. The first byte indicates +// how many additional bytes are used by the number of leading 1-bits that it has. The additional +// bytes are stored in big endian order, and the most significant bits of the value are stored in +// the remaining bits in the first byte. The encoding of the first byte allows the reader to skip +// over the encoded integer without consuming its bytes individually. +// +// Encoded unsigned integers sort the same in a byte-wise comparison as when their decoded values +// are compared. The same property does not hold for signed integers due to ZigZag encoding. +// +// 32-bit inputs encode to 1-5 bytes. +// 64-bit inputs encode to 1-9 bytes. +// +// 0x0000'0000'0000'0000 - 0x0000'0000'0000'007f : 0b0_______ 1 byte +// 0x0000'0000'0000'0080 - 0x0000'0000'0000'3fff : 0b10______ 2 bytes +// 0x0000'0000'0000'4000 - 0x0000'0000'001f'ffff : 0b110_____ 3 bytes +// 0x0000'0000'0020'0000 - 0x0000'0000'0fff'ffff : 0b1110____ 4 bytes +// 0x0000'0000'1000'0000 - 0x0000'0007'ffff'ffff : 0b11110___ 5 bytes +// 0x0000'0008'0000'0000 - 0x0000'03ff'ffff'ffff : 0b111110__ 6 bytes +// 0x0000'0400'0000'0000 - 0x0001'ffff'ffff'ffff : 0b1111110_ 7 bytes +// 0x0002'0000'0000'0000 - 0x00ff'ffff'ffff'ffff : 0b11111110 8 bytes +// 0x0100'0000'0000'0000 - 0xffff'ffff'ffff'ffff : 0b11111111 9 bytes +// +// Encoding Examples +// -42 => ZigZag => 0x53 => 0x53 +// 42 => ZigZag => 0x54 => 0x54 +// 0x1 => 0x01 +// 0x12 => 0x12 +// 0x123 => 0x81 0x23 +// 0x1234 => 0x92 0x34 +// 0x12345 => 0xc1 0x23 0x45 +// 0x123456 => 0xd2 0x34 0x56 +// 0x1234567 => 0xe1 0x23 0x45 0x67 +// 0x12345678 => 0xf0 0x12 0x34 0x56 0x78 +// 0x123456789 => 0xf1 0x23 0x45 0x67 0x89 +// 0x123456789a => 0xf8 0x12 0x34 0x56 0x78 0x9a +// 0x123456789ab => 0xfb 0x23 0x45 0x67 0x89 0xab +// 0x123456789abc => 0xfc 0x12 0x34 0x56 0x78 0x9a 0xbc +// 0x123456789abcd => 0xfd 0x23 0x45 0x67 0x89 0xab 0xcd +// 0x123456789abcde => 0xfe 0x12 0x34 0x56 0x78 0x9a 0xbc 0xde +// 0x123456789abcdef => 0xff 0x01 0x23 0x45 0x67 0x89 0xab 0xcd 0xef +// 0x123456789abcdef0 => 0xff 0x12 0x34 0x56 0x78 0x9a 0xbc 0xde 0xf0 + +/** + * Measure the length in bytes (1-9) of an encoded variable-length integer. + * + * @param InData A variable-length encoding of an (signed or unsigned) integer. + * @return The number of bytes used to encode the integer, in the range 1-9. + */ +inline uint32_t +MeasureVarUInt(const void* InData) +{ + return CountLeadingZeros(uint8_t(~*static_cast<const uint8_t*>(InData))) - 23; +} + +/** Measure the length in bytes (1-9) of an encoded variable-length integer. \see \ref MeasureVarUInt */ +inline uint32_t +MeasureVarInt(const void* InData) +{ + return MeasureVarUInt(InData); +} + +/** Measure the number of bytes (1-5) required to encode the 32-bit input. */ +inline uint32_t +MeasureVarUInt(uint32_t InValue) +{ + return uint32_t(int32_t(FloorLog2(InValue)) / 7 + 1); +} + +/** Measure the number of bytes (1-9) required to encode the 64-bit input. */ +inline uint32_t +MeasureVarUInt(uint64_t InValue) +{ + return uint32_t(std::min(int32_t(FloorLog2_64(InValue)) / 7 + 1, 9)); +} + +/** Measure the number of bytes (1-5) required to encode the 32-bit input. \see \ref MeasureVarUInt */ +inline uint32_t +MeasureVarInt(int32_t InValue) +{ + return MeasureVarUInt(uint32_t((InValue >> 31) ^ (InValue << 1))); +} + +/** Measure the number of bytes (1-9) required to encode the 64-bit input. \see \ref MeasureVarUInt */ +inline uint32_t +MeasureVarInt(int64_t InValue) +{ + return MeasureVarUInt(uint64_t((InValue >> 63) ^ (InValue << 1))); +} + +/** + * Read a variable-length unsigned integer. + * + * @param InData A variable-length encoding of an unsigned integer. + * @param OutByteCount The number of bytes consumed from the input. + * @return An unsigned integer. + */ +inline uint64_t +ReadVarUInt(const void* InData, uint32_t& OutByteCount) +{ + const uint32_t ByteCount = MeasureVarUInt(InData); + OutByteCount = ByteCount; + + const uint8_t* InBytes = static_cast<const uint8_t*>(InData); + uint64_t Value = *InBytes++ & uint8_t(0xff >> ByteCount); + switch (ByteCount - 1) + { + case 8: + Value <<= 8; + Value |= *InBytes++; + [[fallthrough]]; + case 7: + Value <<= 8; + Value |= *InBytes++; + [[fallthrough]]; + case 6: + Value <<= 8; + Value |= *InBytes++; + [[fallthrough]]; + case 5: + Value <<= 8; + Value |= *InBytes++; + [[fallthrough]]; + case 4: + Value <<= 8; + Value |= *InBytes++; + [[fallthrough]]; + case 3: + Value <<= 8; + Value |= *InBytes++; + [[fallthrough]]; + case 2: + Value <<= 8; + Value |= *InBytes++; + [[fallthrough]]; + case 1: + Value <<= 8; + Value |= *InBytes++; + [[fallthrough]]; + default: + return Value; + } +} + +/** + * Read a variable-length signed integer. + * + * @param InData A variable-length encoding of a signed integer. + * @param OutByteCount The number of bytes consumed from the input. + * @return A signed integer. + */ +inline int64_t +ReadVarInt(const void* InData, uint32_t& OutByteCount) +{ + const uint64_t Value = ReadVarUInt(InData, OutByteCount); + return -int64_t(Value & 1) ^ int64_t(Value >> 1); +} + +/** + * Write a variable-length unsigned integer. + * + * @param InValue An unsigned integer to encode. + * @param OutData A buffer of at least 5 bytes to write the output to. + * @return The number of bytes used in the output. + */ +inline uint32_t +WriteVarUInt(uint32_t InValue, void* OutData) +{ + const uint32_t ByteCount = MeasureVarUInt(InValue); + uint8_t* OutBytes = static_cast<uint8_t*>(OutData) + ByteCount - 1; + switch (ByteCount - 1) + { + case 4: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + case 3: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + case 2: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + case 1: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + default: + break; + } + *OutBytes = uint8_t(0xff << (9 - ByteCount)) | uint8_t(InValue); + return ByteCount; +} + +/** + * Write a variable-length unsigned integer. + * + * @param InValue An unsigned integer to encode. + * @param OutData A buffer of at least 9 bytes to write the output to. + * @return The number of bytes used in the output. + */ +inline uint32_t +WriteVarUInt(uint64_t InValue, void* OutData) +{ + const uint32_t ByteCount = MeasureVarUInt(InValue); + uint8_t* OutBytes = static_cast<uint8_t*>(OutData) + ByteCount - 1; + switch (ByteCount - 1) + { + case 8: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + case 7: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + case 6: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + case 5: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + case 4: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + case 3: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + case 2: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + case 1: + *OutBytes-- = uint8_t(InValue); + InValue >>= 8; + [[fallthrough]]; + default: + break; + } + *OutBytes = uint8_t(0xff << (9 - ByteCount)) | uint8_t(InValue); + return ByteCount; +} + +/** Write a variable-length signed integer. \see \ref WriteVarUInt */ +inline uint32_t +WriteVarInt(int32_t InValue, void* OutData) +{ + const uint32_t Value = uint32_t((InValue >> 31) ^ (InValue << 1)); + return WriteVarUInt(Value, OutData); +} + +/** Write a variable-length signed integer. \see \ref WriteVarUInt */ +inline uint32_t +WriteVarInt(int64_t InValue, void* OutData) +{ + const uint64_t Value = uint64_t((InValue >> 63) ^ (InValue << 1)); + return WriteVarUInt(Value, OutData); +} + +} // namespace zen diff --git a/src/zencore/include/zencore/windows.h b/src/zencore/include/zencore/windows.h new file mode 100644 index 000000000..91828f0ec --- /dev/null +++ b/src/zencore/include/zencore/windows.h @@ -0,0 +1,25 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +ZEN_THIRD_PARTY_INCLUDES_START + +struct IUnknown; // Workaround for "combaseapi.h(229): error C2187: syntax error: 'identifier' was unexpected here" when using /permissive- +#ifndef NOMINMAX +# define NOMINMAX // We don't want your min/max macros +#endif +#ifndef NOGDI +# define NOGDI // We don't want your GetObject define +#endif +#ifndef WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +#endif +#ifndef _WIN32_WINNT +# define _WIN32_WINNT 0x0A00 +#endif +#include <windows.h> +#undef GetObject + +ZEN_THIRD_PARTY_INCLUDES_END diff --git a/src/zencore/include/zencore/workthreadpool.h b/src/zencore/include/zencore/workthreadpool.h new file mode 100644 index 000000000..0ddc65298 --- /dev/null +++ b/src/zencore/include/zencore/workthreadpool.h @@ -0,0 +1,48 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <zencore/blockingqueue.h> +#include <zencore/refcount.h> + +#include <exception> +#include <functional> +#include <system_error> +#include <thread> +#include <vector> + +namespace zen { + +struct IWork : public RefCounted +{ + virtual void Execute() = 0; + + inline std::exception_ptr GetException() { return m_Exception; } + +private: + std::exception_ptr m_Exception; + + friend class WorkerThreadPool; +}; + +class WorkerThreadPool +{ +public: + WorkerThreadPool(int InThreadCount); + ~WorkerThreadPool(); + + void ScheduleWork(Ref<IWork> Work); + void ScheduleWork(std::function<void()>&& Work); + + [[nodiscard]] size_t PendingWork() const; + +private: + void WorkerThreadFunction(); + + std::vector<std::thread> m_WorkerThreads; + BlockingQueue<Ref<IWork>> m_WorkQueue; +}; + +} // namespace zen diff --git a/src/zencore/include/zencore/xxhash.h b/src/zencore/include/zencore/xxhash.h new file mode 100644 index 000000000..04872f4c3 --- /dev/null +++ b/src/zencore/include/zencore/xxhash.h @@ -0,0 +1,89 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zencore.h" + +#include <zencore/memory.h> + +#include <xxh3.h> + +#include <compare> +#include <string_view> + +namespace zen { + +class StringBuilderBase; + +/** + * XXH3 hash + */ +struct XXH3_128 +{ + uint8_t Hash[16]; + + static XXH3_128 MakeFrom(const void* data /* 16 bytes */) + { + XXH3_128 Xx; + memcpy(Xx.Hash, data, sizeof Xx); + return Xx; + } + + static inline XXH3_128 HashMemory(const void* data, size_t byteCount) + { + XXH3_128 Hash; + XXH128_canonicalFromHash((XXH128_canonical_t*)Hash.Hash, XXH3_128bits(data, byteCount)); + return Hash; + } + static XXH3_128 HashMemory(MemoryView Data) { return HashMemory(Data.GetData(), Data.GetSize()); } + static XXH3_128 FromHexString(const char* string); + static XXH3_128 FromHexString(const std::string_view string); + const char* ToHexString(char* outString /* 32 characters + NUL terminator */) const; + StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const; + + static const int StringLength = 32; + typedef char String_t[StringLength + 1]; + + static XXH3_128 Zero; // Initialized to all zeros + + inline auto operator<=>(const XXH3_128& rhs) const = default; + + struct Hasher + { + size_t operator()(const XXH3_128& v) const + { + size_t h; + memcpy(&h, v.Hash, sizeof h); + return h; + } + }; +}; + +struct XXH3_128Stream +{ + /// Begin streaming hash compute (not needed on freshly constructed instance) + void Reset() { memset(&m_State, 0, sizeof m_State); } + + /// Append another chunk + XXH3_128Stream& Append(const void* Data, size_t ByteCount) + { + XXH3_128bits_update(&m_State, Data, ByteCount); + return *this; + } + + /// Append another chunk + XXH3_128Stream& Append(MemoryView Data) { return Append(Data.GetData(), Data.GetSize()); } + + /// Obtain final hash. If you wish to reuse the instance call reset() + XXH3_128 GetHash() + { + XXH3_128 Hash; + XXH128_canonicalFromHash((XXH128_canonical_t*)Hash.Hash, XXH3_128bits_digest(&m_State)); + return Hash; + } + +private: + XXH3_state_s m_State{}; +}; + +} // namespace zen diff --git a/src/zencore/include/zencore/zencore.h b/src/zencore/include/zencore/zencore.h new file mode 100644 index 000000000..5bcd77239 --- /dev/null +++ b/src/zencore/include/zencore/zencore.h @@ -0,0 +1,383 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <cinttypes> +#include <stdexcept> +#include <string> +#include <version> + +#ifndef ZEN_WITH_TESTS +# define ZEN_WITH_TESTS 1 +#endif + +////////////////////////////////////////////////////////////////////////// +// Platform +// + +#define ZEN_PLATFORM_WINDOWS 0 +#define ZEN_PLATFORM_LINUX 0 +#define ZEN_PLATFORM_MAC 0 + +#ifdef _WIN32 +# undef ZEN_PLATFORM_WINDOWS +# define ZEN_PLATFORM_WINDOWS 1 +#elif defined(__linux__) +# undef ZEN_PLATFORM_LINUX +# define ZEN_PLATFORM_LINUX 1 +#elif defined(__APPLE__) +# undef ZEN_PLATFORM_MAC +# define ZEN_PLATFORM_MAC 1 +#endif + +#if ZEN_PLATFORM_WINDOWS +# if !defined(NOMINMAX) +# define NOMINMAX // stops Windows.h from defining 'min/max' macros +# endif +# if !defined(NOGDI) +# define NOGDI +# endif +# if !defined(WIN32_LEAN_AND_MEAN) +# define WIN32_LEAN_AND_MEAN // cut-down what Windows.h defines +# endif +#endif + +////////////////////////////////////////////////////////////////////////// +// Compiler +// + +#define ZEN_COMPILER_CLANG 0 +#define ZEN_COMPILER_MSC 0 +#define ZEN_COMPILER_GCC 0 + +// Clang can define __GNUC__ and/or _MSC_VER so we check for Clang first +#ifdef __clang__ +# undef ZEN_COMPILER_CLANG +# define ZEN_COMPILER_CLANG 1 +#elif defined(_MSC_VER) +# undef ZEN_COMPILER_MSC +# define ZEN_COMPILER_MSC 1 +#elif defined(__GNUC__) +# undef ZEN_COMPILER_GCC +# define ZEN_COMPILER_GCC 1 +#else +# error Unknown compiler +#endif + +#if ZEN_COMPILER_MSC +# pragma warning(disable : 4324) // warning C4324: '<type>': structure was padded due to alignment specifier +# pragma warning(default : 4668) // warning C4668: 'symbol' is not defined as a preprocessor macro, replacing with '0' for 'directives' +# pragma warning(default : 4100) // warning C4100: 'identifier' : unreferenced formal parameter +#endif + +#ifndef ZEN_THIRD_PARTY_INCLUDES_START +# if ZEN_COMPILER_MSC +# define ZEN_THIRD_PARTY_INCLUDES_START \ + __pragma(warning(push)) __pragma(warning(disable : 4668)) /* C4668: use of undefined preprocessor macro */ \ + __pragma(warning(disable : 4305)) /* C4305: 'if': truncation from 'uint32' to 'bool' */ \ + __pragma(warning(disable : 4267)) /* C4267: '=': conversion from 'size_t' to 'US' */ \ + __pragma(warning(disable : 4127)) /* C4127: conditional expression is constant */ \ + __pragma(warning(disable : 4189)) /* C4189: local variable is initialized but not referenced */ +# elif ZEN_COMPILER_CLANG +# define ZEN_THIRD_PARTY_INCLUDES_START \ + _Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wundef\"") \ + _Pragma("clang diagnostic ignored \"-Wunused-parameter\"") _Pragma("clang diagnostic ignored \"-Wunused-variable\"") +# elif ZEN_COMPILER_GCC +# define ZEN_THIRD_PARTY_INCLUDES_START \ + _Pragma("GCC diagnostic push") /* NB. ignoring -Wundef doesn't work with GCC */ \ + _Pragma("GCC diagnostic ignored \"-Wunused-parameter\"") _Pragma("GCC diagnostic ignored \"-Wunused-variable\"") +# endif +#endif + +#ifndef ZEN_THIRD_PARTY_INCLUDES_END +# if ZEN_COMPILER_MSC +# define ZEN_THIRD_PARTY_INCLUDES_END __pragma(warning(pop)) +# elif ZEN_COMPILER_CLANG +# define ZEN_THIRD_PARTY_INCLUDES_END _Pragma("clang diagnostic pop") +# elif ZEN_COMPILER_GCC +# define ZEN_THIRD_PARTY_INCLUDES_END _Pragma("GCC diagnostic pop") +# endif +#endif + +#if ZEN_COMPILER_MSC +# define ZEN_DEBUG_BREAK() \ + do \ + { \ + __debugbreak(); \ + } while (0) +#else +# define ZEN_DEBUG_BREAK() \ + do \ + { \ + __builtin_trap(); \ + } while (0) +#endif + +////////////////////////////////////////////////////////////////////////// +// C++20 support +// + +// Clang +#if ZEN_COMPILER_CLANG && __clang_major__ < 12 +# error clang-12 onwards is required for C++20 support +#endif + +// GCC +#if ZEN_COMPILER_GCC && __GNUC__ < 11 +# error GCC-11 onwards is required for C++20 support +#endif + +// GNU libstdc++ +#if defined(_GLIBCXX_RELEASE) && _GLIBCXX_RELEASE < 11 +# error GNU libstdc++-11 onwards is required for C++20 support +#endif + +// LLVM libc++ +#if defined(_LIBCPP_VERSION) && _LIBCPP_VERSION < 12000 +# error LLVM libc++-12 onwards is required for C++20 support +#endif + +// At the time of writing only ver >= 13 of LLVM's libc++ has an implementation +// of std::integral. Some platforms like Ubuntu and Mac OS are still on 12. +#if defined(__cpp_lib_concepts) +# include <concepts> +template<class T> +concept Integral = std::integral<T>; +template<class T> +concept SignedIntegral = std::signed_integral<T>; +template<class T> +concept UnsignedIntegral = std::unsigned_integral<T>; +template<class F, class... A> +concept Invocable = std::invocable<F, A...>; +template<class D, class B> +concept DerivedFrom = std::derived_from<D, B>; +#else +template<class T> +concept Integral = std::is_integral_v<T>; +template<class T> +concept SignedIntegral = Integral<T> && std::is_signed_v<T>; +template<class T> +concept UnsignedIntegral = Integral<T> && !std::is_signed_v<T>; +template<class F, class... A> +concept Invocable = requires(F&& f, A&&... a) +{ + std::invoke(std::forward<F>(f), std::forward<A>(a)...); +}; +template<class D, class B> +concept DerivedFrom = std::is_base_of_v<B, D> && std::is_convertible_v<const volatile D*, const volatile B*>; +#endif + +#if defined(__cpp_lib_ranges) +template<typename T> +concept ContiguousRange = std::ranges::contiguous_range<T>; +#else +template<typename T> +concept ContiguousRange = true; +#endif + +////////////////////////////////////////////////////////////////////////// +// Architecture +// + +#if defined(__amd64__) || defined(_M_X64) +# define ZEN_ARCH_X64 1 +# define ZEN_ARCH_ARM64 0 +#elif defined(__arm64__) || defined(_M_ARM64) +# define ZEN_ARCH_X64 0 +# define ZEN_ARCH_ARM64 1 +#else +# error Unknown architecture +#endif + +////////////////////////////////////////////////////////////////////////// +// Build flavor +// + +#ifdef NDEBUG +# define ZEN_BUILD_DEBUG 0 +# define ZEN_BUILD_RELEASE 1 +#else +# define ZEN_BUILD_DEBUG 1 +# define ZEN_BUILD_RELEASE 0 +#endif + +////////////////////////////////////////////////////////////////////////// + +#define ZEN_PLATFORM_SUPPORTS_UNALIGNED_LOADS 1 + +#if defined(__SIZEOF_WCHAR_T__) && __SIZEOF_WCHAR_T__ == 4 +# define ZEN_SIZEOF_WCHAR_T 4 +#else +static_assert(sizeof(wchar_t) == 2, "wchar_t is expected to be two bytes in size"); +# define ZEN_SIZEOF_WCHAR_T 2 +#endif + +////////////////////////////////////////////////////////////////////////// +// Assert +// + +#if ZEN_PLATFORM_WINDOWS +// Tells the compiler to put the decorated function in a certain section (aka. segment) of the executable. +# define ZEN_CODE_SECTION(Name) __declspec(code_seg(Name)) +# define ZEN_FORCENOINLINE __declspec(noinline) /* Force code to NOT be inline */ +# define LINE_TERMINATOR_ANSI "\r\n" +#else +# define ZEN_CODE_SECTION(Name) +# define ZEN_FORCENOINLINE +# define LINE_TERMINATOR_ANSI "\n" +#endif + +#if ZEN_ARCH_ARM64 +// On ARM we can't do this because the executable will require jumps larger +// than the branch instruction can handle. Clang will only generate +// the trampolines in the .text segment of the binary. If the zcold segment +// is present it will generate code that it cannot link. +# define ZEN_DEBUG_SECTION +#else +// We'll put all assert implementation code into a separate section in the linked +// executable. This code should never execute so using a separate section keeps +// it well off the hot path and hopefully out of the instruction cache. It also +// facilitates reasoning about the makeup of a compiled/linked binary. +# define ZEN_DEBUG_SECTION ZEN_CODE_SECTION(".zcold") +#endif + +namespace zen +{ + class AssertException : public std::logic_error + { + public: + AssertException(const char* Msg) : std::logic_error(Msg) {} + }; + + struct AssertImpl + { + static void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION ExecAssert + [[noreturn]] (const char* Filename, int LineNumber, const char* FunctionName, const char* Msg) + { + CurrentAssertImpl->OnAssert(Filename, LineNumber, FunctionName, Msg); + throw AssertException{Msg}; + } + + protected: + virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION OnAssert(const char* Filename, + int LineNumber, + const char* FunctionName, + const char* Msg) + { + (void(Filename)); + (void(LineNumber)); + (void(FunctionName)); + (void(Msg)); + } + static AssertImpl DefaultAssertImpl; + static AssertImpl* CurrentAssertImpl; + }; + +} // namespace zen + +#define ZEN_ASSERT(x, ...) \ + do \ + { \ + if (x) [[unlikely]] \ + break; \ + zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, #x); \ + } while (false) + +#ifndef NDEBUG +# define ZEN_ASSERT_SLOW(x, ...) \ + do \ + { \ + if (x) [[unlikely]] \ + break; \ + zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, #x); \ + } while (false) +#else +# define ZEN_ASSERT_SLOW(x, ...) +#endif + +////////////////////////////////////////////////////////////////////////// + +#ifdef __clang__ +template<typename T> +auto ZenArrayCountHelper(T& t) -> typename std::enable_if<__is_array(T), char (&)[sizeof(t) / sizeof(t[0]) + 1]>::type; +#else +template<typename T, uint32_t N> +char (&ZenArrayCountHelper(const T (&)[N]))[N + 1]; +#endif + +#define ZEN_ARRAY_COUNT(array) (sizeof(ZenArrayCountHelper(array)) - 1) + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_COMPILER_MSC +# define ZEN_NOINLINE __declspec(noinline) +#else +# define ZEN_NOINLINE __attribute__((noinline)) +#endif + +#if ZEN_PLATFORM_WINDOWS +# define ZEN_EXE_SUFFIX_LITERAL ".exe" +#else +# define ZEN_EXE_SUFFIX_LITERAL +#endif + +#define ZEN_UNUSED(...) ((void)__VA_ARGS__) +#define ZEN_NOT_IMPLEMENTED(...) ZEN_ASSERT(false, __VA_ARGS__) +#define ZENCORE_API // Placeholder to allow DLL configs in the future (maybe) + +namespace zen { + +ZENCORE_API bool IsApplicationExitRequested(); +ZENCORE_API void RequestApplicationExit(int ExitCode); +ZENCORE_API bool IsDebuggerPresent(); +ZENCORE_API void SetIsInteractiveSession(bool Value); +ZENCORE_API bool IsInteractiveSession(); + +ZENCORE_API void zencore_forcelinktests(); + +} // namespace zen + +////////////////////////////////////////////////////////////////////////// + +#ifndef ZEN_USE_MIMALLOC +# if ZEN_ARCH_ARM64 + // The vcpkg mimalloc port doesn't support Arm targets +# define ZEN_USE_MIMALLOC 0 +# else +# define ZEN_USE_MIMALLOC 1 +# endif +#endif + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_COMPILER_MSC +# define ZEN_DISABLE_OPTIMIZATION_ACTUAL __pragma(optimize("", off)) +# define ZEN_ENABLE_OPTIMIZATION_ACTUAL __pragma(optimize("", on)) +#elif ZEN_COMPILER_GCC +# define ZEN_DISABLE_OPTIMIZATION_ACTUAL _Pragma("GCC push_options") _Pragma("GCC optimize (\"O0\")") +# define ZEN_ENABLE_OPTIMIZATION_ACTUAL _Pragma("GCC pop_options") +#elif ZEN_COMPILER_CLANG +# define ZEN_DISABLE_OPTIMIZATION_ACTUAL _Pragma("clang optimize off") +# define ZEN_ENABLE_OPTIMIZATION_ACTUAL _Pragma("clang optimize on") +#endif + +// Set up optimization control macros, now that we have both the build settings and the platform macros +#define ZEN_DISABLE_OPTIMIZATION ZEN_DISABLE_OPTIMIZATION_ACTUAL + +#if ZEN_BUILD_DEBUG +# define ZEN_ENABLE_OPTIMIZATION ZEN_DISABLE_OPTIMIZATION_ACTUAL +#else +# define ZEN_ENABLE_OPTIMIZATION ZEN_ENABLE_OPTIMIZATION_ACTUAL +#endif + +#define ZEN_ENABLE_OPTIMIZATION_ALWAYS ZEN_ENABLE_OPTIMIZATION_ACTUAL + +////////////////////////////////////////////////////////////////////////// + +#ifndef ZEN_WITH_TRACE +# define ZEN_WITH_TRACE 0 +#endif + +////////////////////////////////////////////////////////////////////////// + +using ThreadId_t = uint32_t; diff --git a/src/zencore/intmath.cpp b/src/zencore/intmath.cpp new file mode 100644 index 000000000..5a686dc8e --- /dev/null +++ b/src/zencore/intmath.cpp @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/endian.h> +#include <zencore/intmath.h> + +#include <zencore/testing.h> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +void +intmath_forcelink() +{ +} + +TEST_CASE("intmath") +{ + CHECK(FloorLog2(0x00) == 0); + CHECK(FloorLog2(0x01) == 0); + CHECK(FloorLog2(0x0f) == 3); + CHECK(FloorLog2(0x10) == 4); + CHECK(FloorLog2(0x11) == 4); + CHECK(FloorLog2(0x12) == 4); + CHECK(FloorLog2(0x22) == 5); + CHECK(FloorLog2(0x0001'0000) == 16); + CHECK(FloorLog2(0x0001'000f) == 16); + CHECK(FloorLog2(0x8000'0000) == 31); + + CHECK(FloorLog2_64(0x00ull) == 0); + CHECK(FloorLog2_64(0x01ull) == 0); + CHECK(FloorLog2_64(0x0full) == 3); + CHECK(FloorLog2_64(0x10ull) == 4); + CHECK(FloorLog2_64(0x11ull) == 4); + CHECK(FloorLog2_64(0x0001'0000ull) == 16); + CHECK(FloorLog2_64(0x0001'000full) == 16); + CHECK(FloorLog2_64(0x8000'0000ull) == 31); + CHECK(FloorLog2_64(0x0000'0001'0000'0000ull) == 32); + CHECK(FloorLog2_64(0x8000'0000'0000'0000ull) == 63); + + CHECK(CountLeadingZeros64(0x8000'0000'0000'0000ull) == 0); + CHECK(CountLeadingZeros64(0x0000'0000'0000'0000ull) == 64); + CHECK(CountLeadingZeros64(0x0000'0000'0000'0001ull) == 63); + CHECK(CountLeadingZeros64(0x0000'0000'8000'0000ull) == 32); + CHECK(CountLeadingZeros64(0x0000'0001'0000'0000ull) == 31); + + CHECK(CountTrailingZeros64(0x8000'0000'0000'0000ull) == 63); + CHECK(CountTrailingZeros64(0x0000'0000'0000'0000ull) == 64); + CHECK(CountTrailingZeros64(0x0000'0000'0000'0001ull) == 0); + CHECK(CountTrailingZeros64(0x0000'0000'8000'0000ull) == 31); + CHECK(CountTrailingZeros64(0x0000'0001'0000'0000ull) == 32); + + CHECK(ByteSwap(uint16_t(0x6d72)) == 0x726d); + CHECK(ByteSwap(uint32_t(0x2741'3965)) == 0x6539'4127); + CHECK(ByteSwap(uint64_t(0x214d'6172'7469'6e21ull)) == 0x216e'6974'7261'4d21ull); +} + +#endif + +} // namespace zen diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp new file mode 100644 index 000000000..1d7d47695 --- /dev/null +++ b/src/zencore/iobuffer.cpp @@ -0,0 +1,653 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/iobuffer.h> + +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/testing.h> +#include <zencore/thread.h> + +#include <memory.h> +#include <system_error> + +#if ZEN_USE_MIMALLOC +ZEN_THIRD_PARTY_INCLUDES_START +# include <mimalloc.h> +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +#if ZEN_PLATFORM_WINDOWS +# include <atlfile.h> +#else +# include <sys/stat.h> +# include <sys/mman.h> +#endif + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +void +IoBufferCore::AllocateBuffer(size_t InSize, size_t Alignment) const +{ +#if ZEN_PLATFORM_WINDOWS + if (((InSize & 0xffFF) == 0) && (Alignment == 0x10000)) + { + m_Flags.fetch_or(kLowLevelAlloc, std::memory_order_relaxed); + m_DataPtr = VirtualAlloc(nullptr, InSize, MEM_COMMIT, PAGE_READWRITE); + + return; + } +#endif // ZEN_PLATFORM_WINDOWS + +#if ZEN_USE_MIMALLOC + void* Ptr = mi_aligned_alloc(Alignment, RoundUp(InSize, Alignment)); + m_Flags.fetch_or(kIoBufferAlloc, std::memory_order_relaxed); +#else + void* Ptr = Memory::Alloc(InSize, Alignment); +#endif + + ZEN_ASSERT(Ptr); + + m_DataPtr = Ptr; +} + +void +IoBufferCore::FreeBuffer() +{ + if (!m_DataPtr) + { + return; + } + + const uint32_t LocalFlags = m_Flags.load(std::memory_order_relaxed); +#if ZEN_PLATFORM_WINDOWS + if (LocalFlags & kLowLevelAlloc) + { + VirtualFree(const_cast<void*>(m_DataPtr), 0, MEM_DECOMMIT); + + return; + } +#endif // ZEN_PLATFORM_WINDOWS + +#if ZEN_USE_MIMALLOC + if (LocalFlags & kIoBufferAlloc) + { + return mi_free(const_cast<void*>(m_DataPtr)); + } +#endif + + ZEN_UNUSED(LocalFlags); + return Memory::Free(const_cast<void*>(m_DataPtr)); +} + +////////////////////////////////////////////////////////////////////////// + +static_assert(sizeof(IoBufferCore) == 32); + +IoBufferCore::IoBufferCore(size_t InSize) +{ + ZEN_ASSERT(InSize); + + AllocateBuffer(InSize, sizeof(void*)); + m_DataBytes = InSize; + + SetIsOwnedByThis(true); +} + +IoBufferCore::IoBufferCore(size_t InSize, size_t Alignment) +{ + ZEN_ASSERT(InSize); + + AllocateBuffer(InSize, Alignment); + m_DataBytes = InSize; + + SetIsOwnedByThis(true); +} + +IoBufferCore::~IoBufferCore() +{ + if (IsOwnedByThis() && m_DataPtr) + { + FreeBuffer(); + m_DataPtr = nullptr; + } +} + +void +IoBufferCore::DeleteThis() const +{ + // We do this just to avoid paying for the cost of a vtable + if (const IoBufferExtendedCore* _ = ExtendedCore()) + { + delete _; + } + else + { + delete this; + } +} + +void +IoBufferCore::Materialize() const +{ + if (const IoBufferExtendedCore* _ = ExtendedCore()) + { + _->Materialize(); + } +} + +void +IoBufferCore::MakeOwned(bool Immutable) +{ + if (!IsOwned()) + { + const void* OldDataPtr = m_DataPtr; + AllocateBuffer(m_DataBytes, sizeof(void*)); + memcpy(const_cast<void*>(m_DataPtr), OldDataPtr, m_DataBytes); + SetIsOwnedByThis(true); + } + + SetIsImmutable(Immutable); +} + +void* +IoBufferCore::MutableDataPointer() const +{ + EnsureDataValid(); + ZEN_ASSERT(!IsImmutable()); + return const_cast<void*>(m_DataPtr); +} + +////////////////////////////////////////////////////////////////////////// + +IoBufferExtendedCore::IoBufferExtendedCore(void* FileHandle, uint64_t Offset, uint64_t Size, bool TransferHandleOwnership) +: IoBufferCore(nullptr, Size) +, m_FileHandle(FileHandle) +, m_FileOffset(Offset) +{ + uint32_t NewFlags = kIsOwnedByThis | kIsExtended; + + if (TransferHandleOwnership) + { + NewFlags |= kOwnsFile; + } + m_Flags.fetch_or(NewFlags, std::memory_order_relaxed); +} + +IoBufferExtendedCore::IoBufferExtendedCore(const IoBufferExtendedCore* Outer, uint64_t Offset, uint64_t Size) +: IoBufferCore(Outer, nullptr, Size) +, m_FileHandle(Outer->m_FileHandle) +, m_FileOffset(Outer->m_FileOffset + Offset) +{ + m_Flags.fetch_or(kIsExtended, std::memory_order_relaxed); +} + +IoBufferExtendedCore::~IoBufferExtendedCore() +{ + if (m_MappedPointer) + { +#if ZEN_PLATFORM_WINDOWS + UnmapViewOfFile(m_MappedPointer); +#else + uint64_t MapSize = ~uint64_t(uintptr_t(m_MmapHandle)); + munmap(m_MappedPointer, MapSize); +#endif + } + + const uint32_t LocalFlags = m_Flags.load(std::memory_order_relaxed); +#if ZEN_PLATFORM_WINDOWS + if (LocalFlags & kOwnsMmap) + { + CloseHandle(m_MmapHandle); + } +#endif + + if (LocalFlags & kOwnsFile) + { + if (m_DeleteOnClose) + { +#if ZEN_PLATFORM_WINDOWS + // Mark file for deletion when final handle is closed + FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE}; + + SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); +#else + std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle); + unlink(FilePath.c_str()); +#endif + } +#if ZEN_PLATFORM_WINDOWS + BOOL Success = CloseHandle(m_FileHandle); +#else + int Fd = int(uintptr_t(m_FileHandle)); + bool Success = (close(Fd) == 0); +#endif + if (!Success) + { + ZEN_WARN("Error reported on file handle close, reason '{}'", GetLastErrorAsString()); + } + } + + m_DataPtr = nullptr; +} + +static constexpr size_t MappingLockCount = 128; +static_assert(IsPow2(MappingLockCount), "MappingLockCount must be power of two"); + +static RwLock g_MappingLocks[MappingLockCount]; + +static RwLock& +MappingLockForInstance(const IoBufferExtendedCore* instance) +{ + intptr_t base = (intptr_t)instance; + size_t lock_index = ((base >> 5) ^ (base >> 13)) & (MappingLockCount - 1u); + return g_MappingLocks[lock_index]; +} + +void +IoBufferExtendedCore::Materialize() const +{ + // The synchronization scheme here is very primitive, if we end up with + // a lot of contention we can make it more fine-grained + + if (m_Flags.load(std::memory_order_acquire) & kIsMaterialized) + return; + + RwLock::ExclusiveLockScope _(MappingLockForInstance(this)); + + // Someone could have gotten here first + // We can use memory_order_relaxed on this load because the mutex has already provided the fence + if (m_Flags.load(std::memory_order_relaxed) & kIsMaterialized) + return; + + uint32_t NewFlags = kIsMaterialized; + + if (m_DataBytes == 0) + { + // Fake a "valid" pointer, nobody should read this as size is zero + m_DataPtr = reinterpret_cast<uint8_t*>(&m_MmapHandle); + m_Flags.fetch_or(NewFlags, std::memory_order_release); + return; + } + + const size_t DisableMMapSizeLimit = 0x1000ull; + + if (m_DataBytes < DisableMMapSizeLimit) + { + AllocateBuffer(m_DataBytes, sizeof(void*)); + NewFlags |= kIsOwnedByThis; + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(m_FileOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(m_FileOffset >> 32); + + DWORD dwNumberOfBytesRead = 0; + BOOL Success = ::ReadFile(m_FileHandle, (void*)m_DataPtr, DWORD(m_DataBytes), &dwNumberOfBytesRead, &Ovl); + + ZEN_ASSERT(Success); + ZEN_ASSERT(dwNumberOfBytesRead == m_DataBytes); +#else + static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); + int Fd = int(uintptr_t(m_FileHandle)); + int BytesRead = pread(Fd, (void*)m_DataPtr, m_DataBytes, m_FileOffset); + bool Success = (BytesRead > 0); +#endif // ZEN_PLATFORM_WINDOWS + + m_Flags.fetch_or(NewFlags, std::memory_order_release); + return; + } + + void* NewMmapHandle; + + const uint64_t MapOffset = m_FileOffset & ~0xffffull; + const uint64_t MappedOffsetDisplacement = m_FileOffset - MapOffset; + const uint64_t MapSize = m_DataBytes + MappedOffsetDisplacement; + + ZEN_ASSERT(MapSize > 0); + +#if ZEN_PLATFORM_WINDOWS + NewMmapHandle = CreateFileMapping(m_FileHandle, + /* lpFileMappingAttributes */ nullptr, + /* flProtect */ PAGE_READONLY, + /* dwMaximumSizeLow */ 0, + /* dwMaximumSizeHigh */ 0, + /* lpName */ nullptr); + + if (NewMmapHandle == nullptr) + { + int32_t Error = zen::GetLastError(); + ZEN_ERROR("CreateFileMapping failed on file '{}', {}", zen::PathFromHandle(m_FileHandle), GetSystemErrorAsString(Error)); + throw std::system_error(std::error_code(Error, std::system_category()), + fmt::format("CreateFileMapping failed on file '{}'", zen::PathFromHandle(m_FileHandle))); + } + + NewFlags |= kOwnsMmap; + + void* MappedBase = MapViewOfFile(NewMmapHandle, + /* dwDesiredAccess */ FILE_MAP_READ, + /* FileOffsetHigh */ uint32_t(MapOffset >> 32), + /* FileOffsetLow */ uint32_t(MapOffset & 0xffFFffFFu), + /* dwNumberOfBytesToMap */ MapSize); +#else + NewMmapHandle = (void*)uintptr_t(~MapSize); // ~ so it's never null (assuming MapSize >= 0) + NewFlags |= kOwnsMmap; + + void* MappedBase = mmap( + /* addr */ nullptr, + /* length */ MapSize, + /* prot */ PROT_READ, + /* flags */ MAP_SHARED | MAP_NORESERVE, + /* fd */ int(uintptr_t(m_FileHandle)), + /* offset */ MapOffset); +#endif // ZEN_PLATFORM_WINDOWS + + if (MappedBase == nullptr) + { + int32_t Error = zen::GetLastError(); +#if ZEN_PLATFORM_WINDOWS + CloseHandle(NewMmapHandle); +#endif // ZEN_PLATFORM_WINDOWS + ZEN_ERROR("MapViewOfFile failed (offset {:#x}, size {:#x}) file: '{}', {}", + MapOffset, + MapSize, + zen::PathFromHandle(m_FileHandle), + GetSystemErrorAsString(Error)); + throw std::system_error(std::error_code(Error, std::system_category()), + fmt::format("MapViewOfFile failed (offset {:#x}, size {:#x}) file: '{}'", + MapOffset, + MapSize, + zen::PathFromHandle(m_FileHandle))); + } + + m_MappedPointer = MappedBase; + m_DataPtr = reinterpret_cast<uint8_t*>(MappedBase) + MappedOffsetDisplacement; + m_MmapHandle = NewMmapHandle; + + m_Flags.fetch_or(NewFlags, std::memory_order_release); +} + +bool +IoBufferExtendedCore::GetFileReference(IoBufferFileReference& OutRef) const +{ + if (m_FileHandle == nullptr) + { + return false; + } + + OutRef.FileHandle = m_FileHandle; + OutRef.FileChunkOffset = m_FileOffset; + OutRef.FileChunkSize = m_DataBytes; + + return true; +} + +void +IoBufferExtendedCore::MarkAsDeleteOnClose() +{ + m_DeleteOnClose = true; +} + +////////////////////////////////////////////////////////////////////////// + +IoBuffer::IoBuffer(size_t InSize) : m_Core(new IoBufferCore(InSize)) +{ + m_Core->SetIsImmutable(false); +} + +IoBuffer::IoBuffer(size_t InSize, uint64_t InAlignment) : m_Core(new IoBufferCore(InSize, InAlignment)) +{ + m_Core->SetIsImmutable(false); +} + +IoBuffer::IoBuffer(const IoBuffer& OuterBuffer, size_t Offset, size_t Size) +{ + if (Size == ~(0ull)) + { + Size = std::clamp<size_t>(Size, 0, OuterBuffer.Size() - Offset); + } + + ZEN_ASSERT(Offset <= OuterBuffer.Size()); + ZEN_ASSERT((Offset + Size) <= OuterBuffer.Size()); + + if (IoBufferExtendedCore* Extended = OuterBuffer.m_Core->ExtendedCore()) + { + m_Core = new IoBufferExtendedCore(Extended, Offset, Size); + } + else + { + m_Core = new IoBufferCore(OuterBuffer.m_Core, reinterpret_cast<const uint8_t*>(OuterBuffer.Data()) + Offset, Size); + } +} + +IoBuffer::IoBuffer(EFileTag, void* FileHandle, uint64_t ChunkFileOffset, uint64_t ChunkSize) +: m_Core(new IoBufferExtendedCore(FileHandle, ChunkFileOffset, ChunkSize, /* owned */ true)) +{ +} + +IoBuffer::IoBuffer(EBorrowedFileTag, void* FileHandle, uint64_t ChunkFileOffset, uint64_t ChunkSize) +: m_Core(new IoBufferExtendedCore(FileHandle, ChunkFileOffset, ChunkSize, /* owned */ false)) +{ +} + +bool +IoBuffer::GetFileReference(IoBufferFileReference& OutRef) const +{ + if (IoBufferExtendedCore* ExtCore = m_Core->ExtendedCore()) + { + if (ExtCore->GetFileReference(OutRef)) + { + return true; + } + } + + // Not a file reference + + OutRef.FileHandle = 0; + OutRef.FileChunkOffset = ~0ull; + OutRef.FileChunkSize = 0; + + return false; +} + +void +IoBuffer::MarkAsDeleteOnClose() +{ + if (IoBufferExtendedCore* ExtCore = m_Core->ExtendedCore()) + { + ExtCore->MarkAsDeleteOnClose(); + } +} + +////////////////////////////////////////////////////////////////////////// + +IoBuffer +IoBufferBuilder::ReadFromFileMaybe(IoBuffer& InBuffer) +{ + IoBufferFileReference FileRef; + if (InBuffer.GetFileReference(/* out */ FileRef)) + { + IoBuffer OutBuffer(FileRef.FileChunkSize); + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + const uint64_t NumberOfBytesToRead = FileRef.FileChunkSize; + const uint64_t& FileOffset = FileRef.FileChunkOffset; + + Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(FileOffset >> 32); + + DWORD dwNumberOfBytesRead = 0; + BOOL Success = ::ReadFile(FileRef.FileHandle, OutBuffer.MutableData(), DWORD(NumberOfBytesToRead), &dwNumberOfBytesRead, &Ovl); +#else + int Fd = int(intptr_t(FileRef.FileHandle)); + int Result = pread(Fd, OutBuffer.MutableData(), size_t(FileRef.FileChunkSize), off_t(FileRef.FileChunkOffset)); + bool Success = (Result < 0); + + uint32_t dwNumberOfBytesRead = uint32_t(Result); +#endif + + if (!Success) + { + ThrowLastError("ReadFile failed in IoBufferBuilder::ReadFromFileMaybe"); + } + + ZEN_ASSERT(dwNumberOfBytesRead == FileRef.FileChunkSize); + + return OutBuffer; + } + else + { + return InBuffer; + } +} + +IoBuffer +IoBufferBuilder::MakeFromFileHandle(void* FileHandle, uint64_t Offset, uint64_t Size) +{ + return IoBuffer(IoBuffer::BorrowedFile, FileHandle, Offset, Size); +} + +IoBuffer +IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset, uint64_t Size) +{ + uint64_t FileSize; + +#if ZEN_PLATFORM_WINDOWS + CAtlFile DataFile; + + DWORD ShareOptions = FILE_SHARE_DELETE | FILE_SHARE_WRITE | FILE_SHARE_DELETE | FILE_SHARE_READ; + HRESULT hRes = DataFile.Create(FileName.c_str(), GENERIC_READ, ShareOptions, OPEN_EXISTING); + + if (FAILED(hRes)) + { + return {}; + } + + DataFile.GetSize((ULONGLONG&)FileSize); +#else + int Flags = O_RDONLY | O_CLOEXEC; + int Fd = open(FileName.c_str(), Flags); + if (Fd < 0) + { + return {}; + } + + static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files"); + struct stat Stat; + fstat(Fd, &Stat); + FileSize = Stat.st_size; +#endif // ZEN_PLATFORM_WINDOWS + + // TODO: should validate that offset is in range + + if (Size == ~0ull) + { + Size = FileSize - Offset; + } + else + { + // Clamp size + if ((Offset + Size) > FileSize) + { + Size = FileSize - Offset; + } + } + + if (Size) + { +#if ZEN_PLATFORM_WINDOWS + void* Fd = DataFile.Detach(); +#endif + IoBuffer Iob(IoBuffer::File, (void*)uintptr_t(Fd), Offset, Size); + Iob.m_Core->SetIsWholeFile(Offset == 0 && Size == FileSize); + return Iob; + } + +#if !ZEN_PLATFORM_WINDOWS + close(Fd); +#endif + + // For an empty file, we may as well just return an empty memory IoBuffer + return IoBuffer(IoBuffer::Wrap, "", 0); +} + +IoBuffer +IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName) +{ + uint64_t FileSize; + void* Handle; + +#if ZEN_PLATFORM_WINDOWS + CAtlFile DataFile; + + // We need to open with DELETE since this is used for the case + // when a file has been written to a staging directory, and is going + // to be moved in place + + HRESULT hRes = DataFile.Create(FileName.native().c_str(), GENERIC_READ | DELETE, FILE_SHARE_READ | FILE_SHARE_DELETE, OPEN_EXISTING); + + if (FAILED(hRes)) + { + return {}; + } + + DataFile.GetSize((ULONGLONG&)FileSize); + + Handle = DataFile.Detach(); +#else + int Fd = open(FileName.native().c_str(), O_RDONLY); + if (Fd < 0) + { + return {}; + } + + static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files"); + struct stat Stat; + fstat(Fd, &Stat); + FileSize = Stat.st_size; + + Handle = (void*)uintptr_t(Fd); +#endif // ZEN_PLATFORM_WINDOWS + + IoBuffer Iob(IoBuffer::File, Handle, 0, FileSize); + Iob.m_Core->SetIsWholeFile(true); + + return Iob; +} + +IoHash +HashBuffer(IoBuffer& Buffer) +{ + // TODO: handle disk buffers with special path + return IoHash::HashBuffer(Buffer.Data(), Buffer.Size()); +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +void +iobuffer_forcelink() +{ +} + +TEST_CASE("IoBuffer") +{ + zen::IoBuffer buffer1; + zen::IoBuffer buffer2(16384); + zen::IoBuffer buffer3(buffer2, 0, buffer2.Size()); +} + +#endif + +} // namespace zen diff --git a/src/zencore/iohash.cpp b/src/zencore/iohash.cpp new file mode 100644 index 000000000..77076c133 --- /dev/null +++ b/src/zencore/iohash.cpp @@ -0,0 +1,87 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/iohash.h> + +#include <zencore/blake3.h> +#include <zencore/compositebuffer.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +const IoHash IoHash::Zero{}; // Initialized to all zeros + +IoHash +IoHash::HashBuffer(const void* data, size_t byteCount) +{ + BLAKE3 b3 = BLAKE3::HashMemory(data, byteCount); + + IoHash io; + memcpy(io.Hash, b3.Hash, sizeof io.Hash); + + return io; +} + +IoHash +IoHash::HashBuffer(const CompositeBuffer& Buffer) +{ + IoHashStream Hasher; + + for (const SharedBuffer& Segment : Buffer.GetSegments()) + { + Hasher.Append(Segment.GetData(), Segment.GetSize()); + } + + return Hasher.GetHash(); +} + +IoHash +IoHash::FromHexString(const char* string) +{ + return FromHexString({string, sizeof(IoHash::Hash) * 2}); +} + +IoHash +IoHash::FromHexString(std::string_view string) +{ + ZEN_ASSERT(string.size() == 2 * sizeof(IoHash::Hash)); + + IoHash io; + + ParseHexBytes(string.data(), string.size(), io.Hash); + + return io; +} + +const char* +IoHash::ToHexString(char* outString /* 40 characters + NUL terminator */) const +{ + ToHexBytes(Hash, sizeof(IoHash), outString); + outString[2 * sizeof(IoHash)] = '\0'; + + return outString; +} + +StringBuilderBase& +IoHash::ToHexString(StringBuilderBase& outBuilder) const +{ + String_t Str; + ToHexString(Str); + + outBuilder.AppendRange(Str, &Str[StringLength]); + + return outBuilder; +} + +std::string +IoHash::ToHexString() const +{ + String_t Str; + ToHexString(Str); + + return Str; +} + +} // namespace zen diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp new file mode 100644 index 000000000..a6423e2dc --- /dev/null +++ b/src/zencore/logging.cpp @@ -0,0 +1,85 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/logging.h" + +#include <spdlog/sinks/stdout_color_sinks.h> + +namespace zen { + +// We shadow the underlying spdlog default logger, in order to avoid a bunch of overhead +spdlog::logger* TheDefaultLogger; + +} // namespace zen + +namespace zen::logging { + +spdlog::logger& +Default() +{ + return *TheDefaultLogger; +} + +void +SetDefault(std::shared_ptr<spdlog::logger> NewDefaultLogger) +{ + spdlog::set_default_logger(NewDefaultLogger); + TheDefaultLogger = spdlog::default_logger_raw(); +} + +spdlog::logger& +Get(std::string_view Name) +{ + std::shared_ptr<spdlog::logger> Logger = spdlog::get(std::string(Name)); + + if (!Logger) + { + Logger = Default().clone(std::string(Name)); + spdlog::register_logger(Logger); + } + + return *Logger; +} + +std::once_flag ConsoleInitFlag; +std::shared_ptr<spdlog::logger> ConLogger; + +spdlog::logger& +ConsoleLog() +{ + std::call_once(ConsoleInitFlag, [&] { + ConLogger = spdlog::stdout_color_mt("console"); + + ConLogger->set_pattern("%v"); + }); + + return *ConLogger; +} + +std::shared_ptr<spdlog::logger> TheErrorLogger; + +spdlog::logger* +ErrorLog() +{ + return TheErrorLogger.get(); +} + +void +SetErrorLog(std::shared_ptr<spdlog::logger>&& NewErrorLogger) +{ + TheErrorLogger = std::move(NewErrorLogger); +} + +void +InitializeLogging() +{ + TheDefaultLogger = spdlog::default_logger_raw(); +} + +void +ShutdownLogging() +{ + spdlog::drop_all(); + spdlog::shutdown(); +} + +} // namespace zen::logging diff --git a/src/zencore/md5.cpp b/src/zencore/md5.cpp new file mode 100644 index 000000000..4ec145697 --- /dev/null +++ b/src/zencore/md5.cpp @@ -0,0 +1,463 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/md5.h> +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/zencore.h> + +#include <string.h> +#include <string_view> + +/* + ********************************************************************** + ** md5.h -- Header file for implementation of MD5 ** + ** RSA Data Security, Inc. MD5 Message Digest Algorithm ** + ** Created: 2/17/90 RLR ** + ** Revised: 12/27/90 SRD,AJ,BSK,JT Reference C version ** + ** Revised (for MD5): RLR 4/27/91 ** + ** -- G modified to have y&~z instead of y&z ** + ** -- FF, GG, HH modified to add in last register done ** + ** -- Access pattern: round 2 works mod 5, round 3 works mod 3 ** + ** -- distinct additive constant for each step ** + ** -- round 4 added, working mod 7 ** + ********************************************************************** + */ + +/* + ********************************************************************** + ** Copyright (C) 1990, RSA Data Security, Inc. All rights reserved. ** + ** ** + ** License to copy and use this software is granted provided that ** + ** it is identified as the "RSA Data Security, Inc. MD5 Message ** + ** Digest Algorithm" in all material mentioning or referencing this ** + ** software or this function. ** + ** ** + ** License is also granted to make and use derivative works ** + ** provided that such works are identified as "derived from the RSA ** + ** Data Security, Inc. MD5 Message Digest Algorithm" in all ** + ** material mentioning or referencing the derived work. ** + ** ** + ** RSA Data Security, Inc. makes no representations concerning ** + ** either the merchantability of this software or the suitability ** + ** of this software for any particular purpose. It is provided "as ** + ** is" without express or implied warranty of any kind. ** + ** ** + ** These notices must be retained in any copies of any part of this ** + ** documentation and/or software. ** + ********************************************************************** + */ + +/* Data structure for MD5 (Message Digest) computation */ +struct MD5_CTX +{ + uint32_t i[2]; /* number of _bits_ handled mod 2^64 */ + uint32_t buf[4]; /* scratch buffer */ + unsigned char in[64]; /* input buffer */ + unsigned char digest[16]; /* actual digest after MD5Final call */ +}; + +void MD5Init(); +void MD5Update(); +void MD5Final(); + +/* + ********************************************************************** + ** End of md5.h ** + ******************************* (cut) ******************************** + */ + +/* + ********************************************************************** + ** md5.c ** + ** RSA Data Security, Inc. MD5 Message Digest Algorithm ** + ** Created: 2/17/90 RLR ** + ** Revised: 1/91 SRD,AJ,BSK,JT Reference C Version ** + ********************************************************************** + */ + +/* + ********************************************************************** + ** Copyright (C) 1990, RSA Data Security, Inc. All rights reserved. ** + ** ** + ** License to copy and use this software is granted provided that ** + ** it is identified as the "RSA Data Security, Inc. MD5 Message ** + ** Digest Algorithm" in all material mentioning or referencing this ** + ** software or this function. ** + ** ** + ** License is also granted to make and use derivative works ** + ** provided that such works are identified as "derived from the RSA ** + ** Data Security, Inc. MD5 Message Digest Algorithm" in all ** + ** material mentioning or referencing the derived work. ** + ** ** + ** RSA Data Security, Inc. makes no representations concerning ** + ** either the merchantability of this software or the suitability ** + ** of this software for any particular purpose. It is provided "as ** + ** is" without express or implied warranty of any kind. ** + ** ** + ** These notices must be retained in any copies of any part of this ** + ** documentation and/or software. ** + ********************************************************************** + */ + +/* -- include the following line if the md5.h header file is separate -- */ +/* #include "md5.h" */ + +/* forward declaration */ +static void Transform(uint32_t* buf, uint32_t* in); + +static unsigned char PADDING[64] = {0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + +/* F, G and H are basic MD5 functions: selection, majority, parity */ +#define F(x, y, z) (((x) & (y)) | ((~x) & (z))) +#define G(x, y, z) (((x) & (z)) | ((y) & (~z))) +#define H(x, y, z) ((x) ^ (y) ^ (z)) +#define I(x, y, z) ((y) ^ ((x) | (~z))) + +/* ROTATE_LEFT rotates x left n bits */ +#define ROTATE_LEFT(x, n) (((x) << (n)) | ((x) >> (32 - (n)))) + +/* FF, GG, HH, and II transformations for rounds 1, 2, 3, and 4 */ +/* Rotation is separate from addition to prevent recomputation */ +#define FF(a, b, c, d, x, s, ac) \ + { \ + (a) += F((b), (c), (d)) + (x) + (uint32_t)(ac); \ + (a) = ROTATE_LEFT((a), (s)); \ + (a) += (b); \ + } +#define GG(a, b, c, d, x, s, ac) \ + { \ + (a) += G((b), (c), (d)) + (x) + (uint32_t)(ac); \ + (a) = ROTATE_LEFT((a), (s)); \ + (a) += (b); \ + } +#define HH(a, b, c, d, x, s, ac) \ + { \ + (a) += H((b), (c), (d)) + (x) + (uint32_t)(ac); \ + (a) = ROTATE_LEFT((a), (s)); \ + (a) += (b); \ + } +#define II(a, b, c, d, x, s, ac) \ + { \ + (a) += I((b), (c), (d)) + (x) + (uint32_t)(ac); \ + (a) = ROTATE_LEFT((a), (s)); \ + (a) += (b); \ + } + +void +MD5Init(MD5_CTX* mdContext) +{ + mdContext->i[0] = mdContext->i[1] = (uint32_t)0; + + /* Load magic initialization constants. + */ + mdContext->buf[0] = (uint32_t)0x67452301; + mdContext->buf[1] = (uint32_t)0xefcdab89; + mdContext->buf[2] = (uint32_t)0x98badcfe; + mdContext->buf[3] = (uint32_t)0x10325476; +} + +void +MD5Update(MD5_CTX* mdContext, unsigned char* inBuf, unsigned int inLen) +{ + uint32_t in[16]; + int mdi; + unsigned int i, ii; + + /* compute number of bytes mod 64 */ + mdi = (int)((mdContext->i[0] >> 3) & 0x3F); + + /* update number of bits */ + if ((mdContext->i[0] + ((uint32_t)inLen << 3)) < mdContext->i[0]) + mdContext->i[1]++; + mdContext->i[0] += ((uint32_t)inLen << 3); + mdContext->i[1] += ((uint32_t)inLen >> 29); + + while (inLen--) + { + /* add new character to buffer, increment mdi */ + mdContext->in[mdi++] = *inBuf++; + + /* transform if necessary */ + if (mdi == 0x40) + { + for (i = 0, ii = 0; i < 16; i++, ii += 4) + in[i] = (((uint32_t)mdContext->in[ii + 3]) << 24) | (((uint32_t)mdContext->in[ii + 2]) << 16) | + (((uint32_t)mdContext->in[ii + 1]) << 8) | ((uint32_t)mdContext->in[ii]); + Transform(mdContext->buf, in); + mdi = 0; + } + } +} + +void +MD5Final(MD5_CTX* mdContext) +{ + uint32_t in[16]; + int mdi; + unsigned int i, ii; + unsigned int padLen; + + /* save number of bits */ + in[14] = mdContext->i[0]; + in[15] = mdContext->i[1]; + + /* compute number of bytes mod 64 */ + mdi = (int)((mdContext->i[0] >> 3) & 0x3F); + + /* pad out to 56 mod 64 */ + padLen = (mdi < 56) ? (56 - mdi) : (120 - mdi); + MD5Update(mdContext, PADDING, padLen); + + /* append length in bits and transform */ + for (i = 0, ii = 0; i < 14; i++, ii += 4) + in[i] = (((uint32_t)mdContext->in[ii + 3]) << 24) | (((uint32_t)mdContext->in[ii + 2]) << 16) | + (((uint32_t)mdContext->in[ii + 1]) << 8) | ((uint32_t)mdContext->in[ii]); + Transform(mdContext->buf, in); + + /* store buffer in digest */ + for (i = 0, ii = 0; i < 4; i++, ii += 4) + { + mdContext->digest[ii] = (unsigned char)(mdContext->buf[i] & 0xFF); + mdContext->digest[ii + 1] = (unsigned char)((mdContext->buf[i] >> 8) & 0xFF); + mdContext->digest[ii + 2] = (unsigned char)((mdContext->buf[i] >> 16) & 0xFF); + mdContext->digest[ii + 3] = (unsigned char)((mdContext->buf[i] >> 24) & 0xFF); + } +} + +/* Basic MD5 step. Transform buf based on in. + */ +static void +Transform(uint32_t* buf, uint32_t* in) +{ + uint32_t a = buf[0], b = buf[1], c = buf[2], d = buf[3]; + + /* Round 1 */ +#define S11 7 +#define S12 12 +#define S13 17 +#define S14 22 + FF(a, b, c, d, in[0], S11, 3614090360); /* 1 */ + FF(d, a, b, c, in[1], S12, 3905402710); /* 2 */ + FF(c, d, a, b, in[2], S13, 606105819); /* 3 */ + FF(b, c, d, a, in[3], S14, 3250441966); /* 4 */ + FF(a, b, c, d, in[4], S11, 4118548399); /* 5 */ + FF(d, a, b, c, in[5], S12, 1200080426); /* 6 */ + FF(c, d, a, b, in[6], S13, 2821735955); /* 7 */ + FF(b, c, d, a, in[7], S14, 4249261313); /* 8 */ + FF(a, b, c, d, in[8], S11, 1770035416); /* 9 */ + FF(d, a, b, c, in[9], S12, 2336552879); /* 10 */ + FF(c, d, a, b, in[10], S13, 4294925233); /* 11 */ + FF(b, c, d, a, in[11], S14, 2304563134); /* 12 */ + FF(a, b, c, d, in[12], S11, 1804603682); /* 13 */ + FF(d, a, b, c, in[13], S12, 4254626195); /* 14 */ + FF(c, d, a, b, in[14], S13, 2792965006); /* 15 */ + FF(b, c, d, a, in[15], S14, 1236535329); /* 16 */ + + /* Round 2 */ +#define S21 5 +#define S22 9 +#define S23 14 +#define S24 20 + GG(a, b, c, d, in[1], S21, 4129170786); /* 17 */ + GG(d, a, b, c, in[6], S22, 3225465664); /* 18 */ + GG(c, d, a, b, in[11], S23, 643717713); /* 19 */ + GG(b, c, d, a, in[0], S24, 3921069994); /* 20 */ + GG(a, b, c, d, in[5], S21, 3593408605); /* 21 */ + GG(d, a, b, c, in[10], S22, 38016083); /* 22 */ + GG(c, d, a, b, in[15], S23, 3634488961); /* 23 */ + GG(b, c, d, a, in[4], S24, 3889429448); /* 24 */ + GG(a, b, c, d, in[9], S21, 568446438); /* 25 */ + GG(d, a, b, c, in[14], S22, 3275163606); /* 26 */ + GG(c, d, a, b, in[3], S23, 4107603335); /* 27 */ + GG(b, c, d, a, in[8], S24, 1163531501); /* 28 */ + GG(a, b, c, d, in[13], S21, 2850285829); /* 29 */ + GG(d, a, b, c, in[2], S22, 4243563512); /* 30 */ + GG(c, d, a, b, in[7], S23, 1735328473); /* 31 */ + GG(b, c, d, a, in[12], S24, 2368359562); /* 32 */ + + /* Round 3 */ +#define S31 4 +#define S32 11 +#define S33 16 +#define S34 23 + HH(a, b, c, d, in[5], S31, 4294588738); /* 33 */ + HH(d, a, b, c, in[8], S32, 2272392833); /* 34 */ + HH(c, d, a, b, in[11], S33, 1839030562); /* 35 */ + HH(b, c, d, a, in[14], S34, 4259657740); /* 36 */ + HH(a, b, c, d, in[1], S31, 2763975236); /* 37 */ + HH(d, a, b, c, in[4], S32, 1272893353); /* 38 */ + HH(c, d, a, b, in[7], S33, 4139469664); /* 39 */ + HH(b, c, d, a, in[10], S34, 3200236656); /* 40 */ + HH(a, b, c, d, in[13], S31, 681279174); /* 41 */ + HH(d, a, b, c, in[0], S32, 3936430074); /* 42 */ + HH(c, d, a, b, in[3], S33, 3572445317); /* 43 */ + HH(b, c, d, a, in[6], S34, 76029189); /* 44 */ + HH(a, b, c, d, in[9], S31, 3654602809); /* 45 */ + HH(d, a, b, c, in[12], S32, 3873151461); /* 46 */ + HH(c, d, a, b, in[15], S33, 530742520); /* 47 */ + HH(b, c, d, a, in[2], S34, 3299628645); /* 48 */ + + /* Round 4 */ +#define S41 6 +#define S42 10 +#define S43 15 +#define S44 21 + II(a, b, c, d, in[0], S41, 4096336452); /* 49 */ + II(d, a, b, c, in[7], S42, 1126891415); /* 50 */ + II(c, d, a, b, in[14], S43, 2878612391); /* 51 */ + II(b, c, d, a, in[5], S44, 4237533241); /* 52 */ + II(a, b, c, d, in[12], S41, 1700485571); /* 53 */ + II(d, a, b, c, in[3], S42, 2399980690); /* 54 */ + II(c, d, a, b, in[10], S43, 4293915773); /* 55 */ + II(b, c, d, a, in[1], S44, 2240044497); /* 56 */ + II(a, b, c, d, in[8], S41, 1873313359); /* 57 */ + II(d, a, b, c, in[15], S42, 4264355552); /* 58 */ + II(c, d, a, b, in[6], S43, 2734768916); /* 59 */ + II(b, c, d, a, in[13], S44, 1309151649); /* 60 */ + II(a, b, c, d, in[4], S41, 4149444226); /* 61 */ + II(d, a, b, c, in[11], S42, 3174756917); /* 62 */ + II(c, d, a, b, in[2], S43, 718787259); /* 63 */ + II(b, c, d, a, in[9], S44, 3951481745); /* 64 */ + + buf[0] += a; + buf[1] += b; + buf[2] += c; + buf[3] += d; +} + +/* + ********************************************************************** + ** End of md5.c ** + ******************************* (cut) ******************************** + */ + +#undef FF +#undef GG +#undef HH +#undef II +#undef F +#undef G +#undef H +#undef I + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +MD5 MD5::Zero; // Initialized to all zeroes + +////////////////////////////////////////////////////////////////////////// + +MD5Stream::MD5Stream() +{ + Reset(); +} + +void +MD5Stream::Reset() +{ +} + +MD5Stream& +MD5Stream::Append(const void* Data, size_t ByteCount) +{ + ZEN_UNUSED(Data); + ZEN_UNUSED(ByteCount); + + return *this; +} + +MD5 +MD5Stream::GetHash() +{ + MD5 md5{}; + + return md5; +} + +////////////////////////////////////////////////////////////////////////// + +MD5 +MD5::HashMemory(const void* data, size_t byteCount) +{ + return MD5Stream().Append(data, byteCount).GetHash(); +} + +MD5 +MD5::FromHexString(const char* string) +{ + MD5 md5; + + ParseHexBytes(string, 40, md5.Hash); + + return md5; +} + +const char* +MD5::ToHexString(char* outString /* 32 characters + NUL terminator */) const +{ + ToHexBytes(Hash, sizeof(MD5), outString); + outString[2 * sizeof(MD5)] = '\0'; + + return outString; +} + +StringBuilderBase& +MD5::ToHexString(StringBuilderBase& outBuilder) const +{ + char str[41]; + ToHexString(str); + + outBuilder.AppendRange(str, &str[40]); + + return outBuilder; +} + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +void +md5_forcelink() +{ +} + +// doctest::String +// toString(const MD5& value) +// { +// char md5text[2 * sizeof(MD5) + 1]; +// value.ToHexString(md5text); + +// return md5text; +// } + +TEST_CASE("MD5") +{ + using namespace std::literals; + + auto Input = "jumblesmcgee"sv; + auto Output = "28f2200a59c60b75947099d750c2cc50"sv; + + MD5Stream Stream; + Stream.Append(Input.data(), Input.length()); + MD5 Result = Stream.GetHash(); + + MD5::String_t Buffer; + Result.ToHexString(Buffer); + + CHECK(Output.compare(Buffer)); + + MD5 Reresult = MD5::FromHexString(Buffer); + Reresult.ToHexString(Buffer); + CHECK(Output.compare(Buffer)); +} + +#endif + +} // namespace zen diff --git a/src/zencore/memory.cpp b/src/zencore/memory.cpp new file mode 100644 index 000000000..1f148cede --- /dev/null +++ b/src/zencore/memory.cpp @@ -0,0 +1,211 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/intmath.h> +#include <zencore/memory.h> +#include <zencore/testing.h> +#include <zencore/zencore.h> + +#if ZEN_PLATFORM_WINDOWS +# include <malloc.h> +ZEN_THIRD_PARTY_INCLUDES_START +# include <mimalloc.h> +ZEN_THIRD_PARTY_INCLUDES_END +#else +# include <cstdlib> +#endif + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +static void* +AlignedAllocImpl(size_t Size, size_t Alignment) +{ +#if ZEN_PLATFORM_WINDOWS +# if ZEN_USE_MIMALLOC && 0 /* this path is not functional */ + return mi_aligned_alloc(Alignment, Size); +# else + return _aligned_malloc(Size, Alignment); +# endif +#else + // aligned_alloc() states that size must be a multiple of alignment. Some + // platforms return null if this requirement isn't met. + Size = (Size + Alignment - 1) & ~(Alignment - 1); + return std::aligned_alloc(Alignment, Size); +#endif +} + +void +AlignedFreeImpl(void* ptr) +{ + if (ptr == nullptr) + return; + +#if ZEN_PLATFORM_WINDOWS +# if ZEN_USE_MIMALLOC && 0 /* this path is not functional */ + return mi_free(ptr); +# else + _aligned_free(ptr); +# endif +#else + std::free(ptr); +#endif +} + +////////////////////////////////////////////////////////////////////////// + +MemoryArena::MemoryArena() +{ +} + +MemoryArena::~MemoryArena() +{ +} + +void* +MemoryArena::Alloc(size_t Size, size_t Alignment) +{ + return AlignedAllocImpl(Size, Alignment); +} + +void +MemoryArena::Free(void* ptr) +{ + AlignedFreeImpl(ptr); +} + +////////////////////////////////////////////////////////////////////////// + +void* +Memory::Alloc(size_t Size, size_t Alignment) +{ + return AlignedAllocImpl(Size, Alignment); +} + +void +Memory::Free(void* ptr) +{ + AlignedFreeImpl(ptr); +} + +////////////////////////////////////////////////////////////////////////// + +ChunkingLinearAllocator::ChunkingLinearAllocator(uint64_t ChunkSize, uint64_t ChunkAlignment) +: m_ChunkSize(ChunkSize) +, m_ChunkAlignment(ChunkAlignment) +{ +} + +ChunkingLinearAllocator::~ChunkingLinearAllocator() +{ + Reset(); +} + +void +ChunkingLinearAllocator::Reset() +{ + for (void* ChunkEntry : m_ChunkList) + { + Memory::Free(ChunkEntry); + } + m_ChunkList.clear(); + + m_ChunkCursor = nullptr; + m_ChunkBytesRemain = 0; +} + +void* +ChunkingLinearAllocator::Alloc(size_t Size, size_t Alignment) +{ + ZEN_ASSERT_SLOW(zen::IsPow2(Alignment)); + + // This could be improved in a bunch of ways + // + // * We pessimistically allocate memory even though there may be enough memory available for a single allocation due to the way we take + // alignment into account below + // * The block allocation size could be chosen to minimize slack for the case when multiple oversize allocations are made rather than + // minimizing the number of chunks + // * ... + + const uint64_t AllocationSize = zen::RoundUp(Size, Alignment); + + if (m_ChunkBytesRemain < (AllocationSize + Alignment - 1)) + { + const uint64_t ChunkSize = zen::RoundUp(zen::Max(m_ChunkSize, Size), m_ChunkSize); + void* ChunkPtr = Memory::Alloc(ChunkSize, m_ChunkAlignment); + m_ChunkCursor = reinterpret_cast<uint8_t*>(ChunkPtr); + m_ChunkBytesRemain = ChunkSize; + m_ChunkList.push_back(ChunkPtr); + } + + const uint64_t AlignFixup = (Alignment - reinterpret_cast<uintptr_t>(m_ChunkCursor)) & (Alignment - 1); + void* ReturnPtr = m_ChunkCursor + AlignFixup; + const uint64_t Delta = AlignFixup + AllocationSize; + + ZEN_ASSERT_SLOW(m_ChunkBytesRemain >= Delta); + + m_ChunkCursor += Delta; + m_ChunkBytesRemain -= Delta; + + ZEN_ASSERT_SLOW(IsPointerAligned(ReturnPtr, Alignment)); + + return ReturnPtr; +} + +////////////////////////////////////////////////////////////////////////// +// +// Unit tests +// + +#if ZEN_WITH_TESTS + +TEST_CASE("ChunkingLinearAllocator") +{ + ChunkingLinearAllocator Allocator(4096); + + void* p1 = Allocator.Alloc(1, 1); + void* p2 = Allocator.Alloc(1, 1); + + CHECK(p1 != p2); + + void* p3 = Allocator.Alloc(1, 4); + CHECK(IsPointerAligned(p3, 4)); + + void* p3_2 = Allocator.Alloc(1, 4); + CHECK(IsPointerAligned(p3_2, 4)); + + void* p4 = Allocator.Alloc(1, 8); + CHECK(IsPointerAligned(p4, 8)); + + for (int i = 0; i < 100; ++i) + { + void* p0 = Allocator.Alloc(64); + ZEN_UNUSED(p0); + } +} + +TEST_CASE("MemoryView") +{ + { + uint8_t Array1[16] = {}; + MemoryView View1 = MakeMemoryView(Array1); + CHECK(View1.GetSize() == 16); + } + + { + uint32_t Array2[16] = {}; + MemoryView View2 = MakeMemoryView(Array2); + CHECK(View2.GetSize() == 64); + } + + CHECK(MakeMemoryView<float>({1.0f, 1.2f}).GetSize() == 8); +} + +void +memory_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zencore/mpscqueue.cpp b/src/zencore/mpscqueue.cpp new file mode 100644 index 000000000..29c76c3ca --- /dev/null +++ b/src/zencore/mpscqueue.cpp @@ -0,0 +1,25 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/mpscqueue.h> + +#include <zencore/testing.h> +#include <string> + +namespace zen { + +#if ZEN_WITH_TESTS && 0 +TEST_CASE("mpsc") +{ + MpscQueue<std::string> Queue; + Queue.Enqueue("hello"); + std::optional<std::string> Value = Queue.Dequeue(); + CHECK_EQ(Value, "hello"); +} +#endif + +void +mpscqueue_forcelink() +{ +} + +} // namespace zen
\ No newline at end of file diff --git a/src/zencore/refcount.cpp b/src/zencore/refcount.cpp new file mode 100644 index 000000000..c6c47b04d --- /dev/null +++ b/src/zencore/refcount.cpp @@ -0,0 +1,65 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/refcount.h> + +#include <zencore/testing.h> + +#include <functional> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +struct TestRefClass : public RefCounted +{ + ~TestRefClass() + { + if (OnDestroy) + OnDestroy(); + } + + using RefCounted::RefCount; + + std::function<void()> OnDestroy; +}; + +void +refcount_forcelink() +{ +} + +TEST_CASE("RefPtr") +{ + RefPtr<TestRefClass> Ref; + Ref = new TestRefClass; + + bool IsDestroyed = false; + Ref->OnDestroy = [&] { IsDestroyed = true; }; + + CHECK(IsDestroyed == false); + CHECK(Ref->RefCount() == 1); + + RefPtr<TestRefClass> Ref2; + Ref2 = Ref; + + CHECK(IsDestroyed == false); + CHECK(Ref->RefCount() == 2); + + RefPtr<TestRefClass> Ref3; + Ref2 = Ref3; + + CHECK(IsDestroyed == false); + CHECK(Ref->RefCount() == 1); + Ref = Ref3; + + CHECK(IsDestroyed == true); +} + +#endif + +} // namespace zen diff --git a/src/zencore/session.cpp b/src/zencore/session.cpp new file mode 100644 index 000000000..ce4bfae1b --- /dev/null +++ b/src/zencore/session.cpp @@ -0,0 +1,35 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/session.h" + +#include <zencore/uid.h> + +#include <mutex> + +namespace zen { + +static Oid GlobalSessionId; +static char GlobalSessionString[Oid::StringLength]; +static std::once_flag SessionInitFlag; + +Oid +GetSessionId() +{ + std::call_once(SessionInitFlag, [&] { + GlobalSessionId.Generate(); + GlobalSessionId.ToString(GlobalSessionString); + }); + + return GlobalSessionId; +} + +std::string_view +GetSessionIdString() +{ + // Ensure we actually have a generated session identifier + std::ignore = GetSessionId(); + + return std::string_view(GlobalSessionString, Oid::StringLength); +} + +} // namespace zen
\ No newline at end of file diff --git a/src/zencore/sha1.cpp b/src/zencore/sha1.cpp new file mode 100644 index 000000000..3ee74d7d8 --- /dev/null +++ b/src/zencore/sha1.cpp @@ -0,0 +1,443 @@ +// ////////////////////////////////////////////////////////// +// sha1.cpp +// Copyright (c) 2014,2015 Stephan Brumme. All rights reserved. +// see http://create.stephan-brumme.com/disclaimer.html +// + +#include <zencore/sha1.h> +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/zencore.h> + +#include <string.h> + +// big endian architectures need #define __BYTE_ORDER __BIG_ENDIAN +#if ZEN_PLATFORM_LINUX +# include <endian.h> +#endif + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +SHA1 SHA1::Zero; // Initialized to all zeroes + +////////////////////////////////////////////////////////////////////////// + +SHA1Stream::SHA1Stream() +{ + Reset(); +} + +void +SHA1Stream::Reset() +{ + m_NumBytes = 0; + m_BufferSize = 0; + + // according to RFC 1321 + m_Hash[0] = 0x67452301; + m_Hash[1] = 0xefcdab89; + m_Hash[2] = 0x98badcfe; + m_Hash[3] = 0x10325476; + m_Hash[4] = 0xc3d2e1f0; +} + +namespace { + // mix functions for processBlock() + inline uint32_t f1(uint32_t b, uint32_t c, uint32_t d) + { + return d ^ (b & (c ^ d)); // original: f = (b & c) | ((~b) & d); + } + + inline uint32_t f2(uint32_t b, uint32_t c, uint32_t d) { return b ^ c ^ d; } + + inline uint32_t f3(uint32_t b, uint32_t c, uint32_t d) { return (b & c) | (b & d) | (c & d); } + + inline uint32_t rotate(uint32_t a, uint32_t c) { return (a << c) | (a >> (32 - c)); } + + inline uint32_t swap(uint32_t x) + { +#if defined(__GNUC__) || defined(__clang__) + return __builtin_bswap32(x); +#endif +#ifdef MSC_VER + return _byteswap_ulong(x); +#endif + + return (x >> 24) | ((x >> 8) & 0x0000FF00) | ((x << 8) & 0x00FF0000) | (x << 24); + } +} // namespace + +/// process 64 bytes +void +SHA1Stream::ProcessBlock(const void* data) +{ + // get last hash + uint32_t a = m_Hash[0]; + uint32_t b = m_Hash[1]; + uint32_t c = m_Hash[2]; + uint32_t d = m_Hash[3]; + uint32_t e = m_Hash[4]; + + // data represented as 16x 32-bit words + const uint32_t* input = (uint32_t*)data; + // convert to big endian + uint32_t words[80]; + for (int i = 0; i < 16; i++) +#if defined(__BYTE_ORDER) && (__BYTE_ORDER != 0) && (__BYTE_ORDER == __BIG_ENDIAN) + words[i] = input[i]; +#else + words[i] = swap(input[i]); +#endif + + // extend to 80 words + for (int i = 16; i < 80; i++) + words[i] = rotate(words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16], 1); + + // first round + for (int i = 0; i < 4; i++) + { + int offset = 5 * i; + e += rotate(a, 5) + f1(b, c, d) + words[offset] + 0x5a827999; + b = rotate(b, 30); + d += rotate(e, 5) + f1(a, b, c) + words[offset + 1] + 0x5a827999; + a = rotate(a, 30); + c += rotate(d, 5) + f1(e, a, b) + words[offset + 2] + 0x5a827999; + e = rotate(e, 30); + b += rotate(c, 5) + f1(d, e, a) + words[offset + 3] + 0x5a827999; + d = rotate(d, 30); + a += rotate(b, 5) + f1(c, d, e) + words[offset + 4] + 0x5a827999; + c = rotate(c, 30); + } + + // second round + for (int i = 4; i < 8; i++) + { + int offset = 5 * i; + e += rotate(a, 5) + f2(b, c, d) + words[offset] + 0x6ed9eba1; + b = rotate(b, 30); + d += rotate(e, 5) + f2(a, b, c) + words[offset + 1] + 0x6ed9eba1; + a = rotate(a, 30); + c += rotate(d, 5) + f2(e, a, b) + words[offset + 2] + 0x6ed9eba1; + e = rotate(e, 30); + b += rotate(c, 5) + f2(d, e, a) + words[offset + 3] + 0x6ed9eba1; + d = rotate(d, 30); + a += rotate(b, 5) + f2(c, d, e) + words[offset + 4] + 0x6ed9eba1; + c = rotate(c, 30); + } + + // third round + for (int i = 8; i < 12; i++) + { + int offset = 5 * i; + e += rotate(a, 5) + f3(b, c, d) + words[offset] + 0x8f1bbcdc; + b = rotate(b, 30); + d += rotate(e, 5) + f3(a, b, c) + words[offset + 1] + 0x8f1bbcdc; + a = rotate(a, 30); + c += rotate(d, 5) + f3(e, a, b) + words[offset + 2] + 0x8f1bbcdc; + e = rotate(e, 30); + b += rotate(c, 5) + f3(d, e, a) + words[offset + 3] + 0x8f1bbcdc; + d = rotate(d, 30); + a += rotate(b, 5) + f3(c, d, e) + words[offset + 4] + 0x8f1bbcdc; + c = rotate(c, 30); + } + + // fourth round + for (int i = 12; i < 16; i++) + { + int offset = 5 * i; + e += rotate(a, 5) + f2(b, c, d) + words[offset] + 0xca62c1d6; + b = rotate(b, 30); + d += rotate(e, 5) + f2(a, b, c) + words[offset + 1] + 0xca62c1d6; + a = rotate(a, 30); + c += rotate(d, 5) + f2(e, a, b) + words[offset + 2] + 0xca62c1d6; + e = rotate(e, 30); + b += rotate(c, 5) + f2(d, e, a) + words[offset + 3] + 0xca62c1d6; + d = rotate(d, 30); + a += rotate(b, 5) + f2(c, d, e) + words[offset + 4] + 0xca62c1d6; + c = rotate(c, 30); + } + + // update hash + m_Hash[0] += a; + m_Hash[1] += b; + m_Hash[2] += c; + m_Hash[3] += d; + m_Hash[4] += e; +} + +/// add arbitrary number of bytes +SHA1Stream& +SHA1Stream::Append(const void* data, size_t byteCount) +{ + const uint8_t* current = (const uint8_t*)data; + + if (m_BufferSize > 0) + { + while (byteCount > 0 && m_BufferSize < BlockSize) + { + m_Buffer[m_BufferSize++] = *current++; + byteCount--; + } + } + + // full buffer + if (m_BufferSize == BlockSize) + { + ProcessBlock((void*)m_Buffer); + m_NumBytes += BlockSize; + m_BufferSize = 0; + } + + // no more data ? + if (byteCount == 0) + return *this; + + // process full blocks + while (byteCount >= BlockSize) + { + ProcessBlock(current); + current += BlockSize; + m_NumBytes += BlockSize; + byteCount -= BlockSize; + } + + // keep remaining bytes in buffer + while (byteCount > 0) + { + m_Buffer[m_BufferSize++] = *current++; + byteCount--; + } + + return *this; +} + +/// process final block, less than 64 bytes +void +SHA1Stream::ProcessBuffer() +{ + // the input bytes are considered as bits strings, where the first bit is the most significant bit of the byte + + // - append "1" bit to message + // - append "0" bits until message length in bit mod 512 is 448 + // - append length as 64 bit integer + + // number of bits + size_t paddedLength = m_BufferSize * 8; + + // plus one bit set to 1 (always appended) + paddedLength++; + + // number of bits must be (numBits % 512) = 448 + size_t lower11Bits = paddedLength & 511; + if (lower11Bits <= 448) + paddedLength += 448 - lower11Bits; + else + paddedLength += 512 + 448 - lower11Bits; + // convert from bits to bytes + paddedLength /= 8; + + // only needed if additional data flows over into a second block + unsigned char extra[BlockSize]; + + // append a "1" bit, 128 => binary 10000000 + if (m_BufferSize < BlockSize) + m_Buffer[m_BufferSize] = 128; + else + extra[0] = 128; + + size_t i; + for (i = m_BufferSize + 1; i < BlockSize; i++) + m_Buffer[i] = 0; + for (; i < paddedLength; i++) + extra[i - BlockSize] = 0; + + // add message length in bits as 64 bit number + uint64_t msgBits = 8 * (m_NumBytes + m_BufferSize); + // find right position + unsigned char* addLength; + if (paddedLength < BlockSize) + addLength = m_Buffer + paddedLength; + else + addLength = extra + paddedLength - BlockSize; + + // must be big endian + *addLength++ = (unsigned char)((msgBits >> 56) & 0xFF); + *addLength++ = (unsigned char)((msgBits >> 48) & 0xFF); + *addLength++ = (unsigned char)((msgBits >> 40) & 0xFF); + *addLength++ = (unsigned char)((msgBits >> 32) & 0xFF); + *addLength++ = (unsigned char)((msgBits >> 24) & 0xFF); + *addLength++ = (unsigned char)((msgBits >> 16) & 0xFF); + *addLength++ = (unsigned char)((msgBits >> 8) & 0xFF); + *addLength = (unsigned char)(msgBits & 0xFF); + + // process blocks + ProcessBlock(m_Buffer); + // flowed over into a second block ? + if (paddedLength > BlockSize) + ProcessBlock(extra); +} + +/// return latest hash as bytes +SHA1 +SHA1Stream::GetHash() +{ + SHA1 sha1; + // save old hash if buffer is partially filled + uint32_t oldHash[HashValues]; + for (int i = 0; i < HashValues; i++) + oldHash[i] = m_Hash[i]; + + // process remaining bytes + ProcessBuffer(); + + unsigned char* current = sha1.Hash; + for (int i = 0; i < HashValues; i++) + { + *current++ = (m_Hash[i] >> 24) & 0xFF; + *current++ = (m_Hash[i] >> 16) & 0xFF; + *current++ = (m_Hash[i] >> 8) & 0xFF; + *current++ = m_Hash[i] & 0xFF; + + // restore old hash + m_Hash[i] = oldHash[i]; + } + + return sha1; +} + +/// compute SHA1 of a memory block +SHA1 +SHA1Stream::Compute(const void* data, size_t byteCount) +{ + Reset(); + Append(data, byteCount); + return GetHash(); +} + +SHA1 +SHA1::HashMemory(const void* data, size_t byteCount) +{ + return SHA1Stream().Append(data, byteCount).GetHash(); +} + +SHA1 +SHA1::FromHexString(const char* string) +{ + SHA1 sha1; + + ParseHexBytes(string, 40, sha1.Hash); + + return sha1; +} + +const char* +SHA1::ToHexString(char* outString /* 40 characters + NUL terminator */) const +{ + ToHexBytes(Hash, sizeof(SHA1), outString); + outString[2 * sizeof(SHA1)] = '\0'; + + return outString; +} + +StringBuilderBase& +SHA1::ToHexString(StringBuilderBase& outBuilder) const +{ + char str[41]; + ToHexString(str); + + outBuilder.AppendRange(str, &str[40]); + + return outBuilder; +} + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +void +sha1_forcelink() +{ +} + +// doctest::String +// toString(const SHA1& value) +// { +// char sha1text[2 * sizeof(SHA1) + 1]; +// value.ToHexString(sha1text); + +// return sha1text; +// } + +TEST_CASE("SHA1") +{ + uint8_t sha1_empty[20] = {0xda, 0x39, 0xa3, 0xee, 0x5e, 0x6b, 0x4b, 0x0d, 0x32, 0x55, + 0xbf, 0xef, 0x95, 0x60, 0x18, 0x90, 0xaf, 0xd8, 0x07, 0x09}; + SHA1 sha1z; + memcpy(sha1z.Hash, sha1_empty, sizeof sha1z.Hash); + + SUBCASE("Empty string") + { + SHA1 sha1 = SHA1::HashMemory(nullptr, 0); + + CHECK(sha1 == sha1z); + } + + SUBCASE("Empty stream") + { + SHA1Stream sha1s; + sha1s.Append(nullptr, 0); + sha1s.Append(nullptr, 0); + sha1s.Append(nullptr, 0); + CHECK(sha1s.GetHash() == sha1z); + } + + SUBCASE("SHA1 from string") + { + const SHA1 sha1empty = SHA1::FromHexString("da39a3ee5e6b4b0d3255bfef95601890afd80709"); + + CHECK(sha1z == sha1empty); + } + + SUBCASE("SHA1 to string") + { + char sha1str[41]; + sha1z.ToHexString(sha1str); + + CHECK(StringEquals(sha1str, "da39a3ee5e6b4b0d3255bfef95601890afd80709")); + } + + SUBCASE("Hash ABC") + { + const SHA1 sha1abc = SHA1::FromHexString("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8"); + + SHA1Stream sha1s; + + sha1s.Append("A", 1); + sha1s.Append("B", 1); + sha1s.Append("C", 1); + CHECK(sha1s.GetHash() == sha1abc); + + sha1s.Reset(); + sha1s.Append("AB", 2); + sha1s.Append("C", 1); + CHECK(sha1s.GetHash() == sha1abc); + + sha1s.Reset(); + sha1s.Append("ABC", 3); + CHECK(sha1s.GetHash() == sha1abc); + + sha1s.Reset(); + sha1s.Append("A", 1); + sha1s.Append("BC", 2); + CHECK(sha1s.GetHash() == sha1abc); + } +} + +#endif + +} // namespace zen diff --git a/src/zencore/sharedbuffer.cpp b/src/zencore/sharedbuffer.cpp new file mode 100644 index 000000000..200e06972 --- /dev/null +++ b/src/zencore/sharedbuffer.cpp @@ -0,0 +1,146 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/sharedbuffer.h> + +#include <zencore/testing.h> + +#include <memory.h> + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +UniqueBuffer +UniqueBuffer::Alloc(uint64_t Size) +{ + void* Buffer = Memory::Alloc(Size, 16); + IoBufferCore* Owner = new IoBufferCore(Buffer, Size); + Owner->SetIsOwnedByThis(true); + Owner->SetIsImmutable(false); + + return UniqueBuffer(Owner); +} + +UniqueBuffer +UniqueBuffer::MakeMutableView(void* DataPtr, uint64_t Size) +{ + IoBufferCore* Owner = new IoBufferCore(DataPtr, Size); + Owner->SetIsImmutable(false); + return UniqueBuffer(Owner); +} + +UniqueBuffer::UniqueBuffer(IoBufferCore* Owner) : m_Buffer(Owner) +{ +} + +SharedBuffer +UniqueBuffer::MoveToShared() +{ + return SharedBuffer(std::move(m_Buffer)); +} + +void +UniqueBuffer::Reset() +{ + m_Buffer = nullptr; +} + +////////////////////////////////////////////////////////////////////////// + +SharedBuffer::SharedBuffer(UniqueBuffer&& InBuffer) : m_Buffer(std::move(InBuffer.m_Buffer)) +{ +} + +SharedBuffer +SharedBuffer::MakeOwned() const& +{ + if (IsOwned() || !m_Buffer) + { + return *this; + } + else + { + return Clone(GetView()); + } +} + +SharedBuffer +SharedBuffer::MakeOwned() && +{ + if (IsOwned()) + { + return std::move(*this); + } + else + { + return Clone(GetView()); + } +} + +SharedBuffer +SharedBuffer::MakeView(MemoryView View, SharedBuffer OuterBuffer) +{ + if (OuterBuffer) + { + ZEN_ASSERT(OuterBuffer.GetView().Contains(View)); + } + + if (View == OuterBuffer.GetView()) + { + // Reference to the full buffer contents, so just return the "outer" + return OuterBuffer; + } + + IoBufferCore* NewCore = new IoBufferCore(OuterBuffer.m_Buffer, View.GetData(), View.GetSize()); + NewCore->SetIsImmutable(true); + return SharedBuffer(NewCore); +} + +SharedBuffer +SharedBuffer::MakeView(const void* Data, uint64_t Size) +{ + return SharedBuffer(new IoBufferCore(const_cast<void*>(Data), Size)); +} + +SharedBuffer +SharedBuffer::Clone() +{ + const uint64_t Size = GetSize(); + void* Buffer = Memory::Alloc(Size, 16); + auto NewOwner = new IoBufferCore(Buffer, Size); + NewOwner->SetIsOwnedByThis(true); + memcpy(Buffer, m_Buffer->DataPointer(), Size); + + return SharedBuffer(NewOwner); +} + +SharedBuffer +SharedBuffer::Clone(MemoryView View) +{ + const uint64_t Size = View.GetSize(); + void* Buffer = Memory::Alloc(Size, 16); + auto NewOwner = new IoBufferCore(Buffer, Size); + NewOwner->SetIsOwnedByThis(true); + memcpy(Buffer, View.GetData(), Size); + + return SharedBuffer(NewOwner); +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +void +sharedbuffer_forcelink() +{ +} + +TEST_CASE("SharedBuffer") +{ +} + +#endif + +} // namespace zen diff --git a/src/zencore/stats.cpp b/src/zencore/stats.cpp new file mode 100644 index 000000000..372bc42f8 --- /dev/null +++ b/src/zencore/stats.cpp @@ -0,0 +1,715 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/stats.h" + +#include <zencore/compactbinarybuilder.h> +#include "zencore/intmath.h" +#include "zencore/thread.h" +#include "zencore/timer.h" + +#include <cmath> +#include <gsl/gsl-lite.hpp> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif + +// +// Derived from https://github.com/dln/medida/blob/master/src/medida/stats/ewma.cc +// + +namespace zen::metrics { + +static constinit int kTickIntervalInSeconds = 5; +static constinit double kSecondsPerMinute = 60.0; +static constinit int kOneMinute = 1; +static constinit int kFiveMinutes = 5; +static constinit int kFifteenMinutes = 15; + +static const double kM1_ALPHA = 1.0 - std::exp(-kTickIntervalInSeconds / kSecondsPerMinute / kOneMinute); +static const double kM5_ALPHA = 1.0 - std::exp(-kTickIntervalInSeconds / kSecondsPerMinute / kFiveMinutes); +static const double kM15_ALPHA = 1.0 - std::exp(-kTickIntervalInSeconds / kSecondsPerMinute / kFifteenMinutes); + +static const uint64_t CountPerTick = GetHifreqTimerFrequencySafe() * kTickIntervalInSeconds; +static const uint64_t CountPerSecond = GetHifreqTimerFrequencySafe(); + +////////////////////////////////////////////////////////////////////////// + +void +RawEWMA::Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate) +{ + const double InstantRate = double(Count) / Interval; + + if (IsInitialUpdate) + { + m_Rate.store(InstantRate, std::memory_order_release); + } + else + { + double Delta = Alpha * (InstantRate - m_Rate); + +#if defined(__cpp_lib_atomic_float) + m_Rate.fetch_add(Delta); +#else + double Value = m_Rate.load(std::memory_order_acquire); + double Next; + do + { + Next = Value + Delta; + } while (!m_Rate.compare_exchange_weak(Value, Next, std::memory_order_relaxed)); +#endif + } +} + +double +RawEWMA::Rate() const +{ + return m_Rate.load(std::memory_order_relaxed) * CountPerSecond; +} + +////////////////////////////////////////////////////////////////////////// + +Meter::Meter() : m_StartTick{GetHifreqTimerValue()}, m_LastTick(m_StartTick.load()) +{ +} + +Meter::~Meter() +{ +} + +void +Meter::TickIfNecessary() +{ + uint64_t OldTick = m_LastTick.load(); + const uint64_t NewTick = GetHifreqTimerValue(); + const uint64_t Age = NewTick - OldTick; + + if (Age > CountPerTick) + { + // Ensure only one thread at a time updates the time. This + // works because our tick interval should be sufficiently + // long to ensure two threads don't end up inside this block + + if (m_LastTick.compare_exchange_strong(OldTick, NewTick)) + { + m_Remainder.fetch_add(Age); + + do + { + int64_t Remain = m_Remainder.load(std::memory_order_relaxed); + + if (Remain < 0) + { + return; + } + + if (m_Remainder.compare_exchange_strong(Remain, Remain - CountPerTick)) + { + Tick(); + } + } while (true); + } + } +} + +void +Meter::Tick() +{ + const uint64_t PendingCount = m_PendingCount.exchange(0); + const bool IsFirstTick = m_IsFirstTick; + + if (IsFirstTick) + { + m_IsFirstTick = false; + } + + m_RateM1.Tick(kM1_ALPHA, CountPerTick, PendingCount, IsFirstTick); + m_RateM5.Tick(kM5_ALPHA, CountPerTick, PendingCount, IsFirstTick); + m_RateM15.Tick(kM15_ALPHA, CountPerTick, PendingCount, IsFirstTick); +} + +double +Meter::Rate1() +{ + TickIfNecessary(); + + return m_RateM1.Rate(); +} + +double +Meter::Rate5() +{ + TickIfNecessary(); + + return m_RateM5.Rate(); +} + +double +Meter::Rate15() +{ + TickIfNecessary(); + + return m_RateM15.Rate(); +} + +double +Meter::MeanRate() const +{ + const uint64_t Count = m_TotalCount.load(std::memory_order_relaxed); + + if (Count == 0) + { + return 0.0; + } + + const uint64_t Elapsed = GetHifreqTimerValue() - m_StartTick; + + return (double(Count) * GetHifreqTimerFrequency()) / Elapsed; +} + +void +Meter::Mark(uint64_t Count) +{ + TickIfNecessary(); + + m_TotalCount.fetch_add(Count); + m_PendingCount.fetch_add(Count); +} + +////////////////////////////////////////////////////////////////////////// + +// TODO: should consider a cheaper RNG here, this will run for every thread +// that gets created + +thread_local std::mt19937_64 ThreadLocalRng; + +UniformSample::UniformSample(uint32_t ReservoirSize) : m_Values(ReservoirSize) +{ +} + +UniformSample::~UniformSample() +{ +} + +void +UniformSample::Clear() +{ + for (auto& Value : m_Values) + { + Value.store(0); + } + m_SampleCounter = 0; +} + +uint32_t +UniformSample::Size() const +{ + return gsl::narrow_cast<uint32_t>(Min(m_SampleCounter.load(), m_Values.size())); +} + +void +UniformSample::Update(int64_t Value) +{ + const uint64_t Count = m_SampleCounter++; + const uint64_t Size = m_Values.size(); + + if (Count < Size) + { + m_Values[Count] = Value; + } + else + { + // Randomly choose an old entry to potentially replace (the probability + // of replacing an entry diminishes with time) + + std::uniform_int_distribution<uint64_t> UniformDist(0, Count); + uint64_t SampleIndex = UniformDist(ThreadLocalRng); + + if (SampleIndex < Size) + { + m_Values[SampleIndex].store(Value, std::memory_order_release); + } + } +} + +SampleSnapshot +UniformSample::Snapshot() const +{ + uint64_t ValuesSize = Size(); + std::vector<double> Values(ValuesSize); + + for (int i = 0, n = int(ValuesSize); i < n; ++i) + { + Values[i] = double(m_Values[i]); + } + + return SampleSnapshot(std::move(Values)); +} + +////////////////////////////////////////////////////////////////////////// + +Histogram::Histogram(int32_t SampleCount) : m_Sample(SampleCount) +{ +} + +Histogram::~Histogram() +{ +} + +void +Histogram::Clear() +{ + m_Min = m_Max = m_Sum = m_Count = 0; + m_Sample.Clear(); +} + +void +Histogram::Update(int64_t Value) +{ + m_Sample.Update(Value); + + if (m_Count == 0) + { + m_Min = m_Max = Value; + } + else + { + int64_t CurrentMax = m_Max.load(std::memory_order_relaxed); + + while ((CurrentMax < Value) && !m_Max.compare_exchange_weak(CurrentMax, Value)) + { + } + + int64_t CurrentMin = m_Min.load(std::memory_order_relaxed); + + while ((CurrentMin > Value) && !m_Min.compare_exchange_weak(CurrentMin, Value)) + { + } + } + + m_Sum += Value; + ++m_Count; +} + +int64_t +Histogram::Max() const +{ + return m_Max.load(std::memory_order_relaxed); +} + +int64_t +Histogram::Min() const +{ + return m_Min.load(std::memory_order_relaxed); +} + +double +Histogram::Mean() const +{ + if (m_Count) + { + return double(m_Sum.load(std::memory_order_relaxed)) / m_Count; + } + else + { + return 0.0; + } +} + +uint64_t +Histogram::Count() const +{ + return m_Count.load(std::memory_order_relaxed); +} + +////////////////////////////////////////////////////////////////////////// + +SampleSnapshot::SampleSnapshot(std::vector<double>&& Values) : m_Values(std::move(Values)) +{ + std::sort(begin(m_Values), end(m_Values)); +} + +SampleSnapshot::~SampleSnapshot() +{ +} + +double +SampleSnapshot::GetQuantileValue(double Quantile) +{ + ZEN_ASSERT((Quantile >= 0.0) && (Quantile <= 1.0)); + + if (m_Values.empty()) + { + return 0.0; + } + + const double Pos = Quantile * (m_Values.size() + 1); + + if (Pos < 1) + { + return m_Values.front(); + } + + if (Pos >= m_Values.size()) + { + return m_Values.back(); + } + + const int32_t Index = (int32_t)Pos; + const double Lower = m_Values[Index - 1]; + const double Upper = m_Values[Index]; + + // Lerp + return Lower + (Pos - std::floor(Pos)) * (Upper - Lower); +} + +const std::vector<double>& +SampleSnapshot::GetValues() const +{ + return m_Values; +} + +////////////////////////////////////////////////////////////////////////// + +OperationTiming::OperationTiming(int32_t SampleCount) : m_Histogram{SampleCount} +{ +} + +OperationTiming::~OperationTiming() +{ +} + +void +OperationTiming::Update(int64_t Duration) +{ + m_Meter.Mark(1); + m_Histogram.Update(Duration); +} + +int64_t +OperationTiming::Max() const +{ + return m_Histogram.Max(); +} + +int64_t +OperationTiming::Min() const +{ + return m_Histogram.Min(); +} + +double +OperationTiming::Mean() const +{ + return m_Histogram.Mean(); +} + +uint64_t +OperationTiming::Count() const +{ + return m_Meter.Count(); +} + +OperationTiming::Scope::Scope(OperationTiming& Outer) : m_Outer(Outer), m_StartTick(GetHifreqTimerValue()) +{ +} + +OperationTiming::Scope::~Scope() +{ + Stop(); +} + +void +OperationTiming::Scope::Stop() +{ + if (m_StartTick != 0) + { + m_Outer.Update(GetHifreqTimerValue() - m_StartTick); + m_StartTick = 0; + } +} + +void +OperationTiming::Scope::Cancel() +{ + m_StartTick = 0; +} + +////////////////////////////////////////////////////////////////////////// + +RequestStats::RequestStats(int32_t SampleCount) : m_RequestTimeHistogram{SampleCount}, m_BytesHistogram{SampleCount} +{ +} + +RequestStats::~RequestStats() +{ +} + +void +RequestStats::Update(int64_t Duration, int64_t Bytes) +{ + m_RequestMeter.Mark(1); + m_RequestTimeHistogram.Update(Duration); + + m_BytesMeter.Mark(Bytes); + m_BytesHistogram.Update(Bytes); +} + +uint64_t +RequestStats::Count() const +{ + return m_RequestMeter.Count(); +} + +////////////////////////////////////////////////////////////////////////// + +void +EmitSnapshot(Meter& Stat, CbObjectWriter& Cbo) +{ + Cbo << "count" << Stat.Count(); + Cbo << "rate_mean" << Stat.MeanRate(); + Cbo << "rate_1" << Stat.Rate1() << "rate_5" << Stat.Rate5() << "rate_15" << Stat.Rate15(); +} + +void +RequestStats::EmitSnapshot(std::string_view Tag, CbObjectWriter& Cbo) +{ + Cbo.BeginObject(Tag); + + Cbo.BeginObject("requests"); + metrics::EmitSnapshot(m_RequestMeter, Cbo); + metrics::EmitSnapshot(m_RequestTimeHistogram, Cbo, GetHifreqTimerToSeconds()); + Cbo.EndObject(); + + Cbo.BeginObject("bytes"); + metrics::EmitSnapshot(m_BytesMeter, Cbo); + metrics::EmitSnapshot(m_BytesHistogram, Cbo, 1.0); + Cbo.EndObject(); + + Cbo.EndObject(); +} + +void +EmitSnapshot(std::string_view Tag, OperationTiming& Stat, CbObjectWriter& Cbo) +{ + Cbo.BeginObject(Tag); + + SampleSnapshot Snap = Stat.Snapshot(); + + Cbo << "count" << Stat.Count(); + Cbo << "rate_mean" << Stat.MeanRate(); + Cbo << "rate_1" << Stat.Rate1() << "rate_5" << Stat.Rate5() << "rate_15" << Stat.Rate15(); + + const double ToSeconds = GetHifreqTimerToSeconds(); + + Cbo << "t_avg" << Stat.Mean() * ToSeconds; + Cbo << "t_min" << Stat.Min() * ToSeconds << "t_max" << Stat.Max() * ToSeconds; + Cbo << "t_p75" << Snap.Get75Percentile() * ToSeconds << "t_p95" << Snap.Get95Percentile() * ToSeconds << "t_p99" + << Snap.Get99Percentile() * ToSeconds << "t_p999" << Snap.Get999Percentile() * ToSeconds; + + Cbo.EndObject(); +} + +void +EmitSnapshot(std::string_view Tag, const Histogram& Stat, CbObjectWriter& Cbo, double ConversionFactor) +{ + Cbo.BeginObject(Tag); + EmitSnapshot(Stat, Cbo, ConversionFactor); + Cbo.EndObject(); +} + +void +EmitSnapshot(const Histogram& Stat, CbObjectWriter& Cbo, double ConversionFactor) +{ + SampleSnapshot Snap = Stat.Snapshot(); + + Cbo << "count" << Stat.Count() * ConversionFactor << "avg" << Stat.Mean() * ConversionFactor; + Cbo << "min" << Stat.Min() * ConversionFactor << "max" << Stat.Max() * ConversionFactor; + Cbo << "p75" << Snap.Get75Percentile() * ConversionFactor << "p95" << Snap.Get95Percentile() * ConversionFactor << "p99" + << Snap.Get99Percentile() * ConversionFactor << "p999" << Snap.Get999Percentile() * ConversionFactor; +} + +void +EmitSnapshot(std::string_view Tag, Meter& Stat, CbObjectWriter& Cbo) +{ + Cbo.BeginObject(Tag); + + Cbo << "count" << Stat.Count() << "rate_mean" << Stat.MeanRate(); + Cbo << "rate_1" << Stat.Rate1() << "rate_5" << Stat.Rate5() << "rate_15" << Stat.Rate15(); + + Cbo.EndObject(); +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +TEST_CASE("Core.Stats.Histogram") +{ + Histogram Histo{258}; + + SampleSnapshot Snap1 = Histo.Snapshot(); + CHECK_EQ(Snap1.Size(), 0); + CHECK_EQ(Snap1.GetMedian(), 0); + + Histo.Update(1); + CHECK_EQ(Histo.Min(), 1); + CHECK_EQ(Histo.Max(), 1); + + SampleSnapshot Snap2 = Histo.Snapshot(); + CHECK_EQ(Snap2.Size(), 1); + + Histo.Update(2); + CHECK_EQ(Histo.Min(), 1); + CHECK_EQ(Histo.Max(), 2); + + SampleSnapshot Snap3 = Histo.Snapshot(); + CHECK_EQ(Snap3.Size(), 2); + + Histo.Update(-2); + CHECK_EQ(Histo.Min(), -2); + CHECK_EQ(Histo.Max(), 2); + CHECK_EQ(Histo.Mean(), 1 / 3.0); + + SampleSnapshot Snap4 = Histo.Snapshot(); + CHECK_EQ(Snap4.Size(), 3); + CHECK_EQ(Snap4.GetMedian(), 1); + CHECK_EQ(Snap4.Get999Percentile(), 2); + CHECK_EQ(Snap4.GetQuantileValue(0), -2); +} + +TEST_CASE("Core.Stats.UniformSample") +{ + UniformSample Sample1{100}; + + for (int i = 0; i < 100; ++i) + { + for (int j = 1; j <= 100; ++j) + { + Sample1.Update(j); + } + } + + int64_t Sum = 0; + int64_t Count = 0; + + Sample1.IterateValues([&](int64_t Value) { + ++Count; + Sum += Value; + }); + + double Average = double(Sum) / Count; + + CHECK(fabs(Average - 50) < 10); // What's the right test here? The result could vary massively and still be technically correct +} + +TEST_CASE("Core.Stats.EWMA") +{ + SUBCASE("Simple_1") + { + RawEWMA Ewma1; + Ewma1.Tick(kM1_ALPHA, CountPerSecond, 5, true); + + CHECK(fabs(Ewma1.Rate() - 5) < 0.1); + + for (int i = 0; i < 60; ++i) + { + Ewma1.Tick(kM1_ALPHA, CountPerSecond, 10, false); + } + + CHECK(fabs(Ewma1.Rate() - 10) < 0.1); + + for (int i = 0; i < 60; ++i) + { + Ewma1.Tick(kM1_ALPHA, CountPerSecond, 20, false); + } + + CHECK(fabs(Ewma1.Rate() - 20) < 0.1); + } + + SUBCASE("Simple_10") + { + RawEWMA Ewma1; + RawEWMA Ewma5; + RawEWMA Ewma15; + Ewma1.Tick(kM1_ALPHA, CountPerSecond, 5, true); + Ewma5.Tick(kM5_ALPHA, CountPerSecond, 5, true); + Ewma15.Tick(kM15_ALPHA, CountPerSecond, 5, true); + + CHECK(fabs(Ewma1.Rate() - 5) < 0.1); + CHECK(fabs(Ewma5.Rate() - 5) < 0.1); + CHECK(fabs(Ewma15.Rate() - 5) < 0.1); + + auto Tick1 = [&Ewma1](auto Value) { Ewma1.Tick(kM1_ALPHA, CountPerSecond, Value, false); }; + auto Tick5 = [&Ewma5](auto Value) { Ewma5.Tick(kM5_ALPHA, CountPerSecond, Value, false); }; + auto Tick15 = [&Ewma15](auto Value) { Ewma15.Tick(kM15_ALPHA, CountPerSecond, Value, false); }; + + for (int i = 0; i < 60; ++i) + { + Tick1(10); + Tick5(10); + Tick15(10); + } + + CHECK(fabs(Ewma1.Rate() - 10) < 0.1); + + for (int i = 0; i < 5 * 60; ++i) + { + Tick1(20); + Tick5(20); + Tick15(20); + } + + CHECK(fabs(Ewma1.Rate() - 20) < 0.1); + CHECK(fabs(Ewma5.Rate() - 20) < 0.1); + + for (int i = 0; i < 16 * 60; ++i) + { + Tick1(100); + Tick5(100); + Tick15(100); + } + + CHECK(fabs(Ewma1.Rate() - 100) < 0.1); + CHECK(fabs(Ewma5.Rate() - 100) < 0.1); + CHECK(fabs(Ewma15.Rate() - 100) < 0.5); + } +} + +# if 0 // This is not really a unit test, but mildly useful to exercise some code +TEST_CASE("Meter") +{ + Meter Meter1; + Meter1.Mark(1); + Sleep(1000); + Meter1.Mark(1); + Sleep(1000); + Meter1.Mark(1); + Sleep(1000); + Meter1.Mark(1); + Sleep(1000); + Meter1.Mark(1); + Sleep(1000); + Meter1.Mark(1); + Sleep(1000); + Meter1.Mark(1); + Sleep(1000); + Meter1.Mark(1); + Sleep(1000); + Meter1.Mark(1); + Sleep(1000); + [[maybe_unused]] double Rate = Meter1.MeanRate(); +} +# endif +} + +namespace zen { + +void +stats_forcelink() +{ +} + +#endif + +} // namespace zen::metrics diff --git a/src/zencore/stream.cpp b/src/zencore/stream.cpp new file mode 100644 index 000000000..3402e51be --- /dev/null +++ b/src/zencore/stream.cpp @@ -0,0 +1,79 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <stdarg.h> +#include <zencore/memory.h> +#include <zencore/stream.h> +#include <zencore/testing.h> + +#include <algorithm> +#include <stdexcept> + +namespace zen { + +void +BinaryWriter::Write(std::initializer_list<const MemoryView> Buffers) +{ + size_t TotalByteCount = 0; + for (const MemoryView& View : Buffers) + { + TotalByteCount += View.GetSize(); + } + const size_t NeedEnd = m_Offset + TotalByteCount; + if (NeedEnd > m_Buffer.size()) + { + m_Buffer.resize(NeedEnd); + } + for (const MemoryView& View : Buffers) + { + memcpy(m_Buffer.data() + m_Offset, View.GetData(), View.GetSize()); + m_Offset += View.GetSize(); + } +} + +void +BinaryWriter::Write(const void* data, size_t ByteCount, uint64_t Offset) +{ + const size_t NeedEnd = Offset + ByteCount; + + if (NeedEnd > m_Buffer.size()) + { + m_Buffer.resize(NeedEnd); + } + + memcpy(m_Buffer.data() + Offset, data, ByteCount); +} + +void +BinaryWriter::Reset() +{ + m_Buffer.clear(); + m_Offset = 0; +} + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +TEST_CASE("binary.writer.span") +{ + BinaryWriter Writer; + const MemoryView View1("apa", 3); + const MemoryView View2(" ", 1); + const MemoryView View3("banan", 5); + Writer.Write({View1, View2, View3}); + MemoryView Result = Writer.GetView(); + CHECK(Result.GetSize() == 9); + CHECK(memcmp(Result.GetData(), "apa banan", 9) == 0); +} + +void +stream_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp new file mode 100644 index 000000000..ad6ee78fc --- /dev/null +++ b/src/zencore/string.cpp @@ -0,0 +1,1004 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/memory.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +#include <inttypes.h> +#include <math.h> +#include <stdio.h> +#include <exception> +#include <ostream> +#include <stdexcept> + +#include <utf8.h> + +template<typename u16bit_iterator> +void +utf16to8_impl(u16bit_iterator StartIt, u16bit_iterator EndIt, ::zen::StringBuilderBase& OutString) +{ + while (StartIt != EndIt) + { + uint32_t cp = utf8::internal::mask16(*StartIt++); + // Take care of surrogate pairs first + if (utf8::internal::is_lead_surrogate(cp)) + { + uint32_t trail_surrogate = utf8::internal::mask16(*StartIt++); + cp = (cp << 10) + trail_surrogate + utf8::internal::SURROGATE_OFFSET; + } + OutString.AppendCodepoint(cp); + } +} + +template<typename u32bit_iterator> +void +utf32to8_impl(u32bit_iterator StartIt, u32bit_iterator EndIt, ::zen::StringBuilderBase& OutString) +{ + for (; StartIt != EndIt; ++StartIt) + { + wchar_t cp = *StartIt; + OutString.AppendCodepoint(cp); + } +} + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +bool +ToString(std::span<char> Buffer, uint64_t Num) +{ + snprintf(Buffer.data(), Buffer.size(), "%" PRIu64, Num); + + return true; +} +bool +ToString(std::span<char> Buffer, int64_t Num) +{ + snprintf(Buffer.data(), Buffer.size(), "%" PRId64, Num); + + return true; +} + +////////////////////////////////////////////////////////////////////////// + +const char* +FilepathFindExtension(const std::string_view& Path, const char* ExtensionToMatch) +{ + const size_t PathLen = Path.size(); + + if (ExtensionToMatch) + { + size_t ExtLen = strlen(ExtensionToMatch); + + if (ExtLen > PathLen) + return nullptr; + + const char* PathExtension = Path.data() + PathLen - ExtLen; + + if (StringEquals(PathExtension, ExtensionToMatch)) + return PathExtension; + + return nullptr; + } + + if (PathLen == 0) + return nullptr; + + // Look for extension introducer ('.') + + for (int64_t i = PathLen - 1; i >= 0; --i) + { + if (Path[i] == '.') + return Path.data() + i; + } + + return nullptr; +} + +////////////////////////////////////////////////////////////////////////// + +void +Utf8ToWide(const char8_t* Str8, WideStringBuilderBase& OutString) +{ + Utf8ToWide(std::u8string_view(Str8), OutString); +} + +void +Utf8ToWide(const std::string_view& Str8, WideStringBuilderBase& OutString) +{ + Utf8ToWide(std::u8string_view{reinterpret_cast<const char8_t*>(Str8.data()), Str8.size()}, OutString); +} + +std::wstring +Utf8ToWide(const std::string_view& Wstr) +{ + ExtendableWideStringBuilder<128> String; + Utf8ToWide(Wstr, String); + + return String.c_str(); +} + +void +Utf8ToWide(const std::u8string_view& Str8, WideStringBuilderBase& OutString) +{ + const char* str = (const char*)Str8.data(); + const size_t strLen = Str8.size(); + + const char* endStr = str + strLen; + size_t ByteCount = 0; + size_t CurrentOutChar = 0; + + for (; str != endStr; ++str) + { + unsigned char Data = static_cast<unsigned char>(*str); + + if (!(Data & 0x80)) + { + // ASCII + OutString.Append(wchar_t(Data)); + continue; + } + else if (!ByteCount) + { + // Start of multi-byte sequence. Figure out how + // many bytes we're going to consume + + size_t Count = 0; + + for (size_t Temp = Data; Temp & 0x80; Temp <<= 1) + ++Count; + + ByteCount = Count - 1; + CurrentOutChar = Data & (0xff >> (Count + 1)); + } + else + { + --ByteCount; + + if ((Data & 0xc0) != 0x80) + { + break; + } + + CurrentOutChar = (CurrentOutChar << 6) | (Data & 0x3f); + + if (!ByteCount) + { + OutString.Append(wchar_t(CurrentOutChar)); + CurrentOutChar = 0; + } + } + } +} + +void +WideToUtf8(const wchar_t* Wstr, StringBuilderBase& OutString) +{ + WideToUtf8(std::wstring_view{Wstr}, OutString); +} + +void +WideToUtf8(const std::wstring_view& Wstr, StringBuilderBase& OutString) +{ +#if ZEN_SIZEOF_WCHAR_T == 2 + utf16to8_impl(begin(Wstr), end(Wstr), OutString); +#else + utf32to8_impl(begin(Wstr), end(Wstr), OutString); +#endif +} + +std::string +WideToUtf8(const wchar_t* Wstr) +{ + ExtendableStringBuilder<128> String; + WideToUtf8(std::wstring_view{Wstr}, String); + + return String.c_str(); +} + +std::string +WideToUtf8(const std::wstring_view Wstr) +{ + ExtendableStringBuilder<128> String; + WideToUtf8(std::wstring_view{Wstr.data(), Wstr.size()}, String); + + return String.c_str(); +} + +////////////////////////////////////////////////////////////////////////// + +enum NicenumFormat +{ + kNicenum1024 = 0, // Print kilo, mega, tera, peta, exa.. + kNicenumBytes = 1, // Print single bytes ("13B"), kilo, mega, tera... + kNicenumTime = 2, // Print nanosecs, microsecs, millisecs, seconds... + kNicenumRaw = 3, // Print the raw number without any formatting + kNicenumRawTime = 4 // Same as RAW, but print dashes ('-') for zero. +}; + +namespace { + static const char* UnitStrings[3][7] = { + /* kNicenum1024 */ {"", "K", "M", "G", "T", "P", "E"}, + /* kNicenumBytes */ {"B", "K", "M", "G", "T", "P", "E"}, + /* kNicenumTime */ {"ns", "us", "ms", "s", "?", "?", "?"}}; + + static const int UnitsLen[] = { + /* kNicenum1024 */ 6, + /* kNicenumBytes */ 6, + /* kNicenumTime */ 3}; + + static const uint64_t KiloUnit[] = { + /* kNicenum1024 */ 1024, + /* kNicenumBytes */ 1024, + /* kNicenumTime */ 1000}; +} // namespace + +/* + * Convert a number to an appropriately human-readable output. + */ +int +NiceNumGeneral(uint64_t Num, std::span<char> Buffer, NicenumFormat Format) +{ + switch (Format) + { + case kNicenumRaw: + return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64, (uint64_t)Num); + + case kNicenumRawTime: + if (Num > 0) + { + return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64, (uint64_t)Num); + } + else + { + return snprintf(Buffer.data(), Buffer.size(), "%s", "-"); + } + break; + + case kNicenum1024: + case kNicenumBytes: + case kNicenumTime: + default: + break; + } + + // Bring into range and select unit + + int Index = 0; + uint64_t n = Num; + + { + const uint64_t Unit = KiloUnit[Format]; + const int maxIndex = UnitsLen[Format]; + + while (n >= Unit && Index < maxIndex) + { + n /= Unit; + Index++; + } + } + + const char* u = UnitStrings[Format][Index]; + + if ((Index == 0) || ((Num % (uint64_t)powl((int)KiloUnit[Format], Index)) == 0)) + { + /* + * If this is an even multiple of the base, always display + * without any decimal precision. + */ + return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64 "%s", (uint64_t)n, u); + } + else + { + /* + * We want to choose a precision that reflects the best choice + * for fitting in 5 characters. This can get rather tricky when + * we have numbers that are very close to an order of magnitude. + * For example, when displaying 10239 (which is really 9.999K), + * we want only a single place of precision for 10.0K. We could + * develop some complex heuristics for this, but it's much + * easier just to try each combination in turn. + */ + + int StrLen = 0; + + for (int i = 2; i >= 0; i--) + { + double Value = (double)Num / (uint64_t)powl((int)KiloUnit[Format], Index); + + /* + * Don't print floating point values for time. Note, + * we use floor() instead of round() here, since + * round can result in undesirable results. For + * example, if "num" is in the range of + * 999500-999999, it will print out "1000us". This + * doesn't happen if we use floor(). + */ + if (Format == kNicenumTime) + { + StrLen = snprintf(Buffer.data(), Buffer.size(), "%d%s", (unsigned int)floor(Value), u); + + if (StrLen <= 5) + break; + } + else + { + StrLen = snprintf(Buffer.data(), Buffer.size(), "%.*f%s", i, Value, u); + + if (StrLen <= 5) + break; + } + } + + return StrLen; + } +} + +size_t +NiceNumToBuffer(uint64_t Num, std::span<char> Buffer) +{ + return NiceNumGeneral(Num, Buffer, kNicenum1024); +} + +size_t +NiceBytesToBuffer(uint64_t Num, std::span<char> Buffer) +{ + return NiceNumGeneral(Num, Buffer, kNicenumBytes); +} + +size_t +NiceByteRateToBuffer(uint64_t Num, uint64_t ElapsedMs, std::span<char> Buffer) +{ + size_t n = 0; + + if (ElapsedMs) + { + n = NiceNumGeneral(Num * 1000 / ElapsedMs, Buffer, kNicenumBytes); + } + else + { + Buffer[n++] = '0'; + Buffer[n++] = 'B'; + } + + Buffer[n++] = '/'; + Buffer[n++] = 's'; + Buffer[n++] = '\0'; + + return n; +} + +size_t +NiceLatencyNsToBuffer(uint64_t Nanos, std::span<char> Buffer) +{ + return NiceNumGeneral(Nanos, Buffer, kNicenumTime); +} + +size_t +NiceTimeSpanMsToBuffer(uint64_t Millis, std::span<char> Buffer) +{ + if (Millis < 1000) + { + return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64 "ms", Millis); + } + else if (Millis < 10000) + { + return snprintf(Buffer.data(), Buffer.size(), "%.2fs", Millis / 1000.0); + } + else if (Millis < 60000) + { + return snprintf(Buffer.data(), Buffer.size(), "%.1fs", Millis / 1000.0); + } + else if (Millis < 60 * 60000) + { + return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64 "m%02" PRIu64 "s", Millis / 60000, (Millis / 1000) % 60); + } + else + { + return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64 "h%02" PRIu64 "m", Millis / 3600000, (Millis / 60000) % 60); + } +} + +////////////////////////////////////////////////////////////////////////// + +template<typename C> +StringBuilderImpl<C>::~StringBuilderImpl() +{ + if (m_IsDynamic) + { + FreeBuffer(m_Base, m_End - m_Base); + } +} + +template<typename C> +void +StringBuilderImpl<C>::Extend(size_t extraCapacity) +{ + if (!m_IsExtendable) + { + Fail("exceeded capacity"); + } + + const size_t oldCapacity = m_End - m_Base; + const size_t newCapacity = NextPow2(oldCapacity + extraCapacity); + + C* newBase = (C*)AllocBuffer(newCapacity); + + size_t pos = m_CurPos - m_Base; + memcpy(newBase, m_Base, pos * sizeof(C)); + + if (m_IsDynamic) + { + FreeBuffer(m_Base, oldCapacity); + } + + m_Base = newBase; + m_CurPos = newBase + pos; + m_End = newBase + newCapacity; + m_IsDynamic = true; +} + +template<typename C> +void* +StringBuilderImpl<C>::AllocBuffer(size_t byteCount) +{ + return Memory::Alloc(byteCount * sizeof(C)); +} + +template<typename C> +void +StringBuilderImpl<C>::FreeBuffer(void* buffer, size_t byteCount) +{ + ZEN_UNUSED(byteCount); + + Memory::Free(buffer); +} + +template<typename C> +[[noreturn]] void +StringBuilderImpl<C>::Fail(const char* reason) +{ + throw std::runtime_error(reason); +} + +// Instantiate templates once + +template class StringBuilderImpl<char>; +template class StringBuilderImpl<wchar_t>; + +////////////////////////////////////////////////////////////////////////// +// +// Unit tests +// + +#if ZEN_WITH_TESTS + +TEST_CASE("niceNum") +{ + char Buffer[16]; + + SUBCASE("raw") + { + NiceNumGeneral(1, Buffer, kNicenumRaw); + CHECK(StringEquals(Buffer, "1")); + + NiceNumGeneral(10, Buffer, kNicenumRaw); + CHECK(StringEquals(Buffer, "10")); + + NiceNumGeneral(100, Buffer, kNicenumRaw); + CHECK(StringEquals(Buffer, "100")); + + NiceNumGeneral(1000, Buffer, kNicenumRaw); + CHECK(StringEquals(Buffer, "1000")); + + NiceNumGeneral(10000, Buffer, kNicenumRaw); + CHECK(StringEquals(Buffer, "10000")); + + NiceNumGeneral(100000, Buffer, kNicenumRaw); + CHECK(StringEquals(Buffer, "100000")); + } + + SUBCASE("1024") + { + NiceNumGeneral(1, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "1")); + + NiceNumGeneral(10, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "10")); + + NiceNumGeneral(100, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "100")); + + NiceNumGeneral(1000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "1000")); + + NiceNumGeneral(10000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "9.77K")); + + NiceNumGeneral(100000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "97.7K")); + + NiceNumGeneral(1000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "977K")); + + NiceNumGeneral(10000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "9.54M")); + + NiceNumGeneral(100000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "95.4M")); + + NiceNumGeneral(1000000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "954M")); + + NiceNumGeneral(10000000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "9.31G")); + + NiceNumGeneral(100000000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "93.1G")); + + NiceNumGeneral(1000000000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "931G")); + + NiceNumGeneral(10000000000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "9.09T")); + + NiceNumGeneral(100000000000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "90.9T")); + + NiceNumGeneral(1000000000000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "909T")); + + NiceNumGeneral(10000000000000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "8.88P")); + + NiceNumGeneral(100000000000000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "88.8P")); + + NiceNumGeneral(1000000000000000000, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "888P")); + + NiceNumGeneral(10000000000000000000ull, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "8.67E")); + + // pow2 + + NiceNumGeneral(0, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "0")); + + NiceNumGeneral(1, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "1")); + + NiceNumGeneral(1024, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "1K")); + + NiceNumGeneral(1024 * 1024, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "1M")); + + NiceNumGeneral(1024 * 1024 * 1024, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "1G")); + + NiceNumGeneral(1024llu * 1024 * 1024 * 1024, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "1T")); + + NiceNumGeneral(1024llu * 1024 * 1024 * 1024 * 1024, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "1P")); + + NiceNumGeneral(1024llu * 1024 * 1024 * 1024 * 1024 * 1024, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "1E")); + + // pow2-1 + + NiceNumGeneral(1023, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "1023")); + + NiceNumGeneral(2047, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "2.00K")); + + NiceNumGeneral(9 * 1024 - 1, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "9.00K")); + + NiceNumGeneral(10 * 1024 - 1, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "10.0K")); + + NiceNumGeneral(10 * 1024 - 5, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "10.0K")); + + NiceNumGeneral(10 * 1024 - 6, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "9.99K")); + + NiceNumGeneral(10 * 1024 - 10, Buffer, kNicenum1024); + CHECK(StringEquals(Buffer, "9.99K")); + } + + SUBCASE("time") + { + NiceNumGeneral(1, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "1ns")); + + NiceNumGeneral(100, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "100ns")); + + NiceNumGeneral(1000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "1us")); + + NiceNumGeneral(10000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "10us")); + + NiceNumGeneral(100000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "100us")); + + NiceNumGeneral(1000000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "1ms")); + + NiceNumGeneral(10000000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "10ms")); + + NiceNumGeneral(100000000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "100ms")); + + NiceNumGeneral(1000000000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "1s")); + + NiceNumGeneral(10000000000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "10s")); + + NiceNumGeneral(100000000000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "100s")); + + NiceNumGeneral(1000000000000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "1000s")); + + NiceNumGeneral(10000000000000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "10000s")); + + NiceNumGeneral(100000000000000, Buffer, kNicenumTime); + CHECK(StringEquals(Buffer, "100000s")); + } + + SUBCASE("bytes") + { + NiceNumGeneral(1, Buffer, kNicenumBytes); + CHECK(StringEquals(Buffer, "1B")); + + NiceNumGeneral(10, Buffer, kNicenumBytes); + CHECK(StringEquals(Buffer, "10B")); + + NiceNumGeneral(100, Buffer, kNicenumBytes); + CHECK(StringEquals(Buffer, "100B")); + + NiceNumGeneral(1000, Buffer, kNicenumBytes); + CHECK(StringEquals(Buffer, "1000B")); + + NiceNumGeneral(10000, Buffer, kNicenumBytes); + CHECK(StringEquals(Buffer, "9.77K")); + } + + SUBCASE("byteRate") + { + NiceByteRateToBuffer(1, 1, Buffer); + CHECK(StringEquals(Buffer, "1000B/s")); + + NiceByteRateToBuffer(1000, 1000, Buffer); + CHECK(StringEquals(Buffer, "1000B/s")); + + NiceByteRateToBuffer(1024, 1, Buffer); + CHECK(StringEquals(Buffer, "1000K/s")); + + NiceByteRateToBuffer(1024, 1000, Buffer); + CHECK(StringEquals(Buffer, "1K/s")); + } + + SUBCASE("timespan") + { + NiceTimeSpanMsToBuffer(1, Buffer); + CHECK(StringEquals(Buffer, "1ms")); + + NiceTimeSpanMsToBuffer(900, Buffer); + CHECK(StringEquals(Buffer, "900ms")); + + NiceTimeSpanMsToBuffer(1000, Buffer); + CHECK(StringEquals(Buffer, "1.00s")); + + NiceTimeSpanMsToBuffer(1900, Buffer); + CHECK(StringEquals(Buffer, "1.90s")); + + NiceTimeSpanMsToBuffer(19000, Buffer); + CHECK(StringEquals(Buffer, "19.0s")); + + NiceTimeSpanMsToBuffer(60000, Buffer); + CHECK(StringEquals(Buffer, "1m00s")); + + NiceTimeSpanMsToBuffer(600000, Buffer); + CHECK(StringEquals(Buffer, "10m00s")); + + NiceTimeSpanMsToBuffer(3600000, Buffer); + CHECK(StringEquals(Buffer, "1h00m")); + + NiceTimeSpanMsToBuffer(36000000, Buffer); + CHECK(StringEquals(Buffer, "10h00m")); + + NiceTimeSpanMsToBuffer(360000000, Buffer); + CHECK(StringEquals(Buffer, "100h00m")); + } +} + +void +string_forcelink() +{ +} + +TEST_CASE("StringBuilder") +{ + StringBuilder<64> sb; + + SUBCASE("Empty init") + { + const char* str = sb.c_str(); + + CHECK(StringLength(str) == 0); + } + + SUBCASE("Append single character") + { + sb.Append('a'); + + const char* str = sb.c_str(); + CHECK(StringLength(str) == 1); + CHECK(str[0] == 'a'); + + sb.Append('b'); + str = sb.c_str(); + CHECK(StringLength(str) == 2); + CHECK(str[0] == 'a'); + CHECK(str[1] == 'b'); + } + + SUBCASE("Append string") + { + sb.Append("a"); + + const char* str = sb.c_str(); + CHECK(StringLength(str) == 1); + CHECK(str[0] == 'a'); + + sb.Append("b"); + str = sb.c_str(); + CHECK(StringLength(str) == 2); + CHECK(str[0] == 'a'); + CHECK(str[1] == 'b'); + + sb.Append("cdefghijklmnopqrstuvwxyz"); + CHECK(sb.Size() == 26); + + sb.Append("abcdefghijklmnopqrstuvwxyz"); + CHECK(sb.Size() == 52); + + sb.Append("abcdefghijk"); + CHECK(sb.Size() == 63); + } +} + +TEST_CASE("ExtendableStringBuilder") +{ + ExtendableStringBuilder<16> sb; + + SUBCASE("Empty init") + { + const char* str = sb.c_str(); + + CHECK(StringLength(str) == 0); + } + + SUBCASE("Short append") + { + sb.Append("abcd"); + CHECK(sb.IsDynamic() == false); + } + + SUBCASE("Short+long append") + { + sb.Append("abcd"); + CHECK(sb.IsDynamic() == false); + // This should trigger a dynamic buffer allocation since the required + // capacity exceeds the internal fixed buffer. + sb.Append("abcdefghijklmnopqrstuvwxyz"); + CHECK(sb.IsDynamic() == true); + CHECK(sb.Size() == 30); + CHECK(sb.Size() == StringLength(sb.c_str())); + } +} + +TEST_CASE("WideStringBuilder") +{ + WideStringBuilder<64> sb; + + SUBCASE("Empty init") + { + const wchar_t* str = sb.c_str(); + + CHECK(StringLength(str) == 0); + } + + SUBCASE("Append single character") + { + sb.Append(L'a'); + + const wchar_t* str = sb.c_str(); + CHECK(StringLength(str) == 1); + CHECK(str[0] == L'a'); + + sb.Append(L'b'); + str = sb.c_str(); + CHECK(StringLength(str) == 2); + CHECK(str[0] == L'a'); + CHECK(str[1] == L'b'); + } + + SUBCASE("Append string") + { + sb.Append(L"a"); + + const wchar_t* str = sb.c_str(); + CHECK(StringLength(str) == 1); + CHECK(str[0] == L'a'); + + sb.Append(L"b"); + str = sb.c_str(); + CHECK(StringLength(str) == 2); + CHECK(str[0] == L'a'); + CHECK(str[1] == L'b'); + + sb.Append(L"cdefghijklmnopqrstuvwxyz"); + CHECK(sb.Size() == 26); + + sb.Append(L"abcdefghijklmnopqrstuvwxyz"); + CHECK(sb.Size() == 52); + + sb.Append(L"abcdefghijk"); + CHECK(sb.Size() == 63); + } +} + +TEST_CASE("ExtendableWideStringBuilder") +{ + ExtendableWideStringBuilder<16> sb; + + SUBCASE("Empty init") + { + CHECK(sb.Size() == 0); + + const wchar_t* str = sb.c_str(); + CHECK(StringLength(str) == 0); + } + + SUBCASE("Short append") + { + sb.Append(L"abcd"); + CHECK(sb.IsDynamic() == false); + } + + SUBCASE("Short+long append") + { + sb.Append(L"abcd"); + CHECK(sb.IsDynamic() == false); + // This should trigger a dynamic buffer allocation since the required + // capacity exceeds the internal fixed buffer. + sb.Append(L"abcdefghijklmnopqrstuvwxyz"); + CHECK(sb.IsDynamic() == true); + CHECK(sb.Size() == 30); + CHECK(sb.Size() == StringLength(sb.c_str())); + } +} + +TEST_CASE("utf8") +{ + SUBCASE("utf8towide") + { + // TODO: add more extensive testing here - this covers a very small space + + WideStringBuilder<32> wout; + Utf8ToWide(u8"abcdefghi", wout); + CHECK(StringEquals(L"abcdefghi", wout.c_str())); + + wout.Reset(); + + Utf8ToWide(u8"abc���", wout); + CHECK(StringEquals(L"abc���", wout.c_str())); + } + + SUBCASE("widetoutf8") + { + // TODO: add more extensive testing here - this covers a very small space + + StringBuilder<32> out; + + WideToUtf8(L"abcdefghi", out); + CHECK(StringEquals("abcdefghi", out.c_str())); + + out.Reset(); + + WideToUtf8(L"abc���", out); + CHECK(StringEquals(u8"abc���", out.c_str())); + } +} + +TEST_CASE("filepath") +{ + CHECK(FilepathFindExtension("foo\\bar\\baz.txt", ".txt") != nullptr); + CHECK(FilepathFindExtension("foo\\bar\\baz.txt", ".zap") == nullptr); + + CHECK(FilepathFindExtension("foo\\bar\\baz.txt") != nullptr); + CHECK(FilepathFindExtension("foo\\bar\\baz.txt") == std::string_view(".txt")); + + CHECK(FilepathFindExtension(".txt") == std::string_view(".txt")); +} + +TEST_CASE("string") +{ + using namespace std::literals; + + SUBCASE("hash_djb2") + { + CHECK(HashStringAsLowerDjb2("AbcdZ"sv) == HashStringDjb2("abcdz"sv)); + CHECK(HashStringAsLowerDjb2("aBCd"sv) == HashStringDjb2("abcd"sv)); + CHECK(HashStringAsLowerDjb2("aBCd"sv) == HashStringDjb2(ToLower("aBCd"sv))); + } + + SUBCASE("tolower") + { + CHECK_EQ(ToLower("te!st"sv), "te!st"sv); + CHECK_EQ(ToLower("TE%St"sv), "te%st"sv); + } + + SUBCASE("StrCaseCompare") + { + CHECK(StrCaseCompare("foo", "FoO") == 0); + CHECK(StrCaseCompare("Bar", "bAs") < 0); + CHECK(StrCaseCompare("bAr", "Bas") < 0); + CHECK(StrCaseCompare("BBr", "Bar") > 0); + CHECK(StrCaseCompare("Bbr", "BAr") > 0); + CHECK(StrCaseCompare("foo", "FoO", 3) == 0); + CHECK(StrCaseCompare("Bar", "bAs", 3) < 0); + CHECK(StrCaseCompare("BBr", "Bar", 2) > 0); + } + + SUBCASE("ForEachStrTok") + { + const auto Tokens = "here,is,my,different,tokens"sv; + int32_t ExpectedTokenCount = 5; + int32_t TokenCount = 0; + StringBuilder<512> Sb; + + TokenCount = ForEachStrTok(Tokens, ',', [&Sb](const std::string_view& Token) { + if (Sb.Size()) + { + Sb << ","; + } + Sb << Token; + return true; + }); + + CHECK(TokenCount == ExpectedTokenCount); + CHECK(Sb.ToString() == Tokens); + + ExpectedTokenCount = 1; + const auto Str = "mosdef"sv; + + Sb.Reset(); + TokenCount = ForEachStrTok(Str, ' ', [&Sb](const std::string_view& Token) { + Sb << Token; + return true; + }); + CHECK(Sb.ToString() == Str); + CHECK(TokenCount == ExpectedTokenCount); + + ExpectedTokenCount = 0; + TokenCount = ForEachStrTok(""sv, ',', [](const std::string_view&) { return true; }); + CHECK(TokenCount == ExpectedTokenCount); + } +} + +#endif + +} // namespace zen diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp new file mode 100644 index 000000000..1599e9d1f --- /dev/null +++ b/src/zencore/testing.cpp @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/testing.h" +#include "zencore/logging.h" + +#if ZEN_WITH_TESTS + +namespace zen::testing { + +using namespace std::literals; + +struct TestRunner::Impl +{ + doctest::Context Session; +}; + +TestRunner::TestRunner() +{ + m_Impl = std::make_unique<Impl>(); +} + +TestRunner::~TestRunner() +{ +} + +int +TestRunner::ApplyCommandLine(int argc, char const* const* argv) +{ + m_Impl->Session.applyCommandLine(argc, argv); + + for (int i = 1; i < argc; ++i) + { + if (argv[i] == "--debug"sv) + { + spdlog::set_level(spdlog::level::debug); + } + } + + return 0; +} + +int +TestRunner::Run() +{ + int Rv = 0; + + m_Impl->Session.run(); + + return Rv; +} + +} // namespace zen::testing + +#endif diff --git a/src/zencore/testutils.cpp b/src/zencore/testutils.cpp new file mode 100644 index 000000000..dbc3ab5af --- /dev/null +++ b/src/zencore/testutils.cpp @@ -0,0 +1,42 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zencore/testutils.h" +#include <zencore/session.h> +#include "zencore/string.h" + +#include <atomic> + +namespace zen { + +static std::atomic<int> Sequence{0}; + +std::filesystem::path +CreateTemporaryDirectory() +{ + std::error_code Ec; + + std::filesystem::path DirPath = std::filesystem::temp_directory_path() / GetSessionIdString() / IntNum(++Sequence).c_str(); + std::filesystem::remove_all(DirPath, Ec); + std::filesystem::create_directories(DirPath); + + return DirPath; +} + +ScopedTemporaryDirectory::ScopedTemporaryDirectory() : m_RootPath(CreateTemporaryDirectory()) +{ +} + +ScopedTemporaryDirectory::ScopedTemporaryDirectory(std::filesystem::path Directory) : m_RootPath(Directory) +{ + std::error_code Ec; + std::filesystem::remove_all(Directory, Ec); + std::filesystem::create_directories(Directory); +} + +ScopedTemporaryDirectory::~ScopedTemporaryDirectory() +{ + std::error_code Ec; + std::filesystem::remove_all(m_RootPath, Ec); +} + +} // namespace zen
\ No newline at end of file diff --git a/src/zencore/thread.cpp b/src/zencore/thread.cpp new file mode 100644 index 000000000..1597a7dd9 --- /dev/null +++ b/src/zencore/thread.cpp @@ -0,0 +1,1212 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/thread.h> + +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +#if ZEN_PLATFORM_LINUX +# if !defined(_GNU_SOURCE) +# define _GNU_SOURCE // for semtimedop() +# endif +#endif + +#if ZEN_PLATFORM_WINDOWS +# include <shellapi.h> +# include <Shlobj.h> +# include <zencore/windows.h> +#else +# include <chrono> +# include <condition_variable> +# include <mutex> + +# include <fcntl.h> +# include <pthread.h> +# include <signal.h> +# include <sys/file.h> +# include <sys/sem.h> +# include <sys/stat.h> +# include <sys/syscall.h> +# include <sys/wait.h> +# include <time.h> +# include <unistd.h> +#endif + +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +#if ZEN_PLATFORM_WINDOWS +// The information on how to set the thread name comes from +// a MSDN article: http://msdn2.microsoft.com/en-us/library/xcb2z8hs.aspx +const DWORD kVCThreadNameException = 0x406D1388; +typedef struct tagTHREADNAME_INFO +{ + DWORD dwType; // Must be 0x1000. + LPCSTR szName; // Pointer to name (in user addr space). + DWORD dwThreadID; // Thread ID (-1=caller thread). + DWORD dwFlags; // Reserved for future use, must be zero. +} THREADNAME_INFO; +// The SetThreadDescription API was brought in version 1607 of Windows 10. +typedef HRESULT(WINAPI* SetThreadDescription)(HANDLE hThread, PCWSTR lpThreadDescription); +// This function has try handling, so it is separated out of its caller. +void +SetNameInternal(DWORD thread_id, const char* name) +{ + THREADNAME_INFO info; + info.dwType = 0x1000; + info.szName = name; + info.dwThreadID = thread_id; + info.dwFlags = 0; + __try + { + RaiseException(kVCThreadNameException, 0, sizeof(info) / sizeof(DWORD), reinterpret_cast<DWORD_PTR*>(&info)); + } + __except (EXCEPTION_CONTINUE_EXECUTION) + { + } +} +#endif + +#if ZEN_PLATFORM_LINUX +const bool bNoZombieChildren = []() { + // When a child process exits it is put into a zombie state until the parent + // collects its result. This doesn't fit the Windows-like model that Zen uses + // where there is a less strict familial model and no zombification. Ignoring + // SIGCHLD siganals removes the need to call wait() on zombies. Another option + // would be for the child to call setsid() but that would detatch the child + // from the terminal. + struct sigaction Action = {}; + sigemptyset(&Action.sa_mask); + Action.sa_handler = SIG_IGN; + sigaction(SIGCHLD, &Action, nullptr); + return true; +}(); +#endif + +void +SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName) +{ +#if ZEN_PLATFORM_WINDOWS + // The SetThreadDescription API works even if no debugger is attached. + static auto SetThreadDescriptionFunc = + reinterpret_cast<SetThreadDescription>(::GetProcAddress(::GetModuleHandle(L"Kernel32.dll"), "SetThreadDescription")); + + if (SetThreadDescriptionFunc) + { + SetThreadDescriptionFunc(::GetCurrentThread(), Utf8ToWide(ThreadName).c_str()); + } + // The debugger needs to be around to catch the name in the exception. If + // there isn't a debugger, we are just needlessly throwing an exception. + if (!::IsDebuggerPresent()) + return; + + std::string ThreadNameZ{ThreadName}; + SetNameInternal(GetCurrentThreadId(), ThreadNameZ.c_str()); +#else + std::string ThreadNameZ{ThreadName}; +# if ZEN_PLATFORM_MAC + pthread_setname_np(ThreadNameZ.c_str()); +# else + pthread_setname_np(pthread_self(), ThreadNameZ.c_str()); +# endif +#endif +} // namespace zen + +void +RwLock::AcquireShared() +{ + m_Mutex.lock_shared(); +} + +void +RwLock::ReleaseShared() +{ + m_Mutex.unlock_shared(); +} + +void +RwLock::AcquireExclusive() +{ + m_Mutex.lock(); +} + +void +RwLock::ReleaseExclusive() +{ + m_Mutex.unlock(); +} + +////////////////////////////////////////////////////////////////////////// + +#if !ZEN_PLATFORM_WINDOWS +struct EventInner +{ + std::mutex Mutex; + std::condition_variable CondVar; + bool volatile bSet = false; +}; +#endif // !ZEN_PLATFORM_WINDOWS + +Event::Event() +{ + bool bManualReset = true; + bool bInitialState = false; + +#if ZEN_PLATFORM_WINDOWS + m_EventHandle = CreateEvent(nullptr, bManualReset, bInitialState, nullptr); +#else + ZEN_UNUSED(bManualReset); + auto* Inner = new EventInner(); + Inner->bSet = bInitialState; + m_EventHandle = Inner; +#endif +} + +Event::~Event() +{ + Close(); +} + +void +Event::Set() +{ +#if ZEN_PLATFORM_WINDOWS + SetEvent(m_EventHandle); +#else + auto* Inner = (EventInner*)m_EventHandle; + { + std::unique_lock Lock(Inner->Mutex); + Inner->bSet = true; + } + Inner->CondVar.notify_all(); +#endif +} + +void +Event::Reset() +{ +#if ZEN_PLATFORM_WINDOWS + ResetEvent(m_EventHandle); +#else + auto* Inner = (EventInner*)m_EventHandle; + { + std::unique_lock Lock(Inner->Mutex); + Inner->bSet = false; + } +#endif +} + +void +Event::Close() +{ +#if ZEN_PLATFORM_WINDOWS + CloseHandle(m_EventHandle); +#else + auto* Inner = (EventInner*)m_EventHandle; + delete Inner; +#endif + m_EventHandle = nullptr; +} + +bool +Event::Wait(int TimeoutMs) +{ +#if ZEN_PLATFORM_WINDOWS + using namespace std::literals; + + const DWORD Timeout = (TimeoutMs < 0) ? INFINITE : TimeoutMs; + + DWORD Result = WaitForSingleObject(m_EventHandle, Timeout); + + if (Result == WAIT_FAILED) + { + zen::ThrowLastError("Event wait failed"sv); + } + + return (Result == WAIT_OBJECT_0); +#else + auto* Inner = (EventInner*)m_EventHandle; + + if (TimeoutMs >= 0) + { + std::unique_lock Lock(Inner->Mutex); + + if (Inner->bSet) + { + return true; + } + + return Inner->CondVar.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), [&] { return Inner->bSet; }); + } + + std::unique_lock Lock(Inner->Mutex); + + if (!Inner->bSet) + { + Inner->CondVar.wait(Lock, [&] { return Inner->bSet; }); + } + + return true; +#endif +} + +////////////////////////////////////////////////////////////////////////// + +NamedEvent::NamedEvent(std::string_view EventName) +{ +#if ZEN_PLATFORM_WINDOWS + using namespace std::literals; + + ExtendableStringBuilder<64> Name; + Name << "Local\\"sv; + Name << EventName; + + m_EventHandle = CreateEventA(nullptr, true, false, Name.c_str()); +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + // Create a file to back the semaphore + ExtendableStringBuilder<64> EventPath; + EventPath << "/tmp/" << EventName; + + int Fd = open(EventPath.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666); + if (Fd < 0) + { + ThrowLastError(fmt::format("Failed to create '{}' for named event", EventPath)); + } + fchmod(Fd, 0666); + + // Use the file path to generate an IPC key + key_t IpcKey = ftok(EventPath.c_str(), 1); + if (IpcKey < 0) + { + close(Fd); + ThrowLastError("Failed to create an SysV IPC key"); + } + + // Use the key to create/open the semaphore + int Sem = semget(IpcKey, 1, 0600 | IPC_CREAT); + if (Sem < 0) + { + close(Fd); + ThrowLastError("Failed creating an SysV semaphore"); + } + + // Atomically claim ownership of the semaphore's key. The owner initialises + // the semaphore to 1 so we can use the wait-for-zero op as that does not + // modify the semaphore's value on a successful wait. + int LockResult = flock(Fd, LOCK_EX | LOCK_NB); + if (LockResult == 0) + { + // This isn't thread safe really. Another thread could open the same + // semaphore and successfully wait on it in the period of time where + // this comment is but before the semaphore's initialised. + semctl(Sem, 0, SETVAL, 1); + } + + // Pack into the handle + static_assert(sizeof(Sem) + sizeof(Fd) <= sizeof(void*), "Semaphore packing assumptions not met"); + intptr_t Packed; + Packed = intptr_t(Sem) << 32; + Packed |= intptr_t(Fd) & 0xffff'ffff; + m_EventHandle = (void*)Packed; +#endif +} + +NamedEvent::~NamedEvent() +{ + Close(); +} + +void +NamedEvent::Close() +{ + if (m_EventHandle == nullptr) + { + return; + } + +#if ZEN_PLATFORM_WINDOWS + CloseHandle(m_EventHandle); +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Fd = int(intptr_t(m_EventHandle) & 0xffff'ffff); + + if (flock(Fd, LOCK_EX | LOCK_NB) == 0) + { + std::filesystem::path Name = PathFromHandle((void*)(intptr_t(Fd))); + unlink(Name.c_str()); + + flock(Fd, LOCK_UN | LOCK_NB); + close(Fd); + + int Sem = int(intptr_t(m_EventHandle) >> 32); + semctl(Sem, 0, IPC_RMID); + } +#endif + + m_EventHandle = nullptr; +} + +void +NamedEvent::Set() +{ +#if ZEN_PLATFORM_WINDOWS + SetEvent(m_EventHandle); +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Sem = int(intptr_t(m_EventHandle) >> 32); + semctl(Sem, 0, SETVAL, 0); +#endif +} + +bool +NamedEvent::Wait(int TimeoutMs) +{ +#if ZEN_PLATFORM_WINDOWS + const DWORD Timeout = (TimeoutMs < 0) ? INFINITE : TimeoutMs; + + DWORD Result = WaitForSingleObject(m_EventHandle, Timeout); + + if (Result == WAIT_FAILED) + { + using namespace std::literals; + zen::ThrowLastError("Event wait failed"sv); + } + + return (Result == WAIT_OBJECT_0); +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Sem = int(intptr_t(m_EventHandle) >> 32); + + int Result; + struct sembuf SemOp = {}; + + if (TimeoutMs < 0) + { + Result = semop(Sem, &SemOp, 1); + return Result == 0; + } + +# if defined(_GNU_SOURCE) + struct timespec TimeoutValue = { + .tv_sec = TimeoutMs >> 10, + .tv_nsec = (TimeoutMs & 0x3ff) << 20, + }; + Result = semtimedop(Sem, &SemOp, 1, &TimeoutValue); +# else + const int SleepTimeMs = 10; + SemOp.sem_flg = IPC_NOWAIT; + do + { + Result = semop(Sem, &SemOp, 1); + if (Result == 0 || errno != EAGAIN) + { + break; + } + + Sleep(SleepTimeMs); + TimeoutMs -= SleepTimeMs; + } while (TimeoutMs > 0); +# endif // _GNU_SOURCE + + return Result == 0; +#endif +} + +////////////////////////////////////////////////////////////////////////// + +NamedMutex::~NamedMutex() +{ +#if ZEN_PLATFORM_WINDOWS + if (m_MutexHandle) + { + CloseHandle(m_MutexHandle); + } +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Inner = int(intptr_t(m_MutexHandle)); + flock(Inner, LOCK_UN); + close(Inner); +#endif +} + +bool +NamedMutex::Create(std::string_view MutexName) +{ +#if ZEN_PLATFORM_WINDOWS + ZEN_ASSERT(m_MutexHandle == nullptr); + + using namespace std::literals; + + ExtendableStringBuilder<64> Name; + Name << "Global\\"sv; + Name << MutexName; + + m_MutexHandle = CreateMutexA(nullptr, /* InitialOwner */ TRUE, Name.c_str()); + + return !!m_MutexHandle; +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + ExtendableStringBuilder<64> Name; + Name << "/tmp/" << MutexName; + + int Inner = open(Name.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666); + if (Inner < 0) + { + return false; + } + fchmod(Inner, 0666); + + if (flock(Inner, LOCK_EX) != 0) + { + close(Inner); + Inner = 0; + return false; + } + + m_MutexHandle = (void*)(intptr_t(Inner)); + return true; +#endif // ZEN_PLATFORM_WINDOWS +} + +bool +NamedMutex::Exists(std::string_view MutexName) +{ +#if ZEN_PLATFORM_WINDOWS + using namespace std::literals; + + ExtendableStringBuilder<64> Name; + Name << "Global\\"sv; + Name << MutexName; + + void* MutexHandle = OpenMutexA(SYNCHRONIZE, /* InheritHandle */ FALSE, Name.c_str()); + + if (MutexHandle == nullptr) + { + return false; + } + + CloseHandle(MutexHandle); + + return true; +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + ExtendableStringBuilder<64> Name; + Name << "/tmp/" << MutexName; + + bool bExists = false; + int Fd = open(Name.c_str(), O_RDWR | O_CLOEXEC); + if (Fd >= 0) + { + if (flock(Fd, LOCK_EX | LOCK_NB) == 0) + { + flock(Fd, LOCK_UN | LOCK_NB); + } + else + { + bExists = true; + } + close(Fd); + } + + return bExists; +#endif // ZEN_PLATFORM_WINDOWS +} + +////////////////////////////////////////////////////////////////////////// + +ProcessHandle::ProcessHandle() = default; + +#if ZEN_PLATFORM_WINDOWS +void +ProcessHandle::Initialize(void* ProcessHandle) +{ + ZEN_ASSERT(m_ProcessHandle == nullptr); + + if (ProcessHandle == INVALID_HANDLE_VALUE) + { + ProcessHandle = nullptr; + } + + // TODO: perform some debug verification here to verify it's a valid handle? + m_ProcessHandle = ProcessHandle; + m_Pid = GetProcessId(m_ProcessHandle); +} +#endif // ZEN_PLATFORM_WINDOWS + +ProcessHandle::~ProcessHandle() +{ + Reset(); +} + +void +ProcessHandle::Initialize(int Pid) +{ + ZEN_ASSERT(m_ProcessHandle == nullptr); + +#if ZEN_PLATFORM_WINDOWS + m_ProcessHandle = OpenProcess(PROCESS_QUERY_INFORMATION | SYNCHRONIZE, FALSE, Pid); +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + if (Pid > 0) + { + m_ProcessHandle = (void*)(intptr_t(Pid)); + } +#endif + + if (!m_ProcessHandle) + { + ThrowLastError(fmt::format("ProcessHandle::Initialize(pid: {}) failed", Pid)); + } + + m_Pid = Pid; +} + +bool +ProcessHandle::IsRunning() const +{ + bool bActive = false; + +#if ZEN_PLATFORM_WINDOWS + DWORD ExitCode = 0; + GetExitCodeProcess(m_ProcessHandle, &ExitCode); + bActive = (ExitCode == STILL_ACTIVE); +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + bActive = (kill(pid_t(m_Pid), 0) == 0); +#endif + + return bActive; +} + +bool +ProcessHandle::IsValid() const +{ + return (m_ProcessHandle != nullptr); +} + +void +ProcessHandle::Terminate(int ExitCode) +{ + if (!IsRunning()) + { + return; + } + + bool bSuccess = false; + +#if ZEN_PLATFORM_WINDOWS + TerminateProcess(m_ProcessHandle, ExitCode); + DWORD WaitResult = WaitForSingleObject(m_ProcessHandle, INFINITE); + bSuccess = (WaitResult != WAIT_OBJECT_0); +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + ZEN_UNUSED(ExitCode); + bSuccess = (kill(m_Pid, SIGKILL) == 0); +#endif + + if (!bSuccess) + { + // What might go wrong here, and what is meaningful to act on? + } +} + +void +ProcessHandle::Reset() +{ + if (IsValid()) + { +#if ZEN_PLATFORM_WINDOWS + CloseHandle(m_ProcessHandle); +#endif + m_ProcessHandle = nullptr; + m_Pid = 0; + } +} + +bool +ProcessHandle::Wait(int TimeoutMs) +{ + using namespace std::literals; + +#if ZEN_PLATFORM_WINDOWS + const DWORD Timeout = (TimeoutMs < 0) ? INFINITE : TimeoutMs; + + const DWORD WaitResult = WaitForSingleObject(m_ProcessHandle, Timeout); + + switch (WaitResult) + { + case WAIT_OBJECT_0: + return true; + + case WAIT_TIMEOUT: + return false; + + case WAIT_FAILED: + break; + } +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + const int SleepMs = 20; + timespec SleepTime = {0, SleepMs * 1000 * 1000}; + for (int i = 0;; i += SleepMs) + { +# if ZEN_PLATFORM_MAC + int WaitState = 0; + waitpid(m_Pid, &WaitState, WNOHANG | WCONTINUED | WUNTRACED); +# endif + + if (kill(m_Pid, 0) < 0) + { + if (zen::GetLastError() == ESRCH) + { + return true; + } + break; + } + + if (TimeoutMs >= 0 && i >= TimeoutMs) + { + return false; + } + + nanosleep(&SleepTime, nullptr); + } +#endif + + // What might go wrong here, and what is meaningful to act on? + ThrowLastError("Process::Wait failed"sv); +} + +////////////////////////////////////////////////////////////////////////// + +#if !ZEN_PLATFORM_WINDOWS || ZEN_WITH_TESTS +static void +BuildArgV(std::vector<char*>& Out, char* CommandLine) +{ + char* Cursor = CommandLine; + while (true) + { + // Skip leading whitespace + for (; *Cursor == ' '; ++Cursor) + ; + + // Check for nullp terminator + if (*Cursor == '\0') + { + break; + } + + Out.push_back(Cursor); + + // Extract word + int QuoteCount = 0; + do + { + QuoteCount += (*Cursor == '\"'); + if (*Cursor == ' ' && !(QuoteCount & 1)) + { + break; + } + ++Cursor; + } while (*Cursor != '\0'); + + if (*Cursor == '\0') + { + break; + } + + *Cursor = '\0'; + ++Cursor; + } +} +#endif // !WINDOWS || TESTS + +#if ZEN_PLATFORM_WINDOWS +static CreateProcResult +CreateProcNormal(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options) +{ + PROCESS_INFORMATION ProcessInfo{}; + STARTUPINFO StartupInfo{.cb = sizeof(STARTUPINFO)}; + + const bool InheritHandles = false; + void* Environment = nullptr; + LPSECURITY_ATTRIBUTES ProcessAttributes = nullptr; + LPSECURITY_ATTRIBUTES ThreadAttributes = nullptr; + + DWORD CreationFlags = 0; + if (Options.Flags & CreateProcOptions::Flag_NewConsole) + { + CreationFlags |= CREATE_NEW_CONSOLE; + } + + const wchar_t* WorkingDir = nullptr; + if (Options.WorkingDirectory != nullptr) + { + WorkingDir = Options.WorkingDirectory->c_str(); + } + + ExtendableWideStringBuilder<256> CommandLineZ; + CommandLineZ << CommandLine; + + BOOL Success = CreateProcessW(Executable.c_str(), + CommandLineZ.Data(), + ProcessAttributes, + ThreadAttributes, + InheritHandles, + CreationFlags, + Environment, + WorkingDir, + &StartupInfo, + &ProcessInfo); + + if (!Success) + { + return nullptr; + } + + CloseHandle(ProcessInfo.hThread); + return ProcessInfo.hProcess; +} + +static CreateProcResult +CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options) +{ + /* Launches a binary with the shell as its parent. The shell (such as + Explorer) should be an unelevated process. */ + + // No sense in using this route if we are not elevated in the first place + if (IsUserAnAdmin() == FALSE) + { + return CreateProcNormal(Executable, CommandLine, Options); + } + + // Get the users' shell process and open it for process creation + HWND ShellWnd = GetShellWindow(); + if (ShellWnd == nullptr) + { + return nullptr; + } + + DWORD ShellPid; + GetWindowThreadProcessId(ShellWnd, &ShellPid); + + HANDLE Process = OpenProcess(PROCESS_CREATE_PROCESS, FALSE, ShellPid); + if (Process == nullptr) + { + return nullptr; + } + auto $0 = MakeGuard([&] { CloseHandle(Process); }); + + // Creating a process as a child of another process is done by setting a + // thread-attribute list on the startup info passed to CreateProcess() + SIZE_T AttrListSize; + InitializeProcThreadAttributeList(nullptr, 1, 0, &AttrListSize); + + auto AttrList = (PPROC_THREAD_ATTRIBUTE_LIST)malloc(AttrListSize); + auto $1 = MakeGuard([&] { free(AttrList); }); + + if (!InitializeProcThreadAttributeList(AttrList, 1, 0, &AttrListSize)) + { + return nullptr; + } + + BOOL bOk = + UpdateProcThreadAttribute(AttrList, 0, PROC_THREAD_ATTRIBUTE_PARENT_PROCESS, (HANDLE*)&Process, sizeof(Process), nullptr, nullptr); + if (!bOk) + { + return nullptr; + } + + // By this point we know we are an elevated process. It is not allowed to + // create a process as a child of another unelevated process that share our + // elevated console window if we have one. So we'll need to create a new one. + uint32_t CreateProcFlags = EXTENDED_STARTUPINFO_PRESENT; + if (GetConsoleWindow() != nullptr) + { + CreateProcFlags |= CREATE_NEW_CONSOLE; + } + else + { + CreateProcFlags |= DETACHED_PROCESS; + } + + // Everything is set up now so we can proceed and launch the process + STARTUPINFOEXW StartupInfo = { + .StartupInfo = {.cb = sizeof(STARTUPINFOEXW)}, + .lpAttributeList = AttrList, + }; + PROCESS_INFORMATION ProcessInfo = {}; + + if (Options.Flags & CreateProcOptions::Flag_NewConsole) + { + CreateProcFlags |= CREATE_NEW_CONSOLE; + } + + ExtendableWideStringBuilder<256> CommandLineZ; + CommandLineZ << CommandLine; + + bOk = CreateProcessW(Executable.c_str(), + CommandLineZ.Data(), + nullptr, + nullptr, + FALSE, + CreateProcFlags, + nullptr, + nullptr, + &StartupInfo.StartupInfo, + &ProcessInfo); + if (bOk == FALSE) + { + return nullptr; + } + + CloseHandle(ProcessInfo.hThread); + return ProcessInfo.hProcess; +} + +static CreateProcResult +CreateProcElevated(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options) +{ + ExtendableWideStringBuilder<256> CommandLineZ; + CommandLineZ << CommandLine; + + SHELLEXECUTEINFO ShellExecuteInfo; + ZeroMemory(&ShellExecuteInfo, sizeof(ShellExecuteInfo)); + ShellExecuteInfo.cbSize = sizeof(ShellExecuteInfo); + ShellExecuteInfo.fMask = SEE_MASK_UNICODE | SEE_MASK_NOCLOSEPROCESS; + ShellExecuteInfo.lpFile = Executable.c_str(); + ShellExecuteInfo.lpVerb = TEXT("runas"); + ShellExecuteInfo.nShow = SW_SHOW; + ShellExecuteInfo.lpParameters = CommandLineZ.c_str(); + + if (Options.WorkingDirectory != nullptr) + { + ShellExecuteInfo.lpDirectory = Options.WorkingDirectory->c_str(); + } + + if (::ShellExecuteEx(&ShellExecuteInfo)) + { + return ShellExecuteInfo.hProcess; + } + + return nullptr; +} +#endif // ZEN_PLATFORM_WINDOWS + +CreateProcResult +CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options) +{ +#if ZEN_PLATFORM_WINDOWS + if (Options.Flags & CreateProcOptions::Flag_Unelevated) + { + return CreateProcUnelevated(Executable, CommandLine, Options); + } + + if (Options.Flags & CreateProcOptions::Flag_Elevated) + { + return CreateProcElevated(Executable, CommandLine, Options); + } + + return CreateProcNormal(Executable, CommandLine, Options); +#else + std::vector<char*> ArgV; + std::string CommandLineZ(CommandLine); + BuildArgV(ArgV, CommandLineZ.data()); + ArgV.push_back(nullptr); + + int ChildPid = fork(); + if (ChildPid < 0) + { + ThrowLastError("Failed to fork a new child process"); + } + else if (ChildPid == 0) + { + if (Options.WorkingDirectory != nullptr) + { + int Result = chdir(Options.WorkingDirectory->c_str()); + ZEN_UNUSED(Result); + } + + if (execv(Executable.c_str(), ArgV.data()) < 0) + { + ThrowLastError("Failed to exec() a new process image"); + } + } + + return ChildPid; +#endif +} + +////////////////////////////////////////////////////////////////////////// + +ProcessMonitor::ProcessMonitor() +{ +} + +ProcessMonitor::~ProcessMonitor() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (HandleType& Proc : m_ProcessHandles) + { +#if ZEN_PLATFORM_WINDOWS + CloseHandle(Proc); +#endif + Proc = 0; + } +} + +bool +ProcessMonitor::IsRunning() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + bool FoundOne = false; + + for (HandleType& Proc : m_ProcessHandles) + { + bool ProcIsActive; + +#if ZEN_PLATFORM_WINDOWS + DWORD ExitCode = 0; + GetExitCodeProcess(Proc, &ExitCode); + + ProcIsActive = (ExitCode == STILL_ACTIVE); + if (!ProcIsActive) + { + CloseHandle(Proc); + } +#else + int Pid = int(intptr_t(Proc)); + ProcIsActive = IsProcessRunning(Pid); +#endif + + if (!ProcIsActive) + { + Proc = 0; + } + + // Still alive + FoundOne |= ProcIsActive; + } + + std::erase_if(m_ProcessHandles, [](HandleType Handle) { return Handle == 0; }); + + return FoundOne; +} + +void +ProcessMonitor::AddPid(int Pid) +{ + HandleType ProcessHandle; + +#if ZEN_PLATFORM_WINDOWS + ProcessHandle = OpenProcess(PROCESS_QUERY_INFORMATION | SYNCHRONIZE, FALSE, Pid); +#else + ProcessHandle = HandleType(intptr_t(Pid)); +#endif + + if (ProcessHandle) + { + RwLock::ExclusiveLockScope _(m_Lock); + m_ProcessHandles.push_back(ProcessHandle); + } +} + +bool +ProcessMonitor::IsActive() const +{ + RwLock::SharedLockScope _(m_Lock); + return m_ProcessHandles.empty() == false; +} + +////////////////////////////////////////////////////////////////////////// + +bool +IsProcessRunning(int pid) +{ + // This function is arguably not super useful, a pid can be re-used + // by the OS so holding on to a pid and polling it over some time + // period will not necessarily tell you what you probably want to know. + +#if ZEN_PLATFORM_WINDOWS + HANDLE hProc = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, FALSE, pid); + + if (!hProc) + { + DWORD Error = zen::GetLastError(); + + if (Error == ERROR_INVALID_PARAMETER) + { + return false; + } + + ThrowSystemError(Error, fmt::format("failed to open process with pid {}", pid)); + } + + CloseHandle(hProc); + + return true; +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + return (kill(pid_t(pid), 0) == 0); +#endif +} + +int +GetCurrentProcessId() +{ +#if ZEN_PLATFORM_WINDOWS + return ::GetCurrentProcessId(); +#else + return int(getpid()); +#endif +} + +int +GetCurrentThreadId() +{ +#if ZEN_PLATFORM_WINDOWS + return ::GetCurrentThreadId(); +#elif ZEN_PLATFORM_LINUX + return int(syscall(SYS_gettid)); +#elif ZEN_PLATFORM_MAC + return int(pthread_mach_thread_np(pthread_self())); +#endif +} + +void +Sleep(int ms) +{ +#if ZEN_PLATFORM_WINDOWS + ::Sleep(ms); +#else + usleep(ms * 1000U); +#endif +} + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +void +thread_forcelink() +{ +} + +TEST_CASE("Thread") +{ + int Pid = GetCurrentProcessId(); + CHECK(Pid > 0); + CHECK(IsProcessRunning(Pid)); + + CHECK_FALSE(GetCurrentThreadId() == 0); +} + +TEST_CASE("BuildArgV") +{ + const char* Words[] = {"one", "two", "three", "four", "five"}; + struct + { + int WordCount; + const char* Input; + } Cases[] = { + {0, ""}, + {0, " "}, + {1, "one"}, + {1, " one"}, + {1, "one "}, + {2, "one two"}, + {2, " one two"}, + {2, "one two "}, + {2, " one two"}, + {2, "one two "}, + {2, "one two "}, + {3, "one two three"}, + {3, "\"one\" two \"three\""}, + {5, "one two three four five"}, + }; + + for (const auto& Case : Cases) + { + std::vector<char*> OutArgs; + StringBuilder<64> Mutable; + Mutable << Case.Input; + BuildArgV(OutArgs, Mutable.Data()); + + CHECK_EQ(OutArgs.size(), Case.WordCount); + + for (int i = 0, n = int(OutArgs.size()); i < n; ++i) + { + const char* Truth = Words[i]; + size_t TruthLen = strlen(Truth); + + const char* Candidate = OutArgs[i]; + bool bQuoted = (Candidate[0] == '\"'); + Candidate += bQuoted; + + CHECK(strncmp(Truth, Candidate, TruthLen) == 0); + + if (bQuoted) + { + CHECK_EQ(Candidate[TruthLen], '\"'); + } + } + } +} + +TEST_CASE("NamedEvent") +{ + std::string Name = "zencore_test_event"; + NamedEvent TestEvent(Name); + + // Timeout test + for (uint32_t i = 0; i < 8; ++i) + { + bool bEventSet = TestEvent.Wait(100); + CHECK(!bEventSet); + } + + // Thread check + std::thread Waiter = std::thread([Name]() { + NamedEvent ReadyEvent(Name + "_ready"); + ReadyEvent.Set(); + + NamedEvent TestEvent(Name); + TestEvent.Wait(1000); + }); + + NamedEvent ReadyEvent(Name + "_ready"); + ReadyEvent.Wait(); + + zen::Sleep(500); + TestEvent.Set(); + + Waiter.join(); + + // Manual reset property + for (uint32_t i = 0; i < 8; ++i) + { + bool bEventSet = TestEvent.Wait(100); + CHECK(bEventSet); + } +} + +TEST_CASE("NamedMutex") +{ + static const char* Name = "zen_test_mutex"; + + CHECK(!NamedMutex::Exists(Name)); + + { + NamedMutex TestMutex; + CHECK(TestMutex.Create(Name)); + CHECK(NamedMutex::Exists(Name)); + } + + CHECK(!NamedMutex::Exists(Name)); +} + +#endif // ZEN_WITH_TESTS + +} // namespace zen diff --git a/src/zencore/timer.cpp b/src/zencore/timer.cpp new file mode 100644 index 000000000..1655e912d --- /dev/null +++ b/src/zencore/timer.cpp @@ -0,0 +1,105 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/thread.h> +#include <zencore/timer.h> + +#include <zencore/testing.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#elif ZEN_PLATFORM_LINUX +# include <time.h> +# include <unistd.h> +#endif + +namespace zen { + +uint64_t +GetHifreqTimerValue() +{ + uint64_t Timestamp; + +#if ZEN_PLATFORM_WINDOWS + LARGE_INTEGER li; + QueryPerformanceCounter(&li); + + Timestamp = li.QuadPart; +#else + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + Timestamp = (uint64_t(ts.tv_sec) * 1000000ull) + (uint64_t(ts.tv_nsec) / 1000ull); +#endif + + return Timestamp; +} + +uint64_t +InternalGetHifreqTimerFrequency() +{ +#if ZEN_PLATFORM_WINDOWS + LARGE_INTEGER li; + QueryPerformanceFrequency(&li); + + return li.QuadPart; +#else + return 1000000ull; +#endif +} + +uint64_t QpcFreq = InternalGetHifreqTimerFrequency(); +static const double QpcFactor = 1.0 / InternalGetHifreqTimerFrequency(); + +uint64_t +GetHifreqTimerFrequency() +{ + return QpcFreq; +} + +double +GetHifreqTimerToSeconds() +{ + return QpcFactor; +} + +uint64_t +GetHifreqTimerFrequencySafe() +{ + if (!QpcFreq) + { + QpcFreq = InternalGetHifreqTimerFrequency(); + } + + return QpcFreq; +} + +////////////////////////////////////////////////////////////////////////// + +uint64_t detail::g_LofreqTimerValue = GetHifreqTimerValue(); + +void +UpdateLofreqTimerValue() +{ + detail::g_LofreqTimerValue = GetHifreqTimerValue(); +} + +uint64_t +GetLofreqTimerFrequency() +{ + return GetHifreqTimerFrequencySafe(); +} + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +void +timer_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zencore/trace.cpp b/src/zencore/trace.cpp new file mode 100644 index 000000000..788dcec07 --- /dev/null +++ b/src/zencore/trace.cpp @@ -0,0 +1,45 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +/* clang-format off */ + +#if ZEN_WITH_TRACE + +#include <zencore/zencore.h> + +#define TRACE_IMPLEMENT 1 +#include <zencore/trace.h> + +void +TraceInit(const char* HostOrPath, TraceType Type) +{ + bool EnableEvents = true; + + switch (Type) + { + case TraceType::Network: + trace::SendTo(HostOrPath); + break; + + case TraceType::File: + trace::WriteTo(HostOrPath); + break; + + case TraceType::None: + EnableEvents = false; + break; + } + + trace::FInitializeDesc Desc = { + .bUseImportantCache = false, + }; + trace::Initialize(Desc); + + if (EnableEvents) + { + trace::ToggleChannel("cpu", true); + } +} + +#endif // ZEN_WITH_TRACE + +/* clang-format on */ diff --git a/src/zencore/uid.cpp b/src/zencore/uid.cpp new file mode 100644 index 000000000..86cdfae3a --- /dev/null +++ b/src/zencore/uid.cpp @@ -0,0 +1,148 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/uid.h> + +#include <zencore/endian.h> +#include <zencore/string.h> +#include <zencore/testing.h> + +#include <atomic> +#include <bit> +#include <chrono> +#include <random> +#include <set> +#include <unordered_map> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +namespace detail { + static bool OidInitialised; + static uint32_t RunId; + static std::atomic_uint32_t Serial; +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +const Oid Oid::Zero = {{0u, 0u, 0u}}; +const Oid Oid::Max = {{~0u, ~0u, ~0u}}; + +void +Oid::Initialize() +{ + if (!detail::OidInitialised) + { + std::random_device Rng; + detail::RunId = Rng(); + detail::Serial = Rng(); + + detail::OidInitialised = true; + } +} + +const Oid& +Oid::Generate() +{ + if (!detail::OidInitialised) + { + Oid::Initialize(); + } + + const uint64_t kOffset = 1'609'459'200; // Seconds from 1970 -> 2021 + const uint64_t Time = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) - kOffset; + + OidBits[0] = ToNetworkOrder(uint32_t(Time)); + OidBits[1] = ToNetworkOrder(uint32_t(++detail::Serial)); + OidBits[2] = detail::RunId; + + return *this; +} + +Oid +Oid::NewOid() +{ + return Oid().Generate(); +} + +Oid +Oid::FromHexString(const std::string_view String) +{ + ZEN_ASSERT(String.size() == 2 * sizeof(Oid::OidBits)); + + Oid Id; + + if (ParseHexBytes(String.data(), String.size(), reinterpret_cast<uint8_t*>(Id.OidBits))) + { + return Id; + } + else + { + return Oid::Zero; + } +} + +Oid +Oid::FromMemory(const void* Ptr) +{ + Oid Id; + memcpy(Id.OidBits, Ptr, sizeof Id); + return Id; +} + +void +Oid::ToString(char OutString[StringLength]) +{ + ToHexBytes(reinterpret_cast<const uint8_t*>(OidBits), sizeof(Oid::OidBits), OutString); +} + +StringBuilderBase& +Oid::ToString(StringBuilderBase& OutString) const +{ + String_t Str; + ToHexBytes(reinterpret_cast<const uint8_t*>(OidBits), sizeof(Oid::OidBits), Str); + + OutString.AppendRange(Str, &Str[StringLength]); + + return OutString; +} + +#if ZEN_WITH_TESTS + +TEST_CASE("Oid") +{ + SUBCASE("Basic") + { + Oid id1 = Oid::NewOid(); + ZEN_UNUSED(id1); + + std::vector<Oid> ids; + std::set<Oid> idset; + std::unordered_map<Oid, int, Oid::Hasher> idmap; + + const int Count = 1000; + + for (int i = 0; i < Count; ++i) + { + Oid id; + id.Generate(); + + ids.emplace_back(id); + idset.insert(id); + idmap.insert({id, i}); + } + + CHECK(ids.size() == Count); + CHECK(idset.size() == Count); // All ids should be unique + CHECK(idmap.size() == Count); // Ditto + } +} + +void +uid_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zencore/workthreadpool.cpp b/src/zencore/workthreadpool.cpp new file mode 100644 index 000000000..b4328cdbd --- /dev/null +++ b/src/zencore/workthreadpool.cpp @@ -0,0 +1,83 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/workthreadpool.h> + +#include <zencore/logging.h> + +namespace zen { + +namespace detail { + struct LambdaWork : IWork + { + LambdaWork(auto Work) : WorkFunction(Work) {} + virtual void Execute() override { WorkFunction(); } + + std::function<void()> WorkFunction; + }; +} // namespace detail + +WorkerThreadPool::WorkerThreadPool(int InThreadCount) +{ + for (int i = 0; i < InThreadCount; ++i) + { + m_WorkerThreads.emplace_back(&WorkerThreadPool::WorkerThreadFunction, this); + } +} + +WorkerThreadPool::~WorkerThreadPool() +{ + m_WorkQueue.CompleteAdding(); + + for (std::thread& Thread : m_WorkerThreads) + { + Thread.join(); + } + + m_WorkerThreads.clear(); +} + +void +WorkerThreadPool::ScheduleWork(Ref<IWork> Work) +{ + m_WorkQueue.Enqueue(std::move(Work)); +} + +void +WorkerThreadPool::ScheduleWork(std::function<void()>&& Work) +{ + m_WorkQueue.Enqueue(Ref<IWork>(new detail::LambdaWork(Work))); +} + +[[nodiscard]] size_t +WorkerThreadPool::PendingWork() const +{ + return m_WorkQueue.Size(); +} + +void +WorkerThreadPool::WorkerThreadFunction() +{ + do + { + Ref<IWork> Work; + if (m_WorkQueue.WaitAndDequeue(Work)) + { + try + { + Work->Execute(); + } + catch (std::exception& e) + { + Work->m_Exception = std::current_exception(); + + ZEN_WARN("Caught exception in worker thread: {}", e.what()); + } + } + else + { + return; + } + } while (true); +} + +} // namespace zen diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua new file mode 100644 index 000000000..e1e649c1d --- /dev/null +++ b/src/zencore/xmake.lua @@ -0,0 +1,61 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zencore') + set_kind("static") + add_headerfiles("**.h") + add_configfiles("include/zencore/config.h.in") + on_load(function (target) + local version = io.readfile("VERSION.txt") + version = string.gsub(version,"%-pre.*", "") + target:set("version", version:trim(), {build = "%Y%m%d%H%M"}) + end) + set_configdir("include/zencore") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_includedirs("$(projectdir)/thirdparty/utfcpp/source") + add_includedirs("$(projectdir)/thirdparty/trace", {public=true}) + if is_os("windows") then + add_linkdirs("$(projectdir)/thirdparty/Oodle/lib/Win64") + elseif is_os("linux") then + add_linkdirs("$(projectdir)/thirdparty/Oodle/lib/Linux_x64") + add_links("oo2corelinux64") + add_syslinks("pthread") + elseif is_os("macosx") then + add_linkdirs("$(projectdir)/thirdparty/Oodle/lib/Mac_x64") + add_links("oo2coremac64") + end + add_options("zentrace") + add_packages( + "vcpkg::blake3", + "vcpkg::cpr", + "vcpkg::curl", -- required by cpr + "vcpkg::doctest", + "vcpkg::fmt", + "vcpkg::gsl-lite", + "vcpkg::json11", + "vcpkg::lz4", + "vcpkg::mimalloc", + "vcpkg::openssl", -- required by curl + "vcpkg::spdlog", + "vcpkg::zlib", -- required by curl + "vcpkg::xxhash") + + if is_plat("linux") then + -- The 'vcpkg::openssl' package is two libraries; ssl and crypto, with + -- ssl being dependent on symbols in crypto. When GCC-like linkers read + -- object files from their command line, those object files only resolve + -- symbols of objects previously encountered. Thus crypto must appear + -- after ssl so it can fill out ssl's unresolved symbol table. Xmake's + -- vcpkg support is basic and works by parsing .list files. Openssl's + -- archives are listed alphabetically causing crypto to be _before_ ssl + -- and resulting in link errors. The links are restated here to force + -- xmake to use the correct order, and "syslinks" is used to force the + -- arguments to the end of the line (otherwise they can appear before + -- curl and cause more errors). + add_syslinks("crypto") + add_syslinks("dl") + end + + if is_plat("linux") then + add_syslinks("rt") + end diff --git a/src/zencore/xxhash.cpp b/src/zencore/xxhash.cpp new file mode 100644 index 000000000..450131d19 --- /dev/null +++ b/src/zencore/xxhash.cpp @@ -0,0 +1,50 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/xxhash.h> + +#include <zencore/string.h> +#include <zencore/testing.h> + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +XXH3_128 XXH3_128::Zero; // Initialized to all zeros + +XXH3_128 +XXH3_128::FromHexString(const char* InString) +{ + return FromHexString({InString, sizeof(XXH3_128::Hash) * 2}); +} + +XXH3_128 +XXH3_128::FromHexString(std::string_view InString) +{ + ZEN_ASSERT(InString.size() == 2 * sizeof(XXH3_128::Hash)); + + XXH3_128 Xx; + ParseHexBytes(InString.data(), InString.size(), Xx.Hash); + return Xx; +} + +const char* +XXH3_128::ToHexString(char* OutString /* 40 characters + NUL terminator */) const +{ + ToHexBytes(Hash, sizeof(XXH3_128), OutString); + OutString[2 * sizeof(XXH3_128)] = '\0'; + + return OutString; +} + +StringBuilderBase& +XXH3_128::ToHexString(StringBuilderBase& OutBuilder) const +{ + String_t str; + ToHexString(str); + + OutBuilder.AppendRange(str, &str[StringLength]); + + return OutBuilder; +} + +} // namespace zen diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp new file mode 100644 index 000000000..2a7c5755e --- /dev/null +++ b/src/zencore/zencore.cpp @@ -0,0 +1,175 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/zencore.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#endif + +#if ZEN_PLATFORM_LINUX +# include <pthread.h> +#endif + +#include <zencore/blake3.h> +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compositebuffer.h> +#include <zencore/compress.h> +#include <zencore/crypto.h> +#include <zencore/filesystem.h> +#include <zencore/intmath.h> +#include <zencore/iobuffer.h> +#include <zencore/memory.h> +#include <zencore/mpscqueue.h> +#include <zencore/refcount.h> +#include <zencore/sha1.h> +#include <zencore/stats.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/timer.h> +#include <zencore/uid.h> + +namespace zen { + +AssertImpl AssertImpl::DefaultAssertImpl; +AssertImpl* AssertImpl::CurrentAssertImpl = &AssertImpl::DefaultAssertImpl; + +////////////////////////////////////////////////////////////////////////// + +bool +IsDebuggerPresent() +{ +#if ZEN_PLATFORM_WINDOWS + return ::IsDebuggerPresent(); +#else + return false; +#endif +} + +std::optional<bool> InteractiveSessionFlag; + +void +SetIsInteractiveSession(bool Value) +{ + InteractiveSessionFlag = Value; +} + +bool +IsInteractiveSession() +{ + if (!InteractiveSessionFlag.has_value()) + { +#if ZEN_PLATFORM_WINDOWS + DWORD dwSessionId = 0; + if (ProcessIdToSessionId(GetCurrentProcessId(), &dwSessionId)) + { + InteractiveSessionFlag = (dwSessionId != 0); + } + else + { + InteractiveSessionFlag = false; + } +#else + // TODO: figure out what actually makes sense here + InteractiveSessionFlag = true; +#endif + } + + return InteractiveSessionFlag.value(); +} + +////////////////////////////////////////////////////////////////////////// + +static int s_ApplicationExitCode = 0; +static bool s_ApplicationExitRequested; + +bool +IsApplicationExitRequested() +{ + return s_ApplicationExitRequested; +} + +void +RequestApplicationExit(int ExitCode) +{ + s_ApplicationExitCode = ExitCode; + s_ApplicationExitRequested = true; +} + +#if ZEN_WITH_TESTS +void +zencore_forcelinktests() +{ + zen::blake3_forcelink(); + zen::compositebuffer_forcelink(); + zen::compress_forcelink(); + zen::filesystem_forcelink(); + zen::intmath_forcelink(); + zen::iobuffer_forcelink(); + zen::memory_forcelink(); + zen::mpscqueue_forcelink(); + zen::refcount_forcelink(); + zen::sha1_forcelink(); + zen::stats_forcelink(); + zen::stream_forcelink(); + zen::string_forcelink(); + zen::thread_forcelink(); + zen::timer_forcelink(); + zen::uid_forcelink(); + zen::uson_forcelink(); + zen::usonbuilder_forcelink(); + zen::usonpackage_forcelink(); + zen::crypto_forcelink(); +} +} // namespace zen + +# include <zencore/testing.h> + +namespace zen { + +TEST_CASE("Assert.Default") +{ + bool A = true; + bool B = false; + CHECK_THROWS_WITH(ZEN_ASSERT(A == B), "A == B"); +} + +TEST_CASE("Assert.Custom") +{ + struct MyAssertImpl : AssertImpl + { + ZEN_FORCENOINLINE ZEN_DEBUG_SECTION MyAssertImpl() : PrevAssertImpl(CurrentAssertImpl) { CurrentAssertImpl = this; } + virtual ZEN_FORCENOINLINE ZEN_DEBUG_SECTION ~MyAssertImpl() { CurrentAssertImpl = PrevAssertImpl; } + virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION OnAssert(const char* Filename, + int LineNumber, + const char* FunctionName, + const char* Msg) + { + AssertFileName = Filename; + Line = LineNumber; + FuncName = FunctionName; + Message = Msg; + } + AssertImpl* PrevAssertImpl; + + const char* AssertFileName = nullptr; + int Line = -1; + const char* FuncName = nullptr; + const char* Message = nullptr; + }; + + MyAssertImpl MyAssert; + bool A = true; + bool B = false; + CHECK_THROWS_WITH(ZEN_ASSERT(A == B), "A == B"); + CHECK(MyAssert.AssertFileName != nullptr); + CHECK(MyAssert.Line != -1); + CHECK(MyAssert.FuncName != nullptr); + CHECK(strcmp(MyAssert.Message, "A == B") == 0); +} + +#endif + +} // namespace zen diff --git a/src/zenhttp/httpasio.cpp b/src/zenhttp/httpasio.cpp new file mode 100644 index 000000000..79b2c0a3d --- /dev/null +++ b/src/zenhttp/httpasio.cpp @@ -0,0 +1,1372 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpasio.h" + +#include <zencore/logging.h> +#include <zenhttp/httpserver.h> + +#include <deque> +#include <memory> +#include <string_view> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#if ZEN_PLATFORM_WINDOWS +# include <conio.h> +# include <mstcpip.h> +#endif +#include <http_parser.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#define ASIO_VERBOSE_TRACE 0 + +#if ASIO_VERBOSE_TRACE +# define ZEN_TRACE_VERBOSE ZEN_TRACE +#else +# define ZEN_TRACE_VERBOSE(fmtstr, ...) +#endif + +namespace zen::asio_http { + +using namespace std::literals; + +struct HttpAcceptor; +struct HttpRequest; +struct HttpResponse; +struct HttpServerConnection; + +static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv); +static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv); +static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv); +static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv); +static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv); +static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv); +static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv); + +inline spdlog::logger& +InitLogger() +{ + spdlog::logger& Logger = logging::Get("asio"); + // Logger.set_level(spdlog::level::trace); + return Logger; +} + +inline spdlog::logger& +Log() +{ + static spdlog::logger& g_Logger = InitLogger(); + return g_Logger; +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAsioServerImpl +{ +public: + HttpAsioServerImpl(); + ~HttpAsioServerImpl(); + + int Start(uint16_t Port, int ThreadCount); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + HttpService* RouteRequest(std::string_view Url); + + asio::io_service m_IoService; + asio::io_service::work m_Work{m_IoService}; + std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; + std::vector<std::thread> m_ThreadPool; + + struct ServiceEntry + { + std::string ServiceUrlPath; + HttpService* Service; + }; + + RwLock m_Lock; + std::vector<ServiceEntry> m_UriHandlers; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpAsioServerRequest : public HttpServerRequest +{ +public: + HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpAsioServerRequest(); + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpAsioServerRequest(const HttpAsioServerRequest&) = delete; + HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete; + + asio_http::HttpRequest& m_Request; + IoBuffer m_PayloadBuffer; + std::unique_ptr<HttpResponse> m_Response; +}; + +struct HttpRequest +{ + explicit HttpRequest(HttpServerConnection& Connection) : m_Connection(Connection) {} + + void Initialize(); + size_t ConsumeData(const char* InputData, size_t DataSize); + void ResetState(); + + HttpVerb RequestVerb() const { return m_RequestVerb; } + bool IsKeepAlive() const { return m_KeepAlive; } + std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; } + std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); } + IoBuffer Body() { return m_BodyBuffer; } + + inline HttpContentType ContentType() + { + if (m_ContentTypeHeaderIndex < 0) + { + return HttpContentType::kUnknownContentType; + } + + return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value); + } + + inline HttpContentType AcceptType() + { + if (m_AcceptHeaderIndex < 0) + { + return HttpContentType::kUnknownContentType; + } + + return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value); + } + + Oid SessionId() const { return m_SessionId; } + int RequestId() const { return m_RequestId; } + + std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); } + +private: + struct HeaderEntry + { + HeaderEntry() = default; + + HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {} + + std::string_view Name; + std::string_view Value; + }; + + HttpServerConnection& m_Connection; + char* m_HeaderCursor = m_HeaderBuffer; + char* m_Url = nullptr; + size_t m_UrlLength = 0; + char* m_QueryString = nullptr; + size_t m_QueryLength = 0; + char* m_CurrentHeaderName = nullptr; // Used while parsing headers + size_t m_CurrentHeaderNameLength = 0; + char* m_CurrentHeaderValue = nullptr; // Used while parsing headers + size_t m_CurrentHeaderValueLength = 0; + std::vector<HeaderEntry> m_Headers; + int8_t m_ContentLengthHeaderIndex; + int8_t m_AcceptHeaderIndex; + int8_t m_ContentTypeHeaderIndex; + int8_t m_RangeHeaderIndex; + HttpVerb m_RequestVerb; + bool m_KeepAlive = false; + bool m_Expect100Continue = false; + int m_RequestId = -1; + Oid m_SessionId{}; + IoBuffer m_BodyBuffer; + uint64_t m_BodyPosition = 0; + http_parser m_Parser; + char m_HeaderBuffer[1024]; + std::string m_NormalizedUrl; + + void AppendCurrentHeader(); + + int OnMessageBegin(); + int OnUrl(const char* Data, size_t Bytes); + int OnHeader(const char* Data, size_t Bytes); + int OnHeaderValue(const char* Data, size_t Bytes); + int OnHeadersComplete(); + int OnBody(const char* Data, size_t Bytes); + int OnMessageComplete(); + + static HttpRequest* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequest*>(Parser->data); } + static http_parser_settings s_ParserSettings; +}; + +struct HttpResponse +{ +public: + HttpResponse() = default; + explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) + { + m_ResponseCode = ResponseCode; + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { +#if 1 + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); +#else + IoBuffer TempBuffer = std::move(Buffer); + TempBuffer.MakeOwned(); + m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer)); +#endif + } + + uint64_t LocalDataSize = 0; + + m_AsioBuffers.push_back({}); // Placeholder for header + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // TODO: Use direct file transfer, via TransmitFile/sendfile + // + // this looks like it requires some custom asio plumbing however + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + else + { + // Send from memory + + m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()}); + } + } + m_ContentLength = LocalDataSize; + + auto Headers = GetHeaders(); + m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size()); + } + + uint16_t ResponseCode() const { return m_ResponseCode; } + uint64_t ContentLength() const { return m_ContentLength; } + + const std::vector<asio::const_buffer>& AsioBuffers() const { return m_AsioBuffers; } + + std::string_view GetHeaders() + { + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" + << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Length: " << ContentLength() << "\r\n"sv; + + if (!m_IsKeepAlive) + { + m_Headers << "Connection: close\r\n"sv; + } + + m_Headers << "\r\n"sv; + + return m_Headers; + } + + void SuppressPayload() { m_AsioBuffers.resize(1); } + +private: + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; + std::vector<IoBuffer> m_DataBuffers; + std::vector<asio::const_buffer> m_AsioBuffers; + ExtendableStringBuilder<160> m_Headers; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpServerConnection : std::enable_shared_from_this<HttpServerConnection> +{ + HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket); + ~HttpServerConnection(); + + void HandleNewRequest(); + void TerminateConnection(); + void HandleRequest(); + + std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); } + +private: + enum class RequestState + { + kInitialState, + kInitialRead, + kReadingMore, + kWriting, + kWritingFinal, + kDone, + kTerminated + }; + + RequestState m_RequestState = RequestState::kInitialState; + HttpRequest m_RequestData{*this}; + + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false); + void OnError(); + + HttpAsioServerImpl& m_Server; + asio::streambuf m_RequestBuffer; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::atomic<uint32_t> m_RequestCounter{0}; + uint32_t m_ConnectionId = 0; + Ref<IHttpPackageHandler> m_PackageHandler; + + RwLock m_ResponsesLock; + std::deque<std::unique_ptr<HttpResponse>> m_Responses; +}; + +std::atomic<uint32_t> g_ConnectionIdCounter{0}; + +HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket) +: m_Server(Server) +, m_Socket(std::move(Socket)) +, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1)) +{ + ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId); +} + +HttpServerConnection::~HttpServerConnection() +{ + ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId); +} + +void +HttpServerConnection::HandleNewRequest() +{ + m_RequestData.Initialize(); + + EnqueueRead(); +} + +void +HttpServerConnection::TerminateConnection() +{ + m_RequestState = RequestState::kTerminated; + + std::error_code Ec; + m_Socket->close(Ec); +} + +void +HttpServerConnection::EnqueueRead() +{ + if (m_RequestState == RequestState::kInitialRead) + { + m_RequestState = RequestState::kReadingMore; + } + else + { + m_RequestState = RequestState::kInitialRead; + } + + m_RequestBuffer.prepare(64 * 1024); + + asio::async_read(*m_Socket.get(), + m_RequestBuffer, + asio::transfer_at_least(1), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); +} + +void +HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount) +{ + if (Ec) + { + if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead) + { + ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection '{}' reason '{}'", m_ConnectionId, Ec.message()); + return; + } + else + { + ZEN_WARN("on data received ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message()); + return OnError(); + } + } + + ZEN_TRACE_VERBOSE("on data received, connection '{}', request '{}', thread '{}', bytes '{}'", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + zen::GetCurrentThreadId(), + NiceBytes(ByteCount)); + + while (m_RequestBuffer.size()) + { + const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); + + size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size()); + if (Result == ~0ull) + { + return OnError(); + } + + m_RequestBuffer.consume(Result); + } + + switch (m_RequestState) + { + case RequestState::kDone: + case RequestState::kWritingFinal: + case RequestState::kTerminated: + break; + + default: + EnqueueRead(); + break; + } +} + +void +HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop) +{ + if (Ec) + { + ZEN_WARN("on data sent ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message()); + OnError(); + } + else + { + ZEN_TRACE_VERBOSE("on data sent, connection '{}', request '{}', thread '{}', bytes '{}'", + m_ConnectionId, + m_RequestCounter.load(std::memory_order_relaxed), + zen::GetCurrentThreadId(), + NiceBytes(ByteCount)); + + if (!m_RequestData.IsKeepAlive()) + { + m_RequestState = RequestState::kDone; + + m_Socket->close(); + } + else + { + if (Pop) + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.pop_front(); + } + + m_RequestCounter.fetch_add(1); + } + } +} + +void +HttpServerConnection::OnError() +{ + m_Socket->close(); +} + +void +HttpServerConnection::HandleRequest() +{ + if (!m_RequestData.IsKeepAlive()) + { + m_RequestState = RequestState::kWritingFinal; + + std::error_code Ec; + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + + if (Ec) + { + ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message()); + } + } + else + { + m_RequestState = RequestState::kWriting; + } + + if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url())) + { + HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body()); + + ZEN_TRACE_VERBOSE("handle request, connection '{}' request '{}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed)); + + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + { + try + { + Service->HandleRequest(Request); + } + catch (std::exception& ex) + { + ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); + + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } + } + + if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) + { + // Transmit the response + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + Response->SuppressPayload(); + } + + auto ResponseBuffers = Response->AsioBuffers(); + + uint64_t ResponseLength = 0; + + for (auto& Buffer : ResponseBuffers) + { + ResponseLength += Buffer.size(); + } + + { + RwLock::ExclusiveLockScope _(m_ResponsesLock); + m_Responses.push_back(std::move(Response)); + } + + // TODO: should cork/uncork for Linux? + + asio::async_write(*m_Socket.get(), + ResponseBuffers, + asio::transfer_exactly(ResponseLength), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { + Conn->OnResponseDataSent(Ec, ByteCount, true); + }); + + return; + } + } + + if (m_RequestData.RequestVerb() == HttpVerb::kHead) + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "\r\n"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Connection: close\r\n" + "\r\n"sv; + } + + asio::async_write( + *m_Socket.get(), + asio::buffer(Response), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); + } + else + { + std::string_view Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "No suitable route found"sv; + + if (!m_RequestData.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "Connection: close\r\n" + "\r\n" + "No suitable route found"sv; + } + + asio::async_write( + *m_Socket.get(), + asio::buffer(Response), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); }); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// HttpRequest +// + +http_parser_settings HttpRequest::s_ParserSettings{ + .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); }, + .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); }, + .on_status = + [](http_parser* p, const char* Data, size_t ByteCount) { + ZEN_UNUSED(p, Data, ByteCount); + return 0; + }, + .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); }, + .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); }, + .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); }, + .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); }, + .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); }, + .on_chunk_header{}, + .on_chunk_complete{}}; + +void +HttpRequest::Initialize() +{ + http_parser_init(&m_Parser, HTTP_REQUEST); + m_Parser.data = this; + + ResetState(); +} + +size_t +HttpRequest::ConsumeData(const char* InputData, size_t DataSize) +{ + const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize); + + http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser)); + + if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE) + { + ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno)); + return ~0ull; + } + + return ConsumedBytes; +} + +int +HttpRequest::OnUrl(const char* Data, size_t Bytes) +{ + if (!m_Url) + { + ZEN_ASSERT_SLOW(m_UrlLength == 0); + m_Url = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_UrlLength += Bytes; + + return 0; +} + +int +HttpRequest::OnHeader(const char* Data, size_t Bytes) +{ + if (m_CurrentHeaderValueLength) + { + AppendCurrentHeader(); + + m_CurrentHeaderNameLength = 0; + m_CurrentHeaderValueLength = 0; + m_CurrentHeaderName = m_HeaderCursor; + } + else if (m_CurrentHeaderName == nullptr) + { + m_CurrentHeaderName = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_CurrentHeaderNameLength += Bytes; + + return 0; +} + +void +HttpRequest::AppendCurrentHeader() +{ + std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength); + std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength); + + const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName); + + if (HeaderHash == HashContentLength) + { + m_ContentLengthHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashAccept) + { + m_AcceptHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashContentType) + { + m_ContentTypeHeaderIndex = (int8_t)m_Headers.size(); + } + else if (HeaderHash == HashSession) + { + m_SessionId = Oid::FromHexString(HeaderValue); + } + else if (HeaderHash == HashRequest) + { + std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId); + } + else if (HeaderHash == HashExpect) + { + if (HeaderValue == "100-continue"sv) + { + // We don't currently do anything with this + m_Expect100Continue = true; + } + else + { + ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue); + } + } + else if (HeaderHash == HashRange) + { + m_RangeHeaderIndex = (int8_t)m_Headers.size(); + } + + m_Headers.emplace_back(HeaderName, HeaderValue); +} + +int +HttpRequest::OnHeaderValue(const char* Data, size_t Bytes) +{ + if (m_CurrentHeaderValueLength == 0) + { + m_CurrentHeaderValue = m_HeaderCursor; + } + + const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor; + if (RemainingBufferSpace < Bytes) + { + ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace); + return 1; + } + + memcpy(m_HeaderCursor, Data, Bytes); + m_HeaderCursor += Bytes; + m_CurrentHeaderValueLength += Bytes; + + return 0; +} + +static void +NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl) +{ + bool LastCharWasSeparator = false; + for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex) + { + const char UrlChar = Url[UrlIndex]; + const bool IsSeparator = (UrlChar == '/'); + + if (IsSeparator && LastCharWasSeparator) + { + if (NormalizedUrl.empty()) + { + NormalizedUrl.reserve(UrlLength); + NormalizedUrl.append(Url, UrlIndex); + } + + if (!LastCharWasSeparator) + { + NormalizedUrl.push_back('/'); + } + } + else if (!NormalizedUrl.empty()) + { + NormalizedUrl.push_back(UrlChar); + } + + LastCharWasSeparator = IsSeparator; + } +} + +int +HttpRequest::OnHeadersComplete() +{ + if (m_CurrentHeaderValueLength) + { + AppendCurrentHeader(); + } + + if (m_ContentLengthHeaderIndex >= 0) + { + std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value; + uint64_t ContentLength = 0; + std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength); + + if (ContentLength) + { + m_BodyBuffer = IoBuffer(ContentLength); + } + + m_BodyBuffer.SetContentType(ContentType()); + + m_BodyPosition = 0; + } + + m_KeepAlive = !!http_should_keep_alive(&m_Parser); + + switch (m_Parser.method) + { + case HTTP_GET: + m_RequestVerb = HttpVerb::kGet; + break; + + case HTTP_POST: + m_RequestVerb = HttpVerb::kPost; + break; + + case HTTP_PUT: + m_RequestVerb = HttpVerb::kPut; + break; + + case HTTP_DELETE: + m_RequestVerb = HttpVerb::kDelete; + break; + + case HTTP_HEAD: + m_RequestVerb = HttpVerb::kHead; + break; + + case HTTP_COPY: + m_RequestVerb = HttpVerb::kCopy; + break; + + case HTTP_OPTIONS: + m_RequestVerb = HttpVerb::kOptions; + break; + + default: + ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method)); + break; + } + + std::string_view Url(m_Url, m_UrlLength); + + if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos) + { + m_UrlLength = QuerySplit; + m_QueryString = m_Url + QuerySplit + 1; + m_QueryLength = Url.size() - QuerySplit - 1; + } + + NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl); + + return 0; +} + +int +HttpRequest::OnBody(const char* Data, size_t Bytes) +{ + if (m_BodyPosition + Bytes > m_BodyBuffer.Size()) + { + ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes", + (m_BodyPosition + Bytes) - m_BodyBuffer.Size()); + return 1; + } + memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes); + m_BodyPosition += Bytes; + + if (http_body_is_final(&m_Parser)) + { + if (m_BodyPosition != m_BodyBuffer.Size()) + { + ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size()); + return 1; + } + } + + return 0; +} + +void +HttpRequest::ResetState() +{ + m_HeaderCursor = m_HeaderBuffer; + m_CurrentHeaderName = nullptr; + m_CurrentHeaderNameLength = 0; + m_CurrentHeaderValue = nullptr; + m_CurrentHeaderValueLength = 0; + m_CurrentHeaderName = nullptr; + m_Url = nullptr; + m_UrlLength = 0; + m_QueryString = nullptr; + m_QueryLength = 0; + m_ContentLengthHeaderIndex = -1; + m_AcceptHeaderIndex = -1; + m_ContentTypeHeaderIndex = -1; + m_RangeHeaderIndex = -1; + m_Expect100Continue = false; + m_BodyBuffer = {}; + m_BodyPosition = 0; + m_Headers.clear(); + m_NormalizedUrl.clear(); +} + +int +HttpRequest::OnMessageBegin() +{ + return 0; +} + +int +HttpRequest::OnMessageComplete() +{ + m_Connection.HandleRequest(); + + ResetState(); + + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +struct HttpAcceptor +{ + HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort) + : m_Server(Server) + , m_IoService(IoService) + , m_Acceptor(m_IoService, asio::ip::tcp::v6()) + { + m_Acceptor.set_option(asio::ip::v6_only(false)); +#if ZEN_PLATFORM_WINDOWS + // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms + typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address; + m_Acceptor.set_option(excluse_address(true)); +#else // ZEN_PLATFORM_WINDOWS + m_Acceptor.set_option(asio::socket_base::reuse_address(false)); +#endif // ZEN_PLATFORM_WINDOWS + + m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + uint16_t EffectivePort = BasePort; + + asio::error_code BindErrorCode; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + // Sharing violation implies the port is being used by another process + for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode == asio::error::access_denied) + { + EffectivePort = 0; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode) + { + ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); + } + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor.native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif + m_Acceptor.listen(); + + ZEN_INFO("Started asio server at port '{}'", EffectivePort); + } + + void Start() + { + m_Acceptor.listen(); + InitAccept(); + } + + void Stop() { m_IsStopped = true; } + + void InitAccept() + { + auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); + asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'", + m_Acceptor.local_endpoint().address().to_string(), + m_Acceptor.local_endpoint().port(), + Ec.message()); + } + else + { + // New connection established, pass socket ownership into connection object + // and initiate request handling loop. The connection lifetime is + // managed by the async read/write loop by passing the shared + // reference to the callbacks. + + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket)); + Conn->HandleNewRequest(); + } + + if (!m_IsStopped.load()) + { + InitAccept(); + } + else + { + m_Acceptor.close(); + } + }); + } + + int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } + +private: + HttpAsioServerImpl& m_Server; + asio::io_service& m_IoService; + asio::ip::tcp::acceptor m_Acceptor; + std::atomic<bool> m_IsStopped{false}; +}; + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerRequest::HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer) +: m_Request(Request) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const int PrefixLength = Service.UriPrefixLength(); + + std::string_view Uri = Request.Url(); + Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size()))); + m_Uri = Uri; + m_UriWithExtension = Uri; + m_QueryString = Request.QueryString(); + + m_Verb = Request.RequestVerb(); + m_ContentLength = Request.Body().Size(); + m_ContentType = Request.ContentType(); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + // Parse any extension, to allow requesting a particular response encoding via the URL + + { + std::string_view UriSuffix8{m_Uri}; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); + } + } + } + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = Request.AcceptType(); + } +} + +HttpAsioServerRequest::~HttpAsioServerRequest() +{ +} + +Oid +HttpAsioServerRequest::ParseSessionId() const +{ + return m_Request.SessionId(); +} + +uint32_t +HttpAsioServerRequest::ParseRequestId() const +{ + return m_Request.RequestId(); +} + +IoBuffer +HttpAsioServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(HttpContentType::kBinary)); + std::array<IoBuffer, 0> Empty; + + m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpResponse(ContentType)); + m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); +} + +void +HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(!m_Response); + m_Response.reset(new HttpResponse(ContentType)); + + IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); +} + +void +HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + ZEN_ASSERT(!m_Response); + + // Not one bit async, innit + ContinuationHandler(*this); +} + +bool +HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +HttpAsioServerImpl::HttpAsioServerImpl() +{ +} + +HttpAsioServerImpl::~HttpAsioServerImpl() +{ +} + +int +HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount) +{ + ZEN_ASSERT(ThreadCount > 0); + + ZEN_INFO("starting asio http with {} service threads", ThreadCount); + + m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port)); + m_Acceptor->Start(); + + for (int i = 0; i < ThreadCount; ++i) + { + m_ThreadPool.emplace_back([this, Index = i + 1] { + SetCurrentThreadName(fmt::format("asio worker {}", Index)); + + try + { + m_IoService.run(); + } + catch (std::exception& e) + { + ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what()); + } + }); + } + + ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort()); + + return m_Acceptor->GetAcceptPort(); +} + +void +HttpAsioServerImpl::Stop() +{ + m_Acceptor->Stop(); + m_IoService.stop(); + for (auto& Thread : m_ThreadPool) + { + Thread.join(); + } +} + +void +HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) +{ + std::string_view UrlPath(InUrlPath); + Service.SetUriPrefixLength(UrlPath.size()); + if (!UrlPath.empty() && UrlPath.back() == '/') + { + UrlPath.remove_suffix(1); + } + + RwLock::ExclusiveLockScope _(m_Lock); + m_UriHandlers.push_back({std::string(UrlPath), &Service}); +} + +HttpService* +HttpAsioServerImpl::RouteRequest(std::string_view Url) +{ + RwLock::SharedLockScope _(m_Lock); + + HttpService* CandidateService = nullptr; + std::string::size_type CandidateMatchSize = 0; + for (const ServiceEntry& SvcEntry : m_UriHandlers) + { + const std::string& SvcUrl = SvcEntry.ServiceUrlPath; + const std::string::size_type SvcUrlSize = SvcUrl.size(); + if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && + ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) + { + CandidateMatchSize = SvcUrl.size(); + CandidateService = SvcEntry.Service; + } + } + + return CandidateService; +} + +} // namespace zen::asio_http + +////////////////////////////////////////////////////////////////////////// + +namespace zen { +HttpAsioServer::HttpAsioServer() : m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>()) +{ + ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(asio_http::HttpRequest), sizeof(asio_http::HttpRequest)); +} + +HttpAsioServer::~HttpAsioServer() +{ + try + { + m_Impl->Stop(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception stopping http asio server: {}", ex.what()); + } +} + +void +HttpAsioServer::RegisterService(HttpService& Service) +{ + m_Impl->RegisterService(Service.BaseUri(), Service); +} + +int +HttpAsioServer::Initialize(int BasePort) +{ + m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Max(std::thread::hardware_concurrency(), 8u)); + return m_BasePort; +} + +void +HttpAsioServer::Run(bool IsInteractive) +{ + const bool TestMode = !IsInteractive; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +#if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#endif +} + +void +HttpAsioServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +} // namespace zen diff --git a/src/zenhttp/httpasio.h b/src/zenhttp/httpasio.h new file mode 100644 index 000000000..716145955 --- /dev/null +++ b/src/zenhttp/httpasio.h @@ -0,0 +1,36 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> + +#include <memory> + +namespace zen { + +namespace asio_http { + struct HttpServerConnection; + struct HttpAcceptor; + struct HttpAsioServerImpl; +} // namespace asio_http + +class HttpAsioServer : public HttpServer +{ +public: + HttpAsioServer(); + ~HttpAsioServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int BasePort) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + +private: + Event m_ShutdownEvent; + int m_BasePort = 0; + + std::unique_ptr<asio_http::HttpAsioServerImpl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp new file mode 100644 index 000000000..e6813d407 --- /dev/null +++ b/src/zenhttp/httpclient.cpp @@ -0,0 +1,176 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpclient.h> +#include <zenhttp/httpserver.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/session.h> +#include <zencore/sharedbuffer.h> +#include <zencore/stream.h> +#include <zencore/testing.h> +#include <zenhttp/httpshared.h> + +static std::atomic<uint32_t> HttpClientRequestIdCounter{0}; + +namespace zen { + +using namespace std::literals; + +HttpClient::Response +FromCprResponse(cpr::Response& InResponse) +{ + return {.StatusCode = int(InResponse.status_code)}; +} + +////////////////////////////////////////////////////////////////////////// + +HttpClient::HttpClient(std::string_view BaseUri) : m_BaseUri(BaseUri) +{ + StringBuilder<32> SessionId; + GetSessionId().ToString(SessionId); + m_SessionId = SessionId; +} + +HttpClient::~HttpClient() +{ +} + +HttpClient::Response +HttpClient::TransactPackage(std::string_view Url, CbPackage Package) +{ + cpr::Session Sess; + Sess.SetUrl(m_BaseUri + std::string(Url)); + + // First, list of offered chunks for filtering on the server end + + std::vector<IoHash> AttachmentsToSend; + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + + const uint32_t RequestId = ++HttpClientRequestIdCounter; + auto RequestIdString = fmt::to_string(RequestId); + + if (Attachments.empty() == false) + { + CbObjectWriter Writer; + Writer.BeginArray("offer"); + + for (const CbAttachment& Attachment : Attachments) + { + IoHash Hash = Attachment.GetHash(); + + Writer.AddHash(Hash); + } + + Writer.EndArray(); + + BinaryWriter MemWriter; + Writer.Save(MemWriter); + + Sess.SetHeader({{"Content-Type", "application/x-ue-offer"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}}); + Sess.SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (FilterResponse.status_code == 200) + { + IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size()); + CbObject ResponseObject = LoadCompactBinaryObject(ResponseBuffer); + + for (auto& Entry : ResponseObject["need"]) + { + ZEN_ASSERT(Entry.IsHash()); + AttachmentsToSend.push_back(Entry.AsHash()); + } + } + } + + // Prepare package for send + + CbPackage SendPackage; + SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash()); + + for (const IoHash& AttachmentCid : AttachmentsToSend) + { + const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid); + + if (Attachment) + { + SendPackage.AddAttachment(*Attachment); + } + else + { + // This should be an error -- server asked to have something we can't find + } + } + + // Transmit package payload + + CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage); + SharedBuffer FlatMessage = Message.Flatten(); + + Sess.SetHeader({{"Content-Type", "application/x-ue-cbpkg"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}}); + Sess.SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()}); + + cpr::Response FilterResponse = Sess.Post(); + + if (!IsHttpSuccessCode(FilterResponse.status_code)) + { + return FromCprResponse(FilterResponse); + } + + IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size()); + + if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end()) + { + HttpContentType ContentType = ParseContentType(It->second); + + ResponseBuffer.SetContentType(ContentType); + } + + return {.StatusCode = int(FilterResponse.status_code), .ResponsePayload = ResponseBuffer}; +} + +HttpClient::Response +HttpClient::Put(std::string_view Url, IoBuffer Payload) +{ + ZEN_UNUSED(Url); + ZEN_UNUSED(Payload); + return {}; +} + +HttpClient::Response +HttpClient::Get(std::string_view Url) +{ + ZEN_UNUSED(Url); + return {}; +} + +HttpClient::Response +HttpClient::Delete(std::string_view Url) +{ + ZEN_UNUSED(Url); + return {}; +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +TEST_CASE("httpclient") +{ + using namespace std::literals; + + SUBCASE("client") {} +} + +void +httpclient_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenhttp/httpnull.cpp b/src/zenhttp/httpnull.cpp new file mode 100644 index 000000000..a6e1d3567 --- /dev/null +++ b/src/zenhttp/httpnull.cpp @@ -0,0 +1,83 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpnull.h" + +#include <zencore/logging.h> + +#if ZEN_PLATFORM_WINDOWS +# include <conio.h> +#endif + +namespace zen { + +HttpNullServer::HttpNullServer() +{ +} + +HttpNullServer::~HttpNullServer() +{ +} + +void +HttpNullServer::RegisterService(HttpService& Service) +{ + ZEN_UNUSED(Service); +} + +int +HttpNullServer::Initialize(int BasePort) +{ + return BasePort; +} + +void +HttpNullServer::Run(bool IsInteractiveSession) +{ + const bool TestMode = !IsInteractiveSession; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +#if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +#endif +} + +void +HttpNullServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +} // namespace zen diff --git a/src/zenhttp/httpnull.h b/src/zenhttp/httpnull.h new file mode 100644 index 000000000..74f021f6b --- /dev/null +++ b/src/zenhttp/httpnull.h @@ -0,0 +1,29 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> + +namespace zen { + +/** + * @brief Null implementation of "http" server. Does nothing + */ + +class HttpNullServer : public HttpServer +{ +public: + HttpNullServer(); + ~HttpNullServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int BasePort) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + +private: + Event m_ShutdownEvent; +}; + +} // namespace zen diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp new file mode 100644 index 000000000..671cbd319 --- /dev/null +++ b/src/zenhttp/httpserver.cpp @@ -0,0 +1,885 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpserver.h> + +#include "httpasio.h" +#include "httpnull.h" +#include "httpsys.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/refcount.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/thread.h> +#include <zenhttp/httpshared.h> + +#include <charconv> +#include <mutex> +#include <span> +#include <string_view> + +namespace zen { + +using namespace std::literals; + +std::string_view +MapContentTypeToString(HttpContentType ContentType) +{ + switch (ContentType) + { + default: + case HttpContentType::kUnknownContentType: + case HttpContentType::kBinary: + return "application/octet-stream"sv; + + case HttpContentType::kText: + return "text/plain"sv; + + case HttpContentType::kJSON: + return "application/json"sv; + + case HttpContentType::kCbObject: + return "application/x-ue-cb"sv; + + case HttpContentType::kCbPackage: + return "application/x-ue-cbpkg"sv; + + case HttpContentType::kCbPackageOffer: + return "application/x-ue-offer"sv; + + case HttpContentType::kCompressedBinary: + return "application/x-ue-comp"sv; + + case HttpContentType::kYAML: + return "text/yaml"sv; + + case HttpContentType::kHTML: + return "text/html"sv; + + case HttpContentType::kJavaScript: + return "application/javascript"sv; + + case HttpContentType::kCSS: + return "text/css"sv; + + case HttpContentType::kPNG: + return "image/png"sv; + + case HttpContentType::kIcon: + return "image/x-icon"sv; + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Note that in addition to MIME types we accept abbreviated versions, for +// use in suffix parsing as well as for convenience when using curl + +static constinit uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv); +static constinit uint32_t HashJson = HashStringDjb2("json"sv); +static constinit uint32_t HashApplicationJson = HashStringDjb2("application/json"sv); +static constinit uint32_t HashYaml = HashStringDjb2("yaml"sv); +static constinit uint32_t HashTextYaml = HashStringDjb2("text/yaml"sv); +static constinit uint32_t HashText = HashStringDjb2("text/plain"sv); +static constinit uint32_t HashApplicationCompactBinary = HashStringDjb2("application/x-ue-cb"sv); +static constinit uint32_t HashCompactBinary = HashStringDjb2("ucb"sv); +static constinit uint32_t HashCompactBinaryPackage = HashStringDjb2("application/x-ue-cbpkg"sv); +static constinit uint32_t HashCompactBinaryPackageShort = HashStringDjb2("cbpkg"sv); +static constinit uint32_t HashCompactBinaryPackageOffer = HashStringDjb2("application/x-ue-offer"sv); +static constinit uint32_t HashCompressedBinary = HashStringDjb2("application/x-ue-comp"sv); +static constinit uint32_t HashHtml = HashStringDjb2("html"sv); +static constinit uint32_t HashTextHtml = HashStringDjb2("text/html"sv); +static constinit uint32_t HashJavaScript = HashStringDjb2("js"sv); +static constinit uint32_t HashApplicationJavaScript = HashStringDjb2("application/javascript"sv); +static constinit uint32_t HashCss = HashStringDjb2("css"sv); +static constinit uint32_t HashTextCss = HashStringDjb2("text/css"sv); +static constinit uint32_t HashPng = HashStringDjb2("png"sv); +static constinit uint32_t HashImagePng = HashStringDjb2("image/png"sv); +static constinit uint32_t HashIcon = HashStringDjb2("ico"sv); +static constinit uint32_t HashImageIcon = HashStringDjb2("image/x-icon"sv); + +std::once_flag InitContentTypeLookup; + +struct HashedTypeEntry +{ + uint32_t Hash; + HttpContentType Type; +} TypeHashTable[] = { + // clang-format off + {HashBinary, HttpContentType::kBinary}, + {HashApplicationCompactBinary, HttpContentType::kCbObject}, + {HashCompactBinary, HttpContentType::kCbObject}, + {HashCompactBinaryPackage, HttpContentType::kCbPackage}, + {HashCompactBinaryPackageShort, HttpContentType::kCbPackage}, + {HashCompactBinaryPackageOffer, HttpContentType::kCbPackageOffer}, + {HashJson, HttpContentType::kJSON}, + {HashApplicationJson, HttpContentType::kJSON}, + {HashYaml, HttpContentType::kYAML}, + {HashTextYaml, HttpContentType::kYAML}, + {HashText, HttpContentType::kText}, + {HashCompressedBinary, HttpContentType::kCompressedBinary}, + {HashHtml, HttpContentType::kHTML}, + {HashTextHtml, HttpContentType::kHTML}, + {HashJavaScript, HttpContentType::kJavaScript}, + {HashApplicationJavaScript, HttpContentType::kJavaScript}, + {HashCss, HttpContentType::kCSS}, + {HashTextCss, HttpContentType::kCSS}, + {HashPng, HttpContentType::kPNG}, + {HashImagePng, HttpContentType::kPNG}, + {HashIcon, HttpContentType::kIcon}, + {HashImageIcon, HttpContentType::kIcon}, + // clang-format on +}; + +HttpContentType +ParseContentTypeImpl(const std::string_view& ContentTypeString) +{ + if (!ContentTypeString.empty()) + { + const uint32_t CtHash = HashStringDjb2(ContentTypeString); + + if (auto It = std::lower_bound(std::begin(TypeHashTable), + std::end(TypeHashTable), + CtHash, + [](const HashedTypeEntry& Lhs, const uint32_t Rhs) { return Lhs.Hash < Rhs; }); + It != std::end(TypeHashTable)) + { + if (It->Hash == CtHash) + { + return It->Type; + } + } + } + + return HttpContentType::kUnknownContentType; +} + +HttpContentType +ParseContentTypeInit(const std::string_view& ContentTypeString) +{ + std::call_once(InitContentTypeLookup, [] { + std::sort(std::begin(TypeHashTable), std::end(TypeHashTable), [](const HashedTypeEntry& Lhs, const HashedTypeEntry& Rhs) { + return Lhs.Hash < Rhs.Hash; + }); + + // validate that there are no hash collisions + + uint32_t LastHash = 0; + + for (const auto& Item : TypeHashTable) + { + ZEN_ASSERT(LastHash != Item.Hash); + LastHash = Item.Hash; + } + }); + + ParseContentType = ParseContentTypeImpl; + + return ParseContentTypeImpl(ContentTypeString); +} + +HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString) = &ParseContentTypeInit; + +bool +TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges) +{ + if (RangeHeader.empty()) + { + return false; + } + + const size_t Count = Ranges.size(); + + std::size_t UnitDelim = RangeHeader.find_first_of('='); + if (UnitDelim == std::string_view::npos) + { + return false; + } + + // only bytes for now + std::string_view Unit = RangeHeader.substr(0, UnitDelim); + if (Unit != "bytes"sv) + { + return false; + } + + std::string_view Tokens = RangeHeader.substr(UnitDelim); + while (!Tokens.empty()) + { + // Skip =, + Tokens = Tokens.substr(1); + + size_t Delim = Tokens.find_first_of(','); + if (Delim == std::string_view::npos) + { + Delim = Tokens.length(); + } + + std::string_view Token = Tokens.substr(0, Delim); + Tokens = Tokens.substr(Delim); + + Delim = Token.find_first_of('-'); + if (Delim == std::string_view::npos) + { + return false; + } + + const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim)); + const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1)); + + if (Start.has_value() && End.has_value() && End.value() > Start.value()) + { + Ranges.push_back({.Start = Start.value(), .End = End.value()}); + } + else if (Start) + { + Ranges.push_back({.Start = Start.value()}); + } + else if (End) + { + Ranges.push_back({.End = End.value()}); + } + } + + return Count != Ranges.size(); +} + +////////////////////////////////////////////////////////////////////////// + +const std::string_view +ToString(HttpVerb Verb) +{ + switch (Verb) + { + case HttpVerb::kGet: + return "GET"sv; + case HttpVerb::kPut: + return "PUT"sv; + case HttpVerb::kPost: + return "POST"sv; + case HttpVerb::kDelete: + return "DELETE"sv; + case HttpVerb::kHead: + return "HEAD"sv; + case HttpVerb::kCopy: + return "COPY"sv; + case HttpVerb::kOptions: + return "OPTIONS"sv; + default: + return "???"sv; + } +} + +std::string_view +ReasonStringForHttpResultCode(int HttpCode) +{ + switch (HttpCode) + { + // 1xx Informational + + case 100: + return "Continue"sv; + case 101: + return "Switching Protocols"sv; + + // 2xx Success + + case 200: + return "OK"sv; + case 201: + return "Created"sv; + case 202: + return "Accepted"sv; + case 204: + return "No Content"sv; + case 205: + return "Reset Content"sv; + case 206: + return "Partial Content"sv; + + // 3xx Redirection + + case 300: + return "Multiple Choices"sv; + case 301: + return "Moved Permanently"sv; + case 302: + return "Found"sv; + case 303: + return "See Other"sv; + case 304: + return "Not Modified"sv; + case 305: + return "Use Proxy"sv; + case 306: + return "Switch Proxy"sv; + case 307: + return "Temporary Redirect"sv; + case 308: + return "Permanent Redirect"sv; + + // 4xx Client errors + + case 400: + return "Bad Request"sv; + case 401: + return "Unauthorized"sv; + case 402: + return "Payment Required"sv; + case 403: + return "Forbidden"sv; + case 404: + return "Not Found"sv; + case 405: + return "Method Not Allowed"sv; + case 406: + return "Not Acceptable"sv; + case 407: + return "Proxy Authentication Required"sv; + case 408: + return "Request Timeout"sv; + case 409: + return "Conflict"sv; + case 410: + return "Gone"sv; + case 411: + return "Length Required"sv; + case 412: + return "Precondition Failed"sv; + case 413: + return "Payload Too Large"sv; + case 414: + return "URI Too Long"sv; + case 415: + return "Unsupported Media Type"sv; + case 416: + return "Range Not Satisifiable"sv; + case 417: + return "Expectation Failed"sv; + case 418: + return "I'm a teapot"sv; + case 421: + return "Misdirected Request"sv; + case 422: + return "Unprocessable Entity"sv; + case 423: + return "Locked"sv; + case 424: + return "Failed Dependency"sv; + case 425: + return "Too Early"sv; + case 426: + return "Upgrade Required"sv; + case 428: + return "Precondition Required"sv; + case 429: + return "Too Many Requests"sv; + case 431: + return "Request Header Fields Too Large"sv; + + // 5xx Server errors + + case 500: + return "Internal Server Error"sv; + case 501: + return "Not Implemented"sv; + case 502: + return "Bad Gateway"sv; + case 503: + return "Service Unavailable"sv; + case 504: + return "Gateway Timeout"sv; + case 505: + return "HTTP Version Not Supported"sv; + case 506: + return "Variant Also Negotiates"sv; + case 507: + return "Insufficient Storage"sv; + case 508: + return "Loop Detected"sv; + case 510: + return "Not Extended"sv; + case 511: + return "Network Authentication Required"sv; + + default: + return "Unknown Result"sv; + } +} + +////////////////////////////////////////////////////////////////////////// + +Ref<IHttpPackageHandler> +HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) +{ + ZEN_UNUSED(HttpServiceRequest); + + return Ref<IHttpPackageHandler>(); +} + +////////////////////////////////////////////////////////////////////////// + +HttpServerRequest::HttpServerRequest() +{ +} + +HttpServerRequest::~HttpServerRequest() +{ +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbPackage Data) +{ + std::vector<IoBuffer> ResponseBuffers = FormatPackageMessage(Data); + return WriteResponse(ResponseCode, HttpContentType::kCbPackage, ResponseBuffers); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbObject Data) +{ + if (m_AcceptType == HttpContentType::kJSON) + { + ExtendableStringBuilder<1024> Sb; + WriteResponse(ResponseCode, HttpContentType::kJSON, Data.ToJson(Sb).ToView()); + } + else + { + SharedBuffer Buf = Data.GetBuffer(); + std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())}; + return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers); + } +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbArray Array) +{ + if (m_AcceptType == HttpContentType::kJSON) + { + ExtendableStringBuilder<1024> Sb; + WriteResponse(ResponseCode, HttpContentType::kJSON, Array.ToJson(Sb).ToView()); + } + else + { + SharedBuffer Buf = Array.GetBuffer(); + std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())}; + return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers); + } +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString) +{ + return WriteResponse(ResponseCode, ContentType, std::u8string_view{(char8_t*)ResponseString.data(), ResponseString.size()}); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob) +{ + std::array<IoBuffer, 1> Buffers{Blob}; + return WriteResponse(ResponseCode, ContentType, Buffers); +} + +void +HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload) +{ + std::span<const SharedBuffer> Segments = Payload.GetSegments(); + + std::vector<IoBuffer> Buffers; + + for (auto& Segment : Segments) + { + Buffers.push_back(Segment.AsIoBuffer()); + } + + WriteResponse(ResponseCode, ContentType, Buffers); +} + +HttpServerRequest::QueryParams +HttpServerRequest::GetQueryParams() +{ + QueryParams Params; + + const std::string_view QStr = QueryString(); + + const char* QueryIt = QStr.data(); + const char* QueryEnd = QueryIt + QStr.size(); + + while (QueryIt != QueryEnd) + { + if (*QueryIt == '&') + { + ++QueryIt; + continue; + } + + size_t QueryLen = ptrdiff_t(QueryEnd - QueryIt); + const std::string_view Query{QueryIt, QueryLen}; + + size_t DelimIndex = Query.find('&', 0); + + if (DelimIndex == std::string_view::npos) + { + DelimIndex = Query.size(); + } + + std::string_view ThisQuery{QueryIt, DelimIndex}; + + size_t EqIndex = ThisQuery.find('=', 0); + + if (EqIndex != std::string_view::npos) + { + std::string_view Param{ThisQuery.data(), EqIndex}; + ThisQuery.remove_prefix(EqIndex + 1); + + Params.KvPairs.emplace_back(Param, ThisQuery); + } + + QueryIt += DelimIndex; + } + + return Params; +} + +Oid +HttpServerRequest::SessionId() const +{ + if (m_Flags & kHaveSessionId) + { + return m_SessionId; + } + + m_SessionId = ParseSessionId(); + m_Flags |= kHaveSessionId; + return m_SessionId; +} + +uint32_t +HttpServerRequest::RequestId() const +{ + if (m_Flags & kHaveRequestId) + { + return m_RequestId; + } + + m_RequestId = ParseRequestId(); + m_Flags |= kHaveRequestId; + return m_RequestId; +} + +CbObject +HttpServerRequest::ReadPayloadObject() +{ + if (IoBuffer Payload = ReadPayload()) + { + return LoadCompactBinaryObject(std::move(Payload)); + } + + return {}; +} + +CbPackage +HttpServerRequest::ReadPayloadPackage() +{ + if (IoBuffer Payload = ReadPayload()) + { + return ParsePackageMessage(std::move(Payload)); + } + + return {}; +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpRequestRouter::AddPattern(const char* Id, const char* Regex) +{ + ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end()); + + m_PatternMap.insert({Id, Regex}); +} + +void +HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs) +{ + ExtendableStringBuilder<128> ExpandedRegex; + ProcessRegexSubstitutions(Regex, ExpandedRegex); + + m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex); +} + +void +HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex) +{ + size_t RegexLen = strlen(Regex); + + for (size_t i = 0; i < RegexLen;) + { + bool matched = false; + + if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\'))) + { + // Might have a pattern reference - find closing brace + + for (size_t j = i + 1; j < RegexLen; ++j) + { + if (Regex[j] == '}') + { + std::string Pattern(&Regex[i + 1], j - i - 1); + + if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end()) + { + OutExpandedRegex.Append(it->second.c_str()); + } + else + { + // Default to anything goes (or should this just be an error?) + + OutExpandedRegex.Append("(.+?)"); + } + + // skip ahead + i = j + 1; + + matched = true; + + break; + } + } + } + + if (!matched) + { + OutExpandedRegex.Append(Regex[i++]); + } + } +} + +bool +HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) +{ + const HttpVerb Verb = Request.RequestVerb(); + + std::string_view Uri = Request.RelativeUri(); + HttpRouterRequest RouterRequest(Request); + + for (const auto& Handler : m_Handlers) + { + if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx)) + { + Handler.Handler(RouterRequest); + + return true; // Route matched + } + } + + return false; // No route matched +} + +////////////////////////////////////////////////////////////////////////// + +HttpRpcHandler::HttpRpcHandler() +{ +} + +HttpRpcHandler::~HttpRpcHandler() +{ +} + +void +HttpRpcHandler::AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction) +{ + ZEN_UNUSED(RpcId, HandlerFunction); +} + +////////////////////////////////////////////////////////////////////////// + +enum class HttpServerClass +{ + kHttpAsio, + kHttpSys, + kHttpNull +}; + +// Implemented in httpsys.cpp +Ref<HttpServer> CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads); + +Ref<HttpServer> +CreateHttpServer(std::string_view ServerClass) +{ + using namespace std::literals; + + HttpServerClass Class = HttpServerClass::kHttpNull; + +#if ZEN_WITH_HTTPSYS + Class = HttpServerClass::kHttpSys; +#elif 1 + Class = HttpServerClass::kHttpAsio; +#endif + + if (ServerClass == "asio"sv) + { + Class = HttpServerClass::kHttpAsio; + } + else if (ServerClass == "httpsys"sv) + { + Class = HttpServerClass::kHttpSys; + } + else if (ServerClass == "null"sv) + { + Class = HttpServerClass::kHttpNull; + } + + switch (Class) + { + default: + case HttpServerClass::kHttpAsio: + ZEN_INFO("using asio HTTP server implementation"); + return Ref<HttpServer>(new HttpAsioServer()); + +#if ZEN_WITH_HTTPSYS + case HttpServerClass::kHttpSys: + ZEN_INFO("using http.sys server implementation"); + return Ref<HttpServer>(new HttpSysServer(std::thread::hardware_concurrency(), /* background worker threads */ 16)); +#endif + + case HttpServerClass::kHttpNull: + ZEN_INFO("using null HTTP server implementation"); + return Ref<HttpServer>(new HttpNullServer); + } +} + +////////////////////////////////////////////////////////////////////////// + +bool +HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef) +{ + if (Request.RequestVerb() == HttpVerb::kPost) + { + if (Request.RequestContentType() == HttpContentType::kCbPackageOffer) + { + // The client is presenting us with a package attachments offer, we need + // to filter it down to the list of attachments we need them to send in + // the follow-up request + + PackageHandlerRef = Service.HandlePackageRequest(Request); + + if (PackageHandlerRef) + { + CbObject OfferMessage = LoadCompactBinaryObject(Request.ReadPayload()); + + std::vector<IoHash> OfferCids; + + for (auto& CidEntry : OfferMessage["offer"]) + { + if (!CidEntry.IsHash()) + { + // Should yield bad request response? + + ZEN_WARN("found invalid entry in offer"); + + continue; + } + + OfferCids.push_back(CidEntry.AsHash()); + } + + ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size()); + + PackageHandlerRef->FilterOffer(OfferCids); + + ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size()); + + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + + for (const IoHash& Cid : OfferCids) + { + ResponseWriter.AddHash(Cid); + } + + ResponseWriter.EndArray(); + + // Emit filter response + Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); + return true; + } + } + else if (Request.RequestContentType() == HttpContentType::kCbPackage) + { + // Process chunks in package request + + PackageHandlerRef = Service.HandlePackageRequest(Request); + + // TODO: this should really be done in a streaming fashion, currently this emulates + // the intended flow from an API perspective + + if (PackageHandlerRef) + { + PackageHandlerRef->OnRequestBegin(); + + auto CreateBuffer = [&](const IoHash& Cid, uint64_t Size) -> IoBuffer { + return PackageHandlerRef->CreateTarget(Cid, Size); + }; + + CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer); + + PackageHandlerRef->OnRequestComplete(); + } + } + } + return false; +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +TEST_CASE("http.common") +{ + using namespace std::literals; + + SUBCASE("router") + { + HttpRequestRouter r; + r.AddPattern("a", "[[:alpha:]]+"); + r.RegisterRoute( + "{a}", + [&](auto) {}, + HttpVerb::kGet); + + // struct TestHttpServerRequest : public HttpServerRequest + //{ + // TestHttpServerRequest(std::string_view Uri) : m_uri{Uri} {} + //}; + + // TestHttpServerRequest req{}; + // r.HandleRequest(req); + } + + SUBCASE("content-type") + { + for (uint8_t i = 0; i < uint8_t(HttpContentType::kCOUNT); ++i) + { + HttpContentType Ct{i}; + + if (Ct != HttpContentType::kUnknownContentType) + { + CHECK_EQ(Ct, ParseContentType(MapContentTypeToString(Ct))); + } + } + } +} + +void +http_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenhttp/httpshared.cpp b/src/zenhttp/httpshared.cpp new file mode 100644 index 000000000..7aade56d2 --- /dev/null +++ b/src/zenhttp/httpshared.cpp @@ -0,0 +1,809 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpshared.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compositebuffer.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> + +#include <span> +#include <vector> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +const std::string_view HandlePrefix(":?#:"); + +std::vector<IoBuffer> +FormatPackageMessage(const CbPackage& Data, int TargetProcessPid) +{ + return FormatPackageMessage(Data, FormatFlags::kDefault, TargetProcessPid); +} +CompositeBuffer +FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid) +{ + return FormatPackageMessageBuffer(Data, FormatFlags::kDefault, TargetProcessPid); +} + +CompositeBuffer +FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid) +{ + std::vector<IoBuffer> Message = FormatPackageMessage(Data, Flags, TargetProcessPid); + + std::vector<SharedBuffer> Buffers; + + for (IoBuffer& Buf : Message) + { + Buffers.push_back(SharedBuffer(Buf)); + } + + return CompositeBuffer(std::move(Buffers)); +} + +std::vector<IoBuffer> +FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid) +{ + void* TargetProcessHandle = nullptr; +#if ZEN_PLATFORM_WINDOWS + std::vector<HANDLE> DuplicatedHandles; + auto _ = MakeGuard([&DuplicatedHandles, &TargetProcessHandle]() { + if (TargetProcessHandle == nullptr) + { + return; + } + + for (HANDLE DuplicatedHandle : DuplicatedHandles) + { + HANDLE ClosingHandle; + if (::DuplicateHandle((HANDLE)TargetProcessHandle, + DuplicatedHandle, + GetCurrentProcess(), + &ClosingHandle, + 0, + FALSE, + DUPLICATE_CLOSE_SOURCE | DUPLICATE_SAME_ACCESS) == TRUE) + { + ::CloseHandle(ClosingHandle); + } + } + ::CloseHandle((HANDLE)TargetProcessHandle); + TargetProcessHandle = nullptr; + }); + + if (EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && TargetProcessPid != 0) + { + TargetProcessHandle = OpenProcess(PROCESS_DUP_HANDLE, FALSE, TargetProcessPid); + } +#else + ZEN_UNUSED(TargetProcessPid); + void* DuplicatedHandles = nullptr; +#endif // ZEN_PLATFORM_WINDOWS + + const std::span<const CbAttachment>& Attachments = Data.GetAttachments(); + std::vector<IoBuffer> ResponseBuffers; + + ResponseBuffers.reserve(3 + Attachments.size()); // TODO: may want to use an additional fudge factor here to avoid growing since each + // attachment is likely to consist of several buffers + + // Fixed size header + + CbPackageHeader Hdr{.HeaderMagic = kCbPkgMagic, .AttachmentCount = gsl::narrow<uint32_t>(Attachments.size())}; + + ResponseBuffers.push_back(IoBufferBuilder::MakeCloneFromMemory(&Hdr, sizeof Hdr)); + + // Attachment metadata array + + IoBuffer AttachmentMetadataBuffer = IoBuffer{sizeof(CbAttachmentEntry) * (Attachments.size() + /* root */ 1)}; + CbAttachmentEntry* AttachmentInfo = reinterpret_cast<CbAttachmentEntry*>(AttachmentMetadataBuffer.MutableData()); + + ResponseBuffers.push_back(AttachmentMetadataBuffer); // Attachment metadata + + // Root object + + IoBuffer RootIoBuffer = Data.GetObject().GetBuffer().AsIoBuffer(); + ResponseBuffers.push_back(RootIoBuffer); // Root object + + *AttachmentInfo++ = {.PayloadSize = RootIoBuffer.Size(), .Flags = CbAttachmentEntry::kIsObject, .AttachmentHash = Data.GetObjectHash()}; + + // Attachment payloads + + auto MarshalLocal = [&AttachmentInfo, &ResponseBuffers](const std::string& Path8, + CbAttachmentReferenceHeader& LocalRef, + const IoHash& AttachmentHash, + bool IsCompressed) { + IoBuffer RefBuffer(sizeof(CbAttachmentReferenceHeader) + Path8.size()); + + CbAttachmentReferenceHeader* RefHdr = RefBuffer.MutableData<CbAttachmentReferenceHeader>(); + *RefHdr++ = LocalRef; + memcpy(RefHdr, Path8.data(), Path8.size()); + + *AttachmentInfo++ = {.PayloadSize = RefBuffer.GetSize(), + .Flags = (IsCompressed ? uint32_t(CbAttachmentEntry::kIsCompressed) : 0u) | CbAttachmentEntry::kIsLocalRef, + .AttachmentHash = AttachmentHash}; + + ResponseBuffers.push_back(std::move(RefBuffer)); + }; + + tsl::robin_map<void*, std::string> FileNameMap; + + auto IsLocalRef = [&FileNameMap, &DuplicatedHandles](const CompositeBuffer& AttachmentBinary, + bool DenyPartialLocalReferences, + void* TargetProcessHandle, + CbAttachmentReferenceHeader& LocalRef, + std::string& Path8) -> bool { + const SharedBuffer& Segment = AttachmentBinary.GetSegments().front(); + IoBufferFileReference Ref; + const IoBuffer& SegmentBuffer = Segment.AsIoBuffer(); + + if (!SegmentBuffer.GetFileReference(Ref)) + { + return false; + } + + if (DenyPartialLocalReferences && !SegmentBuffer.IsWholeFile()) + { + return false; + } + + if (auto It = FileNameMap.find(Ref.FileHandle); It != FileNameMap.end()) + { + Path8 = It->second; + } + else + { + bool UseFilePath = true; +#if ZEN_PLATFORM_WINDOWS + if (TargetProcessHandle != nullptr) + { + HANDLE TargetHandle = INVALID_HANDLE_VALUE; + BOOL OK = ::DuplicateHandle(GetCurrentProcess(), + Ref.FileHandle, + (HANDLE)TargetProcessHandle, + &TargetHandle, + FILE_GENERIC_READ, + FALSE, + 0); + if (OK) + { + DuplicatedHandles.push_back(TargetHandle); + Path8 = fmt::format("{}{}", HandlePrefix, reinterpret_cast<uint64_t>(TargetHandle)); + UseFilePath = false; + } + } +#else // ZEN_PLATFORM_WINDOWS + ZEN_UNUSED(TargetProcessHandle); + // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes and to + // deal with acceess rights etc. +#endif // ZEN_PLATFORM_WINDOWS + if (UseFilePath) + { + ExtendablePathBuilder<256> LocalRefFile; + LocalRefFile.Append(std::filesystem::absolute(PathFromHandle(Ref.FileHandle))); + Path8 = LocalRefFile.ToUtf8(); + } + FileNameMap.insert_or_assign(Ref.FileHandle, Path8); + } + + LocalRef.AbsolutePathLength = gsl::narrow<uint16_t>(Path8.size()); + LocalRef.PayloadByteOffset = Ref.FileChunkOffset; + LocalRef.PayloadByteSize = Ref.FileChunkSize; + + return true; + }; + + for (const CbAttachment& Attachment : Attachments) + { + if (Attachment.IsNull()) + { + ZEN_NOT_IMPLEMENTED("Null attachments are not supported"); + } + else if (CompressedBuffer AttachmentBuffer = Attachment.AsCompressedBinary()) + { + CompositeBuffer Compressed = AttachmentBuffer.GetCompressed(); + IoHash AttachmentHash = Attachment.GetHash(); + + // If the data is either not backed by a file, or there are multiple + // fragments then we cannot marshal it by local reference. We might + // want/need to extend this in the future to allow multiple chunk + // segments to be marshaled at once + + bool MarshalByLocalRef = EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (Compressed.GetSegments().size() == 1); + bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences); + CbAttachmentReferenceHeader LocalRef; + std::string Path8; + + if (MarshalByLocalRef) + { + MarshalByLocalRef = IsLocalRef(Compressed, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8); + } + + if (MarshalByLocalRef) + { + const bool IsCompressed = true; + bool IsHandle = false; +#if ZEN_PLATFORM_WINDOWS + IsHandle = Path8.starts_with(HandlePrefix); +#endif + MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed); + ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", Compressed.GetSize()); + } + else + { + *AttachmentInfo++ = {.PayloadSize = AttachmentBuffer.GetCompressedSize(), + .Flags = CbAttachmentEntry::kIsCompressed, + .AttachmentHash = AttachmentHash}; + + for (const SharedBuffer& Segment : Compressed.GetSegments()) + { + ResponseBuffers.push_back(Segment.AsIoBuffer()); + } + } + } + else if (CbObject AttachmentObject = Attachment.AsObject()) + { + IoBuffer ObjIoBuffer = AttachmentObject.GetBuffer().AsIoBuffer(); + ResponseBuffers.push_back(ObjIoBuffer); + + *AttachmentInfo++ = {.PayloadSize = ObjIoBuffer.Size(), + .Flags = CbAttachmentEntry::kIsObject, + .AttachmentHash = Attachment.GetHash()}; + } + else if (CompositeBuffer AttachmentBinary = Attachment.AsCompositeBinary()) + { + IoHash AttachmentHash = Attachment.GetHash(); + bool MarshalByLocalRef = + EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (AttachmentBinary.GetSegments().size() == 1); + bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences); + + CbAttachmentReferenceHeader LocalRef; + std::string Path8; + + if (MarshalByLocalRef) + { + MarshalByLocalRef = IsLocalRef(AttachmentBinary, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8); + } + + if (MarshalByLocalRef) + { + const bool IsCompressed = false; + bool IsHandle = false; +#if ZEN_PLATFORM_WINDOWS + IsHandle = Path8.starts_with(HandlePrefix); +#endif + MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed); + ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", AttachmentBinary.GetSize()); + } + else + { + *AttachmentInfo++ = {.PayloadSize = AttachmentBinary.GetSize(), .Flags = 0, .AttachmentHash = Attachment.GetHash()}; + + for (const SharedBuffer& Segment : AttachmentBinary.GetSegments()) + { + ResponseBuffers.push_back(Segment.AsIoBuffer()); + } + } + } + else + { + ZEN_NOT_IMPLEMENTED("Unknown attachment kind"); + } + } + FileNameMap.clear(); +#if ZEN_PLATFORM_WINDOWS + DuplicatedHandles.clear(); +#endif // ZEN_PLATFORM_WINDOWS + + return ResponseBuffers; +} + +bool +IsPackageMessage(IoBuffer Payload) +{ + if (!Payload) + { + return false; + } + + BinaryReader Reader(Payload); + + CbPackageHeader Hdr; + Reader.Read(&Hdr, sizeof Hdr); + + if (Hdr.HeaderMagic != kCbPkgMagic) + { + return false; + } + + return true; +} + +CbPackage +ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer) +{ + if (!Payload) + { + return {}; + } + + BinaryReader Reader(Payload); + + CbPackageHeader Hdr; + Reader.Read(&Hdr, sizeof Hdr); + + if (Hdr.HeaderMagic != kCbPkgMagic) + { + throw std::runtime_error("invalid CbPackage header magic"); + } + + const uint32_t ChunkCount = Hdr.AttachmentCount + 1; + + std::unique_ptr<CbAttachmentEntry[]> AttachmentEntries{new CbAttachmentEntry[ChunkCount]}; + + Reader.Read(AttachmentEntries.get(), sizeof(CbAttachmentEntry) * ChunkCount); + + CbPackage Package; + + std::vector<CbAttachment> Attachments; + Attachments.reserve(ChunkCount); // Guessing here... + + tsl::robin_map<std::string, IoBuffer> PartialFileBuffers; + + // TODO: Throwing before this loop completes could result in leaking handles as we might not have picked up all the handles in the + // message + for (uint32_t i = 0; i < ChunkCount; ++i) + { + const CbAttachmentEntry& Entry = AttachmentEntries[i]; + const uint64_t AttachmentSize = Entry.PayloadSize; + + const IoBuffer AttachmentBuffer(Payload, Reader.CurrentOffset(), AttachmentSize); + Reader.Skip(AttachmentSize); + + if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) + { + // Marshal local reference - a "pointer" to the chunk backing file + + ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); + + const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); + const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1); + + ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); + std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength); + + IoBuffer FullFileBuffer; + + std::filesystem::path Path(Utf8ToWide(PathView)); + if (auto It = PartialFileBuffers.find(Path.string()); It != PartialFileBuffers.end()) + { + FullFileBuffer = It->second; + } + else + { + if (PathView.starts_with(HandlePrefix)) + { +#if ZEN_PLATFORM_WINDOWS + std::string_view HandleString(PathView.substr(HandlePrefix.length())); + std::optional<uint64_t> HandleNumber(ParseInt<uint64_t>(HandleString)); + if (HandleNumber.has_value()) + { + HANDLE FileHandle = HANDLE(HandleNumber.value()); + ULARGE_INTEGER liFileSize; + liFileSize.LowPart = ::GetFileSize(FileHandle, &liFileSize.HighPart); + if (liFileSize.LowPart != INVALID_FILE_SIZE) + { + FullFileBuffer = IoBuffer(IoBuffer::File, (void*)FileHandle, 0, uint64_t(liFileSize.QuadPart)); + PartialFileBuffers.insert_or_assign(Path.string(), FullFileBuffer); + } + } +#else // ZEN_PLATFORM_WINDOWS + // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes + // and to deal with acceess rights etc. + ZEN_ASSERT(false); +#endif // ZEN_PLATFORM_WINDOWS + } + else + { + FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second; + } + } + + if (!FullFileBuffer) + { + // Unable to open chunk reference + throw std::runtime_error(fmt::format("unable to resolve chunk #{} at '{}' (offset {}, size {})", + i, + Path, + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize)); + } + + IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize() + ? FullFileBuffer + : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); + + CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkReference))); + if (!CompBuf) + { + throw std::runtime_error(fmt::format("invalid format for chunk #{} at '{}' (offset {}, size {})", + i, + Path, + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize)); + } + Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash)); + } + else if (Entry.Flags & CbAttachmentEntry::kIsCompressed) + { + if (Entry.Flags & CbAttachmentEntry::kIsObject) + { + if (i == 0) + { + CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer))); + if (!CompBuf) + { + throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for CbObject", i)); + } + // First payload is always a compact binary object + Package.SetObject(LoadCompactBinaryObject(std::move(CompBuf))); + } + else + { + ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported"); + } + } + else + { + CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer))); + if (!CompBuf) + { + throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for attachment", i)); + } + Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash)); + } + } + else /* not compressed */ + { + if (Entry.Flags & CbAttachmentEntry::kIsObject) + { + if (i == 0) + { + Package.SetObject(LoadCompactBinaryObject(AttachmentBuffer)); + } + else + { + ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported"); + } + } + else + { + // Make a copy of the buffer so we attachements don't reference the entire payload + IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize); + ZEN_ASSERT(AttachmentBufferCopy); + ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize); + AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView()); + + CbAttachment Attachment(SharedBuffer{AttachmentBufferCopy}); + Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy}); + } + } + } + PartialFileBuffers.clear(); + + Package.AddAttachments(Attachments); + + return Package; +} + +bool +ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage) +{ + if (IsPackageMessage(Response)) + { + OutPackage = ParsePackageMessage(Response); + return true; + } + return OutPackage.TryLoad(Response); +} + +CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; }) +{ +} + +CbPackageReader::~CbPackageReader() +{ +} + +void +CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer) +{ + m_CreateBuffer = CreateBuffer; +} + +uint64_t +CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes) +{ + ZEN_ASSERT(m_CurrentState != State::kReadingBuffers); + + switch (m_CurrentState) + { + case State::kInitialState: + ZEN_ASSERT(Data == nullptr); + m_CurrentState = State::kReadingHeader; + return sizeof m_PackageHeader; + + case State::kReadingHeader: + ZEN_ASSERT(DataBytes == sizeof m_PackageHeader); + memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader); + ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic); + m_CurrentState = State::kReadingAttachmentEntries; + m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1); + return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry); + + case State::kReadingAttachmentEntries: + ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry))); + memcpy(m_AttachmentEntries.data(), Data, DataBytes); + + for (CbAttachmentEntry& Entry : m_AttachmentEntries) + { + // This preallocates memory for payloads but note that for the local references + // the caller will need to handle the payload differently (i.e it's a + // CbAttachmentReferenceHeader not the actual payload) + + m_PayloadBuffers.push_back(IoBuffer{Entry.PayloadSize}); + } + + m_CurrentState = State::kReadingBuffers; + return 0; + + default: + ZEN_ASSERT(false); + return 0; + } +} + +IoBuffer +CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer) +{ + // Marshal local reference - a "pointer" to the chunk backing file + + ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader)); + + const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>(); + const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1); + + ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength)); + + std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength}; + + std::filesystem::path Path{PathView}; + + IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize); + + if (!ChunkReference) + { + // Unable to open chunk reference + + throw std::runtime_error(fmt::format("unable to resolve local reference to '{}' (offset {}, size {})", + PathToUtf8(Path), + AttachRefHdr->PayloadByteOffset, + AttachRefHdr->PayloadByteSize)); + } + + return ChunkReference; +}; + +void +CbPackageReader::Finalize() +{ + if (m_AttachmentEntries.empty()) + { + return; + } + + m_Attachments.reserve(m_AttachmentEntries.size() - 1); + + int CurrentAttachmentIndex = 0; + for (CbAttachmentEntry& Entry : m_AttachmentEntries) + { + IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex]; + + if (CurrentAttachmentIndex == 0) + { + // Root object + if (Entry.Flags & CbAttachmentEntry::kIsObject) + { + if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) + { + m_RootObject = LoadCompactBinaryObject(MarshalLocalChunkReference(AttachmentBuffer)); + } + else if (Entry.Flags & CbAttachmentEntry::kIsCompressed) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentBuffer), RawHash, RawSize); + if (RawHash == Entry.AttachmentHash) + { + m_RootObject = LoadCompactBinaryObject(Compressed); + } + } + else + { + m_RootObject = LoadCompactBinaryObject(std::move(AttachmentBuffer)); + } + } + else + { + throw std::runtime_error("missing or invalid root object"); + } + } + else if (Entry.Flags & CbAttachmentEntry::kIsLocalRef) + { + IoBuffer ChunkReference = MarshalLocalChunkReference(AttachmentBuffer); + + if (Entry.Flags & CbAttachmentEntry::kIsCompressed) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkReference), RawHash, RawSize); + if (RawHash == Entry.AttachmentHash) + { + m_Attachments.push_back(CbAttachment(Compressed, Entry.AttachmentHash)); + } + } + else + { + CompressedBuffer Compressed = + CompressedBuffer::Compress(SharedBuffer(ChunkReference), OodleCompressor::NotSet, OodleCompressionLevel::None); + m_Attachments.push_back(CbAttachment(std::move(Compressed), Compressed.DecodeRawHash())); + } + } + + ++CurrentAttachmentIndex; + } +} + +/** + ______________________ _____________________________ + \__ ___/\_ _____// _____/\__ ___/ _____/ + | | | __)_ \_____ \ | | \_____ \ + | | | \/ \ | | / \ + |____| /_______ /_______ / |____| /_______ / + \/ \/ \/ + */ + +#if ZEN_WITH_TESTS + +TEST_CASE("CbPackage.Serialization") +{ + // Make a test package + + CbAttachment Attach1{SharedBuffer::MakeView(MakeMemoryView("abcd"))}; + CbAttachment Attach2{SharedBuffer::MakeView(MakeMemoryView("efgh"))}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("abcd", Attach1); + Cbo.AddAttachment("efgh", Attach2); + + CbPackage Pkg; + Pkg.AddAttachment(Attach1); + Pkg.AddAttachment(Attach2); + Pkg.SetObject(Cbo.Save()); + + SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg).Flatten(); + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); + uint64_t RemainingBytes = Buffer.GetSize(); + + auto ConsumeBytes = [&](uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + void* ReturnPtr = (void*)CursorPtr; + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + return ReturnPtr; + }; + + auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + memcpy(TargetBuffer, CursorPtr, ByteCount); + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + }; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); + NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); + auto Buffers = Reader.GetPayloadBuffers(); + + for (auto& PayloadBuffer : Buffers) + { + CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); + } + + Reader.Finalize(); +} + +TEST_CASE("CbPackage.LocalRef") +{ + ScopedTemporaryDirectory TempDir; + + auto Path1 = TempDir.Path() / "abcd"; + auto Path2 = TempDir.Path() / "efgh"; + + { + IoBuffer Buffer1 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("abcd")); + IoBuffer Buffer2 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("efgh")); + + WriteFile(Path1, Buffer1); + WriteFile(Path2, Buffer2); + } + + // Make a test package + + IoBuffer FileBuffer1 = IoBufferBuilder::MakeFromFile(Path1); + IoBuffer FileBuffer2 = IoBufferBuilder::MakeFromFile(Path2); + + CbAttachment Attach1{SharedBuffer(FileBuffer1)}; + CbAttachment Attach2{SharedBuffer(FileBuffer2)}; + + CbObjectWriter Cbo; + Cbo.AddAttachment("abcd", Attach1); + Cbo.AddAttachment("efgh", Attach2); + + CbPackage Pkg; + Pkg.AddAttachment(Attach1); + Pkg.AddAttachment(Attach2); + Pkg.SetObject(Cbo.Save()); + + SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten(); + const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData()); + uint64_t RemainingBytes = Buffer.GetSize(); + + auto ConsumeBytes = [&](uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + void* ReturnPtr = (void*)CursorPtr; + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + return ReturnPtr; + }; + + auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) { + ZEN_ASSERT(ByteCount <= RemainingBytes); + memcpy(TargetBuffer, CursorPtr, ByteCount); + CursorPtr += ByteCount; + RemainingBytes -= ByteCount; + }; + + CbPackageReader Reader; + uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0); + uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead); + NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes); + auto Buffers = Reader.GetPayloadBuffers(); + + for (auto& PayloadBuffer : Buffers) + { + CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize()); + } + + Reader.Finalize(); +} + +void +forcelink_httpshared() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenhttp/httpsys.cpp b/src/zenhttp/httpsys.cpp new file mode 100644 index 000000000..c733d618d --- /dev/null +++ b/src/zenhttp/httpsys.cpp @@ -0,0 +1,1674 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpsys.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/except.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> +#include <zencore/timer.h> +#include <zenhttp/httpshared.h> + +#if ZEN_WITH_HTTPSYS + +# include <conio.h> +# include <mstcpip.h> +# pragma comment(lib, "httpapi.lib") + +std::wstring +UTF8_to_UTF16(const char* InPtr) +{ + std::wstring OutString; + unsigned int Codepoint; + + while (*InPtr != 0) + { + unsigned char InChar = static_cast<unsigned char>(*InPtr); + + if (InChar <= 0x7f) + Codepoint = InChar; + else if (InChar <= 0xbf) + Codepoint = (Codepoint << 6) | (InChar & 0x3f); + else if (InChar <= 0xdf) + Codepoint = InChar & 0x1f; + else if (InChar <= 0xef) + Codepoint = InChar & 0x0f; + else + Codepoint = InChar & 0x07; + + ++InPtr; + + if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff)) + { + if (Codepoint > 0xffff) + { + OutString.append(1, static_cast<wchar_t>(0xd800 + (Codepoint >> 10))); + OutString.append(1, static_cast<wchar_t>(0xdc00 + (Codepoint & 0x03ff))); + } + else if (Codepoint < 0xd800 || Codepoint >= 0xe000) + { + OutString.append(1, static_cast<wchar_t>(Codepoint)); + } + } + } + + return OutString; +} + +namespace zen { + +using namespace std::literals; + +class HttpSysServer; +class HttpSysTransaction; +class HttpMessageResponseRequest; + +////////////////////////////////////////////////////////////////////////// + +HttpVerb +TranslateHttpVerb(HTTP_VERB ReqVerb) +{ + switch (ReqVerb) + { + case HttpVerbOPTIONS: + return HttpVerb::kOptions; + + case HttpVerbGET: + return HttpVerb::kGet; + + case HttpVerbHEAD: + return HttpVerb::kHead; + + case HttpVerbPOST: + return HttpVerb::kPost; + + case HttpVerbPUT: + return HttpVerb::kPut; + + case HttpVerbDELETE: + return HttpVerb::kDelete; + + case HttpVerbCOPY: + return HttpVerb::kCopy; + + default: + // TODO: invalid request? + return (HttpVerb)0; + } +} + +uint64_t +GetContentLength(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength]; + std::string_view cl(clh.pRawValue, clh.RawValueLength); + uint64_t ContentLength = 0; + std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength); + return ContentLength; +}; + +HttpContentType +GetContentType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +HttpContentType +GetAcceptType(const HTTP_REQUEST* HttpRequest) +{ + const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept]; + return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength}); +}; + +/** + * @brief Base class for any pending or active HTTP transactions + */ +class HttpSysRequestHandler +{ +public: + explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {} + virtual ~HttpSysRequestHandler() = default; + + virtual void IssueRequest(std::error_code& ErrorCode) = 0; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0; + HttpSysTransaction& Transaction() { return m_Transaction; } + + HttpSysRequestHandler(const HttpSysRequestHandler&) = delete; + HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete; + +private: + HttpSysTransaction& m_Transaction; +}; + +/** + * This is the handler for the initial HTTP I/O request which will receive the headers + * and however much of the remaining payload might fit in the embedded request buffer. + * + * It is also used to receive any entity body data relating to the request + * + */ +struct InitialRequestHandler : public HttpSysRequestHandler +{ + inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; } + inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; } + inline bool IsInitialRequest() const { return m_IsInitialRequest; } + + InitialRequestHandler(HttpSysTransaction& InRequest); + ~InitialRequestHandler(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + + bool m_IsInitialRequest = true; + uint64_t m_CurrentPayloadOffset = 0; + uint64_t m_ContentLength = ~uint64_t(0); + IoBuffer m_PayloadBuffer; + UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)]; +}; + +/** + * This is the class which request handlers use to interact with the server instance + */ + +class HttpSysServerRequest : public HttpServerRequest +{ +public: + HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpSysServerRequest() = default; + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpSysServerRequest(const HttpSysServerRequest&) = delete; + HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete; + + HttpSysTransaction& m_HttpTx; + HttpSysRequestHandler* m_NextCompletionHandler = nullptr; + IoBuffer m_PayloadBuffer; + ExtendableStringBuilder<128> m_UriUtf8; + ExtendableStringBuilder<128> m_QueryStringUtf8; +}; + +/** HTTP transaction + + There will be an instance of this per pending and in-flight HTTP transaction + + */ +class HttpSysTransaction final +{ +public: + HttpSysTransaction(HttpSysServer& Server); + virtual ~HttpSysTransaction(); + + enum class Status + { + kDone, + kRequestPending + }; + + Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred); + + static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io); + + void IssueInitialRequest(std::error_code& ErrorCode); + bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler); + + PTP_IO Iocp(); + HANDLE RequestQueueHandle(); + inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; } + inline HttpSysServer& Server() { return m_HttpServer; } + inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); } + + HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload); + + HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); } + +private: + OVERLAPPED m_HttpOverlapped{}; + HttpSysServer& m_HttpServer; + + // Tracks which handler is due to handle the next I/O completion event + HttpSysRequestHandler* m_CompletionHandler = nullptr; + RwLock m_CompletionMutex; + InitialRequestHandler m_InitialHttpHandler{*this}; + std::optional<HttpSysServerRequest> m_HandlerRequest; + Ref<IHttpPackageHandler> m_PackageHandler; +}; + +/** + * @brief HTTP request response I/O request handler + * + * Asynchronously streams out a response to an HTTP request via compound + * responses from memory or directly from file + */ + +class HttpMessageResponseRequest : public HttpSysRequestHandler +{ +public: + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize); + HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> Blobs); + ~HttpMessageResponseRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + void SuppressResponseBody(); // typically used for HEAD requests + +private: + std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks; + uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes + uint16_t m_ResponseCode = 0; + uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists + uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends + bool m_IsInitialResponse = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + std::vector<IoBuffer> m_DataBuffers; + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs); +}; + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode) +: HttpSysRequestHandler(InRequest) +{ + std::array<IoBuffer, 0> EmptyBufferList; + + InitializeForPayload(ResponseCode, EmptyBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message) +: HttpSysRequestHandler(InRequest) +, m_ContentType(HttpContentType::kText) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + const void* Payload, + size_t PayloadSize) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + InitializeForPayload(ResponseCode, SingleBufferList); +} + +HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, + uint16_t ResponseCode, + HttpContentType ContentType, + std::span<IoBuffer> BlobList) +: HttpSysRequestHandler(InRequest) +, m_ContentType(ContentType) +{ + InitializeForPayload(ResponseCode, BlobList); +} + +HttpMessageResponseRequest::~HttpMessageResponseRequest() +{ +} + +void +HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) +{ + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_HttpDataChunks.reserve(ChunkCount); + m_DataBuffers.reserve(ChunkCount); + + for (IoBuffer& Buffer : BlobList) + { + m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + + // Initialize the full array up front + + uint64_t LocalDataSize = 0; + + for (IoBuffer& Buffer : m_DataBuffers) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + LocalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // Use direct file transfer + + m_HttpDataChunks.push_back({}); + auto& Chunk = m_HttpDataChunks.back(); + + Chunk.DataChunkType = HttpDataChunkFromFileHandle; + Chunk.FromFileHandle.FileHandle = FileRef.FileHandle; + Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset; + Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize; + } + else + { + // Send from memory, need to make sure we chunk the buffer up since + // the underlying data structure only accepts 32-bit chunk sizes for + // memory chunks. When this happens the vector will be reallocated, + // which is fine since this will be a pretty rare case and sending + // the data is going to take a lot longer than a memory allocation :) + + const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data()); + + while (BufferDataSize) + { + const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize)); + + m_HttpDataChunks.push_back({}); + auto& Chunk = m_HttpDataChunks.back(); + + Chunk.DataChunkType = HttpDataChunkFromMemory; + Chunk.FromMemory.pBuffer = (void*)WriteCursor; + Chunk.FromMemory.BufferLength = ThisChunkSize; + + BufferDataSize -= ThisChunkSize; + WriteCursor += ThisChunkSize; + } + } + } + + m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size()); + m_TotalDataSize = LocalDataSize; + + if (m_TotalDataSize == 0 && ResponseCode == 200) + { + // Some HTTP clients really don't like empty responses unless a 204 response is sent + m_ResponseCode = uint16_t(HttpResponseCode::NoContent); + } + else + { + m_ResponseCode = ResponseCode; + } +} + +void +HttpMessageResponseRequest::SuppressResponseBody() +{ + m_RemainingChunkCount = 0; + m_HttpDataChunks.clear(); + m_DataBuffers.clear(); +} + +HttpSysRequestHandler* +HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + ZEN_UNUSED(NumberOfBytesTransferred); + + if (IoResult != NO_ERROR) + { + ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult)); + + // if one transmit failed there's really no need to go on + return nullptr; + } + + if (m_RemainingChunkCount == 0) + { + return nullptr; // All done + } + + return this; +} + +void +HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode) +{ + HttpSysTransaction& Tx = Transaction(); + HTTP_REQUEST* const HttpReq = Tx.HttpRequest(); + PTP_IO const Iocp = Tx.Iocp(); + + StartThreadpoolIo(Iocp); + + // Split payload into batches to play well with the underlying API + + const int MaxChunksPerCall = 9999; + + const int ThisRequestChunkCount = std::min<int>(m_RemainingChunkCount, MaxChunksPerCall); + const int ThisRequestChunkOffset = m_NextDataChunkOffset; + + m_RemainingChunkCount -= ThisRequestChunkCount; + m_NextDataChunkOffset += ThisRequestChunkCount; + + /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA? + + From the docs: + + This flag enables buffering of data in the kernel on a per-response basis. It should + be used by an application doing synchronous I/O, or by a an application doing + asynchronous I/O with no more than one send outstanding at a time. + + Applications using asynchronous I/O which may have more than one send outstanding at + a time should not use this flag. + + When this flag is set, it should be used consistently in calls to the + HttpSendHttpResponse function as well. + */ + + ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA; + + if (m_RemainingChunkCount) + { + // We need to make more calls to send the full amount of data + SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + } + + ULONG SendResult = 0; + + if (m_IsInitialResponse) + { + // Populate response structure + + HTTP_RESPONSE HttpResponse = {}; + + HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount); + HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset; + + // Server header + // + // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header + // + // This is controlled via a registry key 'DisableServerHeader', at: + // + // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters + // + // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether + // (only the latter appears to do anything in my testing, on Windows 10). + // + // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header) + // + + PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer]; + ServerHeader->pRawValue = "Zen"; + ServerHeader->RawValueLength = (USHORT)3; + + // Content-length header + + char ContentLengthString[32]; + _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10); + + PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength]; + ContentLengthHeader->pRawValue = ContentLengthString; + ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString); + + // Content-type header + + PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType]; + + std::string_view ContentTypeString = MapContentTypeToString(m_ContentType); + + ContentTypeHeader->pRawValue = ContentTypeString.data(); + ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size(); + + std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode); + + HttpResponse.StatusCode = m_ResponseCode; + HttpResponse.pReason = ReasonString.data(); + HttpResponse.ReasonLength = (USHORT)ReasonString.size(); + + // Cache policy + + HTTP_CACHE_POLICY CachePolicy; + + CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates; + CachePolicy.SecondsToLive = 0; + + // Initial response API call + + SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + &HttpResponse, + &CachePolicy, + NULL, + NULL, + 0, + Tx.Overlapped(), + NULL); + + m_IsInitialResponse = false; + } + else + { + // Subsequent response API calls + + SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + SendFlags, + (USHORT)ThisRequestChunkCount, // EntityChunkCount + &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks + NULL, // BytesSent + NULL, // Reserved1 + 0, // Reserved2 + Tx.Overlapped(), // Overlapped + NULL // LogData + ); + } + + if (SendResult == NO_ERROR) + { + // Synchronous completion, but the completion event will still be posted to IOCP + + ErrorCode.clear(); + } + else if (SendResult == ERROR_IO_PENDING) + { + // Asynchronous completion, a completion notification will be posted to IOCP + + ErrorCode.clear(); + } + else + { + // An error occurred, no completion will be posted to IOCP + + CancelThreadpoolIo(Iocp); + + ZEN_WARN("failed to send HTTP response (error: '{}'), request URL: '{}', request id: {}", + GetSystemErrorAsString(SendResult), + HttpReq->pRawUrl, + HttpReq->RequestId); + + ErrorCode = MakeErrorCode(SendResult); + } +} + +/** HTTP completion handler for async work + + This is used to allow work to be taken off the request handler threads + and to support posting responses asynchronously. + */ + +class HttpAsyncWorkRequest : public HttpSysRequestHandler +{ +public: + HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response); + ~HttpAsyncWorkRequest(); + + virtual void IssueRequest(std::error_code& ErrorCode) override final; + virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override; + +private: + struct AsyncWorkItem : public IWork + { + virtual void Execute() override; + + AsyncWorkItem(HttpSysTransaction& InTx, std::function<void(HttpServerRequest&)>&& InHandler) + : Tx(InTx) + , Handler(std::move(InHandler)) + { + } + + HttpSysTransaction& Tx; + std::function<void(HttpServerRequest&)> Handler; + }; + + Ref<AsyncWorkItem> m_WorkItem; +}; + +HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response) +: HttpSysRequestHandler(Tx) +{ + m_WorkItem = new AsyncWorkItem(Tx, std::move(Response)); +} + +HttpAsyncWorkRequest::~HttpAsyncWorkRequest() +{ +} + +void +HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode) +{ + ErrorCode.clear(); + + Transaction().Server().WorkPool().ScheduleWork(m_WorkItem); +} + +HttpSysRequestHandler* +HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // This ought to not be called since there should be no outstanding I/O request + // when this completion handler is active + + ZEN_UNUSED(IoResult, NumberOfBytesTransferred); + + ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred); + + return this; +} + +void +HttpAsyncWorkRequest::AsyncWorkItem::Execute() +{ + try + { + HttpSysServerRequest& ThisRequest = Tx.ServerRequest(); + + ThisRequest.m_NextCompletionHandler = nullptr; + + Handler(ThisRequest); + + // TODO: should Handler be destroyed at this point to ensure there + // are no outstanding references into state which could be + // deleted asynchronously as a result of issuing the response? + + if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler) + { + return (void)Tx.IssueNextRequest(NextHandler); + } + else if (!ThisRequest.IsHandled()) + { + return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv)); + } + else + { + // "Handled" but no request handler? Shouldn't ever happen + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv)); + } + } + catch (std::exception& Ex) + { + return (void)Tx.IssueNextRequest( + new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what()))); + } +} + +/** + _________ + / _____/ ______________ __ ___________ + \_____ \_/ __ \_ __ \ \/ // __ \_ __ \ + / \ ___/| | \/\ /\ ___/| | \/ + /_______ /\___ >__| \_/ \___ >__| + \/ \/ \/ +*/ + +HttpSysServer::HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount) +: m_Log(logging::Get("http")) +, m_RequestLog(logging::Get("http_requests")) +, m_ThreadPool(ThreadCount) +, m_AsyncWorkPool(AsyncWorkThreadCount) +{ + ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr); + + if (Result != NO_ERROR) + { + return; + } + + m_IsHttpInitialized = true; + m_IsOk = true; + + ZEN_INFO("http.sys server started, using {} I/O threads and {} async worker threads", ThreadCount, AsyncWorkThreadCount); +} + +HttpSysServer::~HttpSysServer() +{ + if (m_IsHttpInitialized) + { + Cleanup(); + + HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr); + } +} + +int +HttpSysServer::InitializeServer(int BasePort) +{ + using namespace std::literals; + + WideStringBuilder<64> WildcardUrlPath; + WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv; + + m_IsOk = false; + + ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + int EffectivePort = BasePort; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + + // Sharing violation implies the port is being used by another process + for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + WildcardUrlPath.Reset(); + WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv; + + Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + } + + m_BaseUris.clear(); + if (Result == NO_ERROR) + { + m_BaseUris.push_back(WildcardUrlPath.c_str()); + } + else if (Result == ERROR_ACCESS_DENIED) + { + // If we can't register the wildcard path, we fall back to local paths + // This local paths allow requests originating locally to function, but will not allow + // remote origin requests to function. This can be remedied by using netsh + // during an install process to grant permissions to route public access to the appropriate + // port for the current user. eg: + // netsh http add urlacl url=http://*:1337/ user=<some_user> + + ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath)); + + const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv}; + + ULONG InternalResult = ERROR_SHARING_VIOLATION; + for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + + for (const std::u8string_view Host : Hosts) + { + WideStringBuilder<64> LocalUrlPath; + LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv; + + InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0); + + if (InternalResult == NO_ERROR) + { + ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath)); + + m_BaseUris.push_back(LocalUrlPath.c_str()); + } + else + { + break; + } + } + } + } + + if (m_BaseUris.empty()) + { + ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result); + + return BasePort; + } + + HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0}; + + Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2, + /* Name */ nullptr, + /* SecurityAttributes */ nullptr, + /* Flags */ 0, + &m_RequestQueueHandle); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + + return EffectivePort; + } + + HttpBindingInfo.Flags.Present = 1; + HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle; + + Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo)); + + if (Result != NO_ERROR) + { + ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result); + + return EffectivePort; + } + + // Create I/O completion port + + std::error_code ErrorCode; + m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode); + + if (ErrorCode) + { + ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message()); + } + else + { + m_IsOk = true; + + ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front())); + } + + return EffectivePort; +} + +void +HttpSysServer::Cleanup() +{ + ++m_IsShuttingDown; + + if (m_RequestQueueHandle) + { + HttpCloseRequestQueue(m_RequestQueueHandle); + m_RequestQueueHandle = nullptr; + } + + if (m_HttpUrlGroupId) + { + HttpCloseUrlGroup(m_HttpUrlGroupId); + m_HttpUrlGroupId = 0; + } + + if (m_HttpSessionId) + { + HttpCloseServerSession(m_HttpSessionId); + m_HttpSessionId = 0; + } +} + +void +HttpSysServer::StartServer() +{ + const int InitialRequestCount = 32; + + for (int i = 0; i < InitialRequestCount; ++i) + { + IssueNewRequestMaybe(); + } +} + +void +HttpSysServer::Run(bool IsInteractive) +{ + if (IsInteractive) + { + zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit"); + } + + do + { + // int WaitTimeout = -1; + int WaitTimeout = 100; + + if (IsInteractive) + { + WaitTimeout = 1000; + + if (_kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + UpdateLofreqTimerValue(); + } while (!IsApplicationExitRequested()); +} + +void +HttpSysServer::OnHandlingRequest() +{ + if (--m_PendingRequests > m_MinPendingRequests) + { + // We have more than the minimum number of requests pending, just let someone else + // enqueue new requests + return; + } + + IssueNewRequestMaybe(); +} + +void +HttpSysServer::IssueNewRequestMaybe() +{ + if (m_IsShuttingDown.load(std::memory_order::acquire)) + { + return; + } + + if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests) + { + return; + } + + std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this); + + std::error_code ErrorCode; + Request->IssueInitialRequest(ErrorCode); + + if (ErrorCode) + { + // No request was actually issued. What is the appropriate response? + + return; + } + + // This may end up exceeding the MaxPendingRequests limit, but it's not + // really a problem. I'm doing it this way mostly to avoid dealing with + // exceptions here + ++m_PendingRequests; + + Request.release(); +} + +void +HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service) +{ + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); + Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */); + + // Convert to wide string + + for (const std::wstring& BaseUri : m_BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; + + ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + + return; + } + } +} + +void +HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service) +{ + ZEN_UNUSED(Service); + + if (UrlPath[0] == '/') + { + ++UrlPath; + } + + const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath); + + // Convert to wide string + + for (const std::wstring& BaseUri : m_BaseUris) + { + std::wstring Url16 = BaseUri + PathUtf16; + + ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0); + + if (Result != NO_ERROR) + { + ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result)); + } + } +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler) +{ +} + +HttpSysTransaction::~HttpSysTransaction() +{ +} + +PTP_IO +HttpSysTransaction::Iocp() +{ + return m_HttpServer.m_ThreadPool.Iocp(); +} + +HANDLE +HttpSysTransaction::RequestQueueHandle() +{ + return m_HttpServer.m_RequestQueueHandle; +} + +void +HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode) +{ + m_InitialHttpHandler.IssueRequest(ErrorCode); +} + +void +HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance, + PVOID pContext /* HttpSysServer */, + PVOID pOverlapped, + ULONG IoResult, + ULONG_PTR NumberOfBytesTransferred, + PTP_IO Io) +{ + UNREFERENCED_PARAMETER(Io); + UNREFERENCED_PARAMETER(Instance); + UNREFERENCED_PARAMETER(pContext); + + // Note that for a given transaction we may be in this completion function on more + // than one thread at any given moment. This means we need to be careful about what + // happens in here + + HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped); + + if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone) + { + delete Transaction; + } +} + +bool +HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler) +{ + HttpSysRequestHandler* CurrentHandler = m_CompletionHandler; + m_CompletionHandler = NewCompletionHandler; + + auto _ = MakeGuard([this, CurrentHandler] { + if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler)) + { + delete CurrentHandler; + } + }); + + if (NewCompletionHandler == nullptr) + { + return false; + } + + try + { + std::error_code ErrorCode; + m_CompletionHandler->IssueRequest(ErrorCode); + + if (!ErrorCode) + { + return true; + } + + ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message()); + } + catch (std::exception& Ex) + { + ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what()); + } + + // something went wrong, no request is pending + m_CompletionHandler = nullptr; + + return false; +} + +HttpSysTransaction::Status +HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + // We use this to ensure sequential execution of completion handlers + // for any given transaction. It also ensures all member variables are + // in a consistent state for the current thread + + RwLock::ExclusiveLockScope _(m_CompletionMutex); + + bool IsRequestPending = false; + + if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler) + { + if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest()) + { + // Ensure we have a sufficient number of pending requests outstanding + m_HttpServer.OnHandlingRequest(); + } + + auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred); + + IsRequestPending = IssueNextRequest(NewCompletionHandler); + } + + // Ensure new requests are enqueued as necessary + m_HttpServer.IssueNewRequestMaybe(); + + if (IsRequestPending) + { + // There is another request pending on this transaction, so it needs to remain valid + return Status::kRequestPending; + } + + if (m_HttpServer.m_IsRequestLoggingEnabled) + { + if (m_HandlerRequest.has_value()) + { + m_HttpServer.m_RequestLog.info("{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri()); + } + } + + // Transaction done, caller should clean up (delete) this instance + return Status::kDone; +} + +HttpSysServerRequest& +HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) +{ + HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload); + + // Default request handling + + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + + return ThisRequest; +} + +////////////////////////////////////////////////////////////////////////// + +HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer) +: m_HttpTx(Tx) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest(); + + const int PrefixLength = Service.UriPrefixLength(); + const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + if (AbsPathLength >= PrefixLength) + { + // We convert the URI immediately because most of the code involved prefers to deal + // with utf8. This is overhead which I'd prefer to avoid but for now we just have + // to live with it + + WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)}, + m_UriUtf8); + + std::string_view UriSuffix8{m_UriUtf8}; + + m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access + m_Uri = UriSuffix8; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(UriSuffix8.size() + 1); + } + } + } + else + { + m_UriUtf8.Reset(); + m_Uri = {}; + m_UriWithExtension = {}; + } + + if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength) + { + --QueryStringLength; // We skip the leading question mark + + WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8); + } + else + { + m_QueryStringUtf8.Reset(); + } + + m_QueryString = std::string_view(m_QueryStringUtf8); + m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb); + m_ContentLength = GetContentLength(HttpRequestPtr); + m_ContentType = GetContentType(HttpRequestPtr); + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = GetAcceptType(HttpRequestPtr); + } + + if (m_Verb == HttpVerb::kHead) + { + SetSuppressResponseBody(); + } +} + +Oid +HttpSysServerRequest::ParseSessionId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Session"sv) + { + if (Header.RawValueLength == Oid::StringLength) + { + return Oid::FromHexString({Header.pRawValue, Header.RawValueLength}); + } + } + } + + return {}; +} + +uint32_t +HttpSysServerRequest::ParseRequestId() const +{ + const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest(); + + for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i) + { + HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i]; + std::string_view HeaderName{Header.pName, Header.NameLength}; + + if (HeaderName == "UE-Request"sv) + { + std::string_view RequestValue{Header.pRawValue, Header.RawValueLength}; + uint32_t RequestId = 0; + std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId); + return RequestId; + } + } + + return 0; +} + +IoBuffer +HttpSysServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(IsHandled() == false); + + auto Response = + new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size()); + + if (SuppressBody()) + { + Response->SuppressResponseBody(); + } + + m_NextCompletionHandler = Response; + + SetIsHandled(); +} + +void +HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + if (m_HttpTx.Server().IsAsyncResponseEnabled()) + { + m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler)); + } + else + { + ContinuationHandler(m_HttpTx.ServerRequest()); + } +} + +bool +HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + HTTP_REQUEST* Req = m_HttpTx.HttpRequest(); + const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange]; + + return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest) +{ +} + +InitialRequestHandler::~InitialRequestHandler() +{ +} + +void +InitialRequestHandler::IssueRequest(std::error_code& ErrorCode) +{ + HttpSysTransaction& Tx = Transaction(); + PTP_IO Iocp = Tx.Iocp(); + HTTP_REQUEST* HttpReq = Tx.HttpRequest(); + + StartThreadpoolIo(Iocp); + + ULONG HttpApiResult; + + if (IsInitialRequest()) + { + HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(), + HTTP_NULL_ID, + HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, + HttpReq, + RequestBufferSize(), + NULL, + Tx.Overlapped()); + } + else + { + // The http.sys team recommends limiting the size to 128KB + static const uint64_t kMaxBytesPerApiCall = 128 * 1024; + + uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset; + const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall); + void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset; + + HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(), + HttpReq->RequestId, + 0, /* Flags */ + BufferWriteCursor, + gsl::narrow<ULONG>(BytesToReadThisCall), + nullptr, // BytesReturned + Tx.Overlapped()); + } + + if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR) + { + CancelThreadpoolIo(Iocp); + + ErrorCode = MakeErrorCode(HttpApiResult); + + ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message()); + + return; + } + + ErrorCode.clear(); +} + +HttpSysRequestHandler* +InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) +{ + auto _ = MakeGuard([&] { m_IsInitialRequest = false; }); + + switch (IoResult) + { + default: + case ERROR_OPERATION_ABORTED: + return nullptr; + + case ERROR_MORE_DATA: // Insufficient buffer space + case NO_ERROR: + break; + } + + // Route request + + try + { + HTTP_REQUEST* HttpReq = HttpRequest(); + +# if 0 + for (int i = 0; i < HttpReq->RequestInfoCount; ++i) + { + auto& ReqInfo = HttpReq->pRequestInfo[i]; + + switch (ReqInfo.InfoType) + { + case HttpRequestInfoTypeRequestTiming: + { + const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeAuth: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeChannelBind: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslProtocol: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslTokenBindingDraft: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeSslTokenBinding: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeTcpInfoV0: + { + const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeRequestSizing: + { + const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo); + ZEN_INFO(""); + } + break; + case HttpRequestInfoTypeQuicStats: + ZEN_INFO(""); + break; + case HttpRequestInfoTypeTcpInfoV1: + { + const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo); + + ZEN_INFO(""); + } + break; + } + } +# endif + + if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext)) + { + if (m_IsInitialRequest) + { + m_ContentLength = GetContentLength(HttpReq); + const HttpContentType ContentType = GetContentType(HttpReq); + + if (m_ContentLength) + { + // Handle initial chunk read by copying any payload which has already been copied + // into our embedded request buffer + + m_PayloadBuffer = IoBuffer(m_ContentLength); + m_PayloadBuffer.SetContentType(ContentType); + + uint64_t BytesToRead = m_ContentLength; + uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()); + uint8_t* BufferWriteCursor = BufferBase; + + const int EntityChunkCount = HttpReq->EntityChunkCount; + + for (int i = 0; i < EntityChunkCount; ++i) + { + HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i]; + + ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory); + + const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength; + + ZEN_ASSERT(BufferLength <= BytesToRead); + + memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength); + + BufferWriteCursor += BufferLength; + BytesToRead -= BufferLength; + } + + m_CurrentPayloadOffset = BufferWriteCursor - BufferBase; + } + } + else + { + m_CurrentPayloadOffset += NumberOfBytesTransferred; + } + + if (m_CurrentPayloadOffset != m_ContentLength) + { + // Body not complete, issue another read request to receive more body data + return this; + } + + // Request body received completely + + m_PayloadBuffer.MakeImmutable(); + + HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer); + + if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler) + { + return Response; + } + + if (!ThisRequest.IsHandled()) + { + return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv); + } + } + + // Unable to route + return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv); + } + catch (std::exception& ex) + { + ZEN_ERROR("Caught exception while handling request: '{}'", ex.what()); + + return new HttpMessageResponseRequest(Transaction(), 500, ex.what()); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// HttpServer interface implementation +// + +int +HttpSysServer::Initialize(int BasePort) +{ + int EffectivePort = InitializeServer(BasePort); + StartServer(); + return EffectivePort; +} + +void +HttpSysServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} +void +HttpSysServer::RegisterService(HttpService& Service) +{ + RegisterService(Service.BaseUri(), Service); +} + +Ref<HttpServer> +CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads) +{ + return Ref<HttpServer>(new HttpSysServer(Concurrency, BackgroundWorkerThreads)); +} + +} // namespace zen +#endif diff --git a/src/zenhttp/httpsys.h b/src/zenhttp/httpsys.h new file mode 100644 index 000000000..d6bd34890 --- /dev/null +++ b/src/zenhttp/httpsys.h @@ -0,0 +1,90 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +#ifndef ZEN_WITH_HTTPSYS +# if ZEN_PLATFORM_WINDOWS +# define ZEN_WITH_HTTPSYS 1 +# else +# define ZEN_WITH_HTTPSYS 0 +# endif +#endif + +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include <zencore/windows.h> +# include <zencore/workthreadpool.h> +# include "iothreadpool.h" + +# include <http.h> + +namespace spdlog { +class logger; +} + +namespace zen { + +/** + * @brief Windows implementation of HTTP server based on http.sys + * + * This requires elevation to function + */ +class HttpSysServer : public HttpServer +{ + friend class HttpSysTransaction; + +public: + explicit HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount); + ~HttpSysServer(); + + // HttpServer interface implementation + + virtual int Initialize(int BasePort) override; + virtual void Run(bool TestMode) override; + virtual void RequestExit() override; + virtual void RegisterService(HttpService& Service) override; + + WorkerThreadPool& WorkPool() { return m_AsyncWorkPool; } + + inline bool IsOk() const { return m_IsOk; } + inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + +private: + int InitializeServer(int BasePort); + void Cleanup(); + + void StartServer(); + void OnHandlingRequest(); + void IssueNewRequestMaybe(); + + void RegisterService(const char* Endpoint, HttpService& Service); + void UnregisterService(const char* Endpoint, HttpService& Service); + +private: + spdlog::logger& m_Log; + spdlog::logger& m_RequestLog; + spdlog::logger& Log() { return m_Log; } + + bool m_IsOk = false; + bool m_IsHttpInitialized = false; + bool m_IsRequestLoggingEnabled = false; + bool m_IsAsyncResponseEnabled = true; + + WinIoThreadPool m_ThreadPool; + WorkerThreadPool m_AsyncWorkPool; + + std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/ + HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; + HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; + HANDLE m_RequestQueueHandle = 0; + std::atomic_int32_t m_PendingRequests{0}; + std::atomic_int32_t m_IsShuttingDown{0}; + int32_t m_MinPendingRequests = 16; + int32_t m_MaxPendingRequests = 128; + Event m_ShutdownEvent; +}; + +} // namespace zen +#endif diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h new file mode 100644 index 000000000..8316a9b9f --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpclient.h @@ -0,0 +1,47 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zencore/iobuffer.h> +#include <zencore/uid.h> +#include <zenhttp/httpcommon.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CbPackage; + +/** HTTP client implementation for Zen use cases + + Currently simple and synchronous, should become lean and asynchronous + */ +class HttpClient +{ +public: + HttpClient(std::string_view BaseUri); + ~HttpClient(); + + struct Response + { + int StatusCode = 0; + IoBuffer ResponsePayload; // Note: this also includes the content type + }; + + [[nodiscard]] Response Put(std::string_view Url, IoBuffer Payload); + [[nodiscard]] Response Get(std::string_view Url); + [[nodiscard]] Response TransactPackage(std::string_view Url, CbPackage Package); + [[nodiscard]] Response Delete(std::string_view Url); + +private: + std::string m_BaseUri; + std::string m_SessionId; +}; + +} // namespace zen + +void httpclient_forcelink(); // internal diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h new file mode 100644 index 000000000..19fda8db4 --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpcommon.h @@ -0,0 +1,181 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> + +#include <string_view> + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +using HttpContentType = ZenContentType; + +class IoBuffer; +class CbObject; +class CbPackage; +class StringBuilderBase; + +struct HttpRange +{ + uint32_t Start = ~uint32_t(0); + uint32_t End = ~uint32_t(0); +}; + +using HttpRanges = std::vector<HttpRange>; + +std::string_view MapContentTypeToString(HttpContentType ContentType); +extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString); +std::string_view ReasonStringForHttpResultCode(int HttpCode); +bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges); + +[[nodiscard]] inline bool +IsHttpSuccessCode(int HttpCode) +{ + return (HttpCode >= 200) && (HttpCode < 300); +} + +enum class HttpVerb : uint8_t +{ + kGet = 1 << 0, + kPut = 1 << 1, + kPost = 1 << 2, + kDelete = 1 << 3, + kHead = 1 << 4, + kCopy = 1 << 5, + kOptions = 1 << 6 +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(HttpVerb); + +const std::string_view ToString(HttpVerb Verb); + +enum class HttpResponseCode +{ + // 1xx - Informational + + Continue = 100, //!< Indicates that the initial part of a request has been received and has not yet been rejected by the server. + SwitchingProtocols = 101, //!< Indicates that the server understands and is willing to comply with the client's request, via the + //!< Upgrade header field, for a change in the application protocol being used on this connection. + Processing = 102, //!< Is an interim response used to inform the client that the server has accepted the complete request, but has not + //!< yet completed it. + EarlyHints = 103, //!< Indicates to the client that the server is likely to send a final response with the header fields included in + //!< the informational response. + + // 2xx - Successful + + OK = 200, //!< Indicates that the request has succeeded. + Created = 201, //!< Indicates that the request has been fulfilled and has resulted in one or more new resources being created. + Accepted = 202, //!< Indicates that the request has been accepted for processing, but the processing has not been completed. + NonAuthoritativeInformation = 203, //!< Indicates that the request was successful but the enclosed payload has been modified from that + //!< of the origin server's 200 (OK) response by a transforming proxy. + NoContent = 204, //!< Indicates that the server has successfully fulfilled the request and that there is no additional content to send + //!< in the response payload body. + ResetContent = 205, //!< Indicates that the server has fulfilled the request and desires that the user agent reset the \"document + //!< view\", which caused the request to be sent, to its original state as received from the origin server. + PartialContent = 206, //!< Indicates that the server is successfully fulfilling a range request for the target resource by transferring + //!< one or more parts of the selected representation that correspond to the satisfiable ranges found in the + //!< requests's Range header field. + MultiStatus = 207, //!< Provides status for multiple independent operations. + AlreadyReported = 208, //!< Used inside a DAV:propstat response element to avoid enumerating the internal members of multiple bindings + //!< to the same collection repeatedly. [RFC 5842] + IMUsed = 226, //!< The server has fulfilled a GET request for the resource, and the response is a representation of the result of one + //!< or more instance-manipulations applied to the current instance. + + // 3xx - Redirection + + MultipleChoices = 300, //!< Indicates that the target resource has more than one representation, each with its own more specific + //!< identifier, and information about the alternatives is being provided so that the user (or user agent) can + //!< select a preferred representation by redirecting its request to one or more of those identifiers. + MovedPermanently = 301, //!< Indicates that the target resource has been assigned a new permanent URI and any future references to this + //!< resource ought to use one of the enclosed URIs. + Found = 302, //!< Indicates that the target resource resides temporarily under a different URI. + SeeOther = 303, //!< Indicates that the server is redirecting the user agent to a different resource, as indicated by a URI in the + //!< Location header field, that is intended to provide an indirect response to the original request. + NotModified = 304, //!< Indicates that a conditional GET request has been received and would have resulted in a 200 (OK) response if it + //!< were not for the fact that the condition has evaluated to false. + UseProxy = 305, //!< \deprecated \parblock Due to security concerns regarding in-band configuration of a proxy. \endparblock + //!< The requested resource MUST be accessed through the proxy given by the Location field. + TemporaryRedirect = 307, //!< Indicates that the target resource resides temporarily under a different URI and the user agent MUST NOT + //!< change the request method if it performs an automatic redirection to that URI. + PermanentRedirect = 308, //!< The target resource has been assigned a new permanent URI and any future references to this resource + //!< ought to use one of the enclosed URIs. [...] This status code is similar to 301 Moved Permanently + //!< (Section 7.3.2 of rfc7231), except that it does not allow rewriting the request method from POST to GET. + + // 4xx - Client Error + BadRequest = 400, //!< Indicates that the server cannot or will not process the request because the received syntax is invalid, + //!< nonsensical, or exceeds some limitation on what the server is willing to process. + Unauthorized = 401, //!< Indicates that the request has not been applied because it lacks valid authentication credentials for the + //!< target resource. + PaymentRequired = 402, //!< *Reserved* + Forbidden = 403, //!< Indicates that the server understood the request but refuses to authorize it. + NotFound = 404, //!< Indicates that the origin server did not find a current representation for the target resource or is not willing + //!< to disclose that one exists. + MethodNotAllowed = 405, //!< Indicates that the method specified in the request-line is known by the origin server but not supported by + //!< the target resource. + NotAcceptable = 406, //!< Indicates that the target resource does not have a current representation that would be acceptable to the + //!< user agent, according to the proactive negotiation header fields received in the request, and the server is + //!< unwilling to supply a default representation. + ProxyAuthenticationRequired = + 407, //!< Is similar to 401 (Unauthorized), but indicates that the client needs to authenticate itself in order to use a proxy. + RequestTimeout = + 408, //!< Indicates that the server did not receive a complete request message within the time that it was prepared to wait. + Conflict = 409, //!< Indicates that the request could not be completed due to a conflict with the current state of the resource. + Gone = 410, //!< Indicates that access to the target resource is no longer available at the origin server and that this condition is + //!< likely to be permanent. + LengthRequired = 411, //!< Indicates that the server refuses to accept the request without a defined Content-Length. + PreconditionFailed = + 412, //!< Indicates that one or more preconditions given in the request header fields evaluated to false when tested on the server. + PayloadTooLarge = 413, //!< Indicates that the server is refusing to process a request because the request payload is larger than the + //!< server is willing or able to process. + URITooLong = 414, //!< Indicates that the server is refusing to service the request because the request-target is longer than the + //!< server is willing to interpret. + UnsupportedMediaType = 415, //!< Indicates that the origin server is refusing to service the request because the payload is in a format + //!< not supported by the target resource for this method. + RangeNotSatisfiable = 416, //!< Indicates that none of the ranges in the request's Range header field overlap the current extent of the + //!< selected resource or that the set of ranges requested has been rejected due to invalid ranges or an + //!< excessive request of small or overlapping ranges. + ExpectationFailed = 417, //!< Indicates that the expectation given in the request's Expect header field could not be met by at least + //!< one of the inbound servers. + ImATeapot = 418, //!< Any attempt to brew coffee with a teapot should result in the error code 418 I'm a teapot. + UnprocessableEntity = 422, //!< Means the server understands the content type of the request entity (hence a 415(Unsupported Media + //!< Type) status code is inappropriate), and the syntax of the request entity is correct (thus a 400 (Bad + //!< Request) status code is inappropriate) but was unable to process the contained instructions. + Locked = 423, //!< Means the source or destination resource of a method is locked. + FailedDependency = 424, //!< Means that the method could not be performed on the resource because the requested action depended on + //!< another action and that action failed. + UpgradeRequired = 426, //!< Indicates that the server refuses to perform the request using the current protocol but might be willing to + //!< do so after the client upgrades to a different protocol. + PreconditionRequired = 428, //!< Indicates that the origin server requires the request to be conditional. + TooManyRequests = 429, //!< Indicates that the user has sent too many requests in a given amount of time (\"rate limiting\"). + RequestHeaderFieldsTooLarge = + 431, //!< Indicates that the server is unwilling to process the request because its header fields are too large. + UnavailableForLegalReasons = + 451, //!< This status code indicates that the server is denying access to the resource in response to a legal demand. + + // 5xx - Server Error + + InternalServerError = + 500, //!< Indicates that the server encountered an unexpected condition that prevented it from fulfilling the request. + NotImplemented = 501, //!< Indicates that the server does not support the functionality required to fulfill the request. + BadGateway = 502, //!< Indicates that the server, while acting as a gateway or proxy, received an invalid response from an inbound + //!< server it accessed while attempting to fulfill the request. + ServiceUnavailable = 503, //!< Indicates that the server is currently unable to handle the request due to a temporary overload or + //!< scheduled maintenance, which will likely be alleviated after some delay. + GatewayTimeout = 504, //!< Indicates that the server, while acting as a gateway or proxy, did not receive a timely response from an + //!< upstream server it needed to access in order to complete the request. + HTTPVersionNotSupported = 505, //!< Indicates that the server does not support, or refuses to support, the protocol version that was + //!< used in the request message. + VariantAlsoNegotiates = + 506, //!< Indicates that the server has an internal configuration error: the chosen variant resource is configured to engage in + //!< transparent content negotiation itself, and is therefore not a proper end point in the negotiation process. + InsufficientStorage = 507, //!< Means the method could not be performed on the resource because the server is unable to store the + //!< representation needed to successfully complete the request. + LoopDetected = 508, //!< Indicates that the server terminated an operation because it encountered an infinite loop while processing a + //!< request with "Depth: infinity". [RFC 5842] + NotExtended = 510, //!< The policy for accessing the resource has not been met in the request. [RFC 2774] + NetworkAuthenticationRequired = 511, //!< Indicates that the client needs to authenticate to gain network access. +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h new file mode 100644 index 000000000..3b9fa50b4 --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -0,0 +1,315 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenhttp.h" + +#include <zencore/compactbinary.h> +#include <zencore/enumflags.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/refcount.h> +#include <zencore/string.h> +#include <zencore/uid.h> +#include <zenhttp/httpcommon.h> + +#include <functional> +#include <gsl/gsl-lite.hpp> +#include <list> +#include <map> +#include <regex> +#include <span> +#include <unordered_map> + +namespace zen { + +/** HTTP Server Request + */ +class HttpServerRequest +{ +public: + HttpServerRequest(); + ~HttpServerRequest(); + + // Synchronous operations + + [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix + [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } + [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; } + + struct QueryParams + { + std::vector<std::pair<std::string_view, std::string_view>> KvPairs; + + std::string_view GetValue(std::string_view ParamName) const + { + for (const auto& Kv : KvPairs) + { + const std::string_view& Key = Kv.first; + + if (Key.size() == ParamName.size()) + { + if (0 == StrCaseCompare(Key.data(), ParamName.data(), Key.size())) + { + return Kv.second; + } + } + } + + return std::string_view(); + } + }; + + virtual bool TryGetRanges(HttpRanges&) { return false; } + + QueryParams GetQueryParams(); + + inline HttpVerb RequestVerb() const { return m_Verb; } + inline HttpContentType RequestContentType() { return m_ContentType; } + inline HttpContentType AcceptContentType() { return m_AcceptType; } + + inline uint64_t ContentLength() const { return m_ContentLength; } + Oid SessionId() const; + uint32_t RequestId() const; + + inline bool IsHandled() const { return !!(m_Flags & kIsHandled); } + inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); } + inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; } + + /** Read POST/PUT payload for request body, which is always available without delay + */ + virtual IoBuffer ReadPayload() = 0; + + ZENCORE_API CbObject ReadPayloadObject(); + ZENCORE_API CbPackage ReadPayloadPackage(); + + /** Respond with payload + + No data will have been sent when any of these functions return. Instead, the response will be transmitted + asynchronously, after returning from a request handler function. + + Note that this is destructive in the sense that the IoBuffer instances referred to by Blobs will be + moved into our response handler array where they are kept alive, in order to reduce ref-counting storms + */ + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) = 0; + virtual void WriteResponse(HttpResponseCode ResponseCode) = 0; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload); + + void WriteResponse(HttpResponseCode ResponseCode, CbObject Data); + void WriteResponse(HttpResponseCode ResponseCode, CbArray Array); + void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package); + void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString); + void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob); + + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0; + +protected: + enum + { + kIsHandled = 1 << 0, + kSuppressBody = 1 << 1, + kHaveRequestId = 1 << 2, + kHaveSessionId = 1 << 3, + }; + + mutable uint32_t m_Flags = 0; + HttpVerb m_Verb = HttpVerb::kGet; + HttpContentType m_ContentType = HttpContentType::kBinary; + HttpContentType m_AcceptType = HttpContentType::kUnknownContentType; + uint64_t m_ContentLength = ~0ull; + std::string_view m_Uri; + std::string_view m_UriWithExtension; + std::string_view m_QueryString; + mutable uint32_t m_RequestId = ~uint32_t(0); + mutable Oid m_SessionId = Oid::Zero; + + inline void SetIsHandled() { m_Flags |= kIsHandled; } + + virtual Oid ParseSessionId() const = 0; + virtual uint32_t ParseRequestId() const = 0; +}; + +class IHttpPackageHandler : public RefCounted +{ +public: + virtual void FilterOffer(std::vector<IoHash>& OfferCids) = 0; + virtual void OnRequestBegin() = 0; + virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) = 0; + virtual void OnRequestComplete() = 0; +}; + +/** + * Base class for implementing an HTTP "service" + * + * A service exposes one or more endpoints with a certain URI prefix + * + */ + +class HttpService +{ +public: + HttpService() = default; + virtual ~HttpService() = default; + + virtual const char* BaseUri() const = 0; + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0; + virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest); + + // Internals + + inline void SetUriPrefixLength(size_t PrefixLength) { m_UriPrefixLength = (int)PrefixLength; } + inline int UriPrefixLength() const { return m_UriPrefixLength; } + +private: + int m_UriPrefixLength = 0; +}; + +/** HTTP server + * + * Implements the main event loop to service HTTP requests, and handles routing + * requests to the appropriate handler as registered via RegisterService + */ +class HttpServer : public RefCounted +{ +public: + virtual void RegisterService(HttpService& Service) = 0; + virtual int Initialize(int BasePort) = 0; + virtual void Run(bool IsInteractiveSession) = 0; + virtual void RequestExit() = 0; +}; + +Ref<HttpServer> CreateHttpServer(std::string_view ServerClass); + +////////////////////////////////////////////////////////////////////////// + +class HttpRouterRequest +{ +public: + HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} + + ZENCORE_API std::string GetCapture(uint32_t Index) const; + inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } + +private: + using MatchResults_t = std::match_results<std::string_view::const_iterator>; + + HttpServerRequest& m_HttpRequest; + MatchResults_t m_Match; + + friend class HttpRequestRouter; +}; + +inline std::string +HttpRouterRequest::GetCapture(uint32_t Index) const +{ + ZEN_ASSERT(Index < m_Match.size()); + + return m_Match[Index]; +} + +/** HTTP request router helper + * + * This helper class allows a service implementer to register one or more + * endpoints using pattern matching (currently using regex matching) + * + * This is intended to be initialized once only, there is no thread + * safety so you can absolutely not add or remove endpoints once the handler + * goes live + */ + +class HttpRequestRouter +{ +public: + typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t; + + /** + * @brief Add pattern which can be referenced by name, commonly used for URL components + * @param Id String used to identify patterns for replacement + * @param Regex String which will replace the Id string in any registered URL paths + */ + void AddPattern(const char* Id, const char* Regex); + + /** + * @brief Register a an endpoint handler for the given route + * @param Regex Regular expression used to match the handler to a request. This may + * contain pattern aliases registered via AddPattern + * @param HandlerFunc Handler function to call for any matching request + * @param SupportedVerbs Supported HTTP verbs for this handler + */ + void RegisterRoute(const char* Regex, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs); + + void ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex); + + /** + * @brief HTTP request handling function - this should be called to route the + * request to a registered handler + * @param Request Request to route to a handler + * @return Function returns true if the request was routed successfully + */ + bool HandleRequest(zen::HttpServerRequest& Request); + +private: + struct HandlerEntry + { + HandlerEntry(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) + : RegEx(Regex, std::regex::icase | std::regex::ECMAScript) + , Verbs(SupportedVerbs) + , Handler(std::move(Handler)) + , Pattern(Pattern) + { + } + + ~HandlerEntry() = default; + + std::regex RegEx; + HttpVerb Verbs; + HandlerFunc_t Handler; + const char* Pattern; + + private: + HandlerEntry& operator=(const HandlerEntry&) = delete; + HandlerEntry(const HandlerEntry&) = delete; + }; + + std::list<HandlerEntry> m_Handlers; + std::unordered_map<std::string, std::string> m_PatternMap; +}; + +/** HTTP RPC request helper + */ + +class RpcResult +{ + RpcResult(CbObject Result) : m_Result(std::move(Result)) {} + +private: + CbObject m_Result; +}; + +class HttpRpcHandler +{ +public: + HttpRpcHandler(); + ~HttpRpcHandler(); + + HttpRpcHandler(const HttpRpcHandler&) = delete; + HttpRpcHandler operator=(const HttpRpcHandler&) = delete; + + void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction); + +private: + struct RpcFunction + { + std::function<void(CbObject& RpcArgs)> Function; + std::string Identifier; + }; + + std::map<std::string, RpcFunction> m_Functions; +}; + +bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef); + +void http_forcelink(); // internal + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/httpshared.h b/src/zenhttp/include/zenhttp/httpshared.h new file mode 100644 index 000000000..d335572c5 --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpshared.h @@ -0,0 +1,163 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinarypackage.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> + +#include <functional> +#include <gsl/gsl-lite.hpp> + +namespace zen { + +class IoBuffer; +class CbPackage; +class CompositeBuffer; + +/** _____ _ _____ _ + / ____| | | __ \ | | + | | | |__ | |__) |_ _ ___| | ____ _ __ _ ___ + | | | '_ \| ___/ _` |/ __| |/ / _` |/ _` |/ _ \ + | |____| |_) | | | (_| | (__| < (_| | (_| | __/ + \_____|_.__/|_| \__,_|\___|_|\_\__,_|\__, |\___| + __/ | + |___/ + + Structures and code related to handling CbPackage transactions + + CbPackage instances are marshaled across the wire using a distinct message + format. We don't use the CbPackage serialization format provided by the + CbPackage implementation itself since that does not provide much flexibility + in how the attachment payloads are transmitted. The scheme below separates + metadata cleanly from payloads and this enables us to more efficiently + transmit them either via sendfile/TransmitFile like mechanisms, or by + reference/memory mapping in the local case. + */ + +struct CbPackageHeader +{ + uint32_t HeaderMagic; + uint32_t AttachmentCount; // TODO: should add ability to opt out of implicit root document? + uint32_t Reserved1; + uint32_t Reserved2; +}; + +static_assert(sizeof(CbPackageHeader) == 16); + +enum : uint32_t +{ + kCbPkgMagic = 0xaa77aacc +}; + +struct CbAttachmentEntry +{ + uint64_t PayloadSize; // Size of the associated payload data in the message + uint32_t Flags; // See flags below + IoHash AttachmentHash; // Content Id for the attachment + + enum + { + kIsCompressed = (1u << 0), // Is marshaled using compressed buffer storage format + kIsObject = (1u << 1), // Is compact binary object + kIsError = (1u << 2), // Is error (compact binary formatted) object + kIsLocalRef = (1u << 3), // Is "local reference" + }; +}; + +struct CbAttachmentReferenceHeader +{ + uint64_t PayloadByteOffset = 0; + uint64_t PayloadByteSize = ~0u; + uint16_t AbsolutePathLength = 0; + + // This header will be followed by UTF8 encoded absolute path to backing file +}; + +static_assert(sizeof(CbAttachmentEntry) == 32); + +enum class FormatFlags +{ + kDefault = 0, + kAllowLocalReferences = (1u << 0), + kDenyPartialLocalReferences = (1u << 1) +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(FormatFlags); + +enum class RpcAcceptOptions : uint16_t +{ + kNone = 0, + kAllowLocalReferences = (1u << 0), + kAllowPartialLocalReferences = (1u << 1) +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions); + +std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0); +CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0); +CbPackage ParsePackageMessage( + IoBuffer Payload, + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer { + return IoBuffer{Size}; + }); +bool IsPackageMessage(IoBuffer Payload); + +bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage); + +std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, int TargetProcessPid = 0); +CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid = 0); + +/** Streaming reader for compact binary packages + + The goal is to ultimately support zero-copy I/O, but for now there'll be some + copying involved on some platforms at least. + + This approach to deserializing CbPackage data is more efficient than + `ParsePackageMessage` since it does not require the entire message to + be resident in a memory buffer + + */ +class CbPackageReader +{ +public: + CbPackageReader(); + ~CbPackageReader(); + + void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer); + + /** Process compact binary package data stream + + The data stream must be in the serialization format produced by FormatPackageMessage + + \return How many bytes must be fed to this function in the next call + */ + uint64_t ProcessPackageHeaderData(const void* Data, uint64_t DataBytes); + + void Finalize(); + const std::vector<CbAttachment>& GetAttachments() { return m_Attachments; } + CbObject GetRootObject() { return m_RootObject; } + std::span<IoBuffer> GetPayloadBuffers() { return m_PayloadBuffers; } + +private: + enum class State + { + kInitialState, + kReadingHeader, + kReadingAttachmentEntries, + kReadingBuffers + } m_CurrentState = State::kInitialState; + + std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer; + std::vector<IoBuffer> m_PayloadBuffers; + std::vector<CbAttachmentEntry> m_AttachmentEntries; + std::vector<CbAttachment> m_Attachments; + CbObject m_RootObject; + CbPackageHeader m_PackageHeader; + + IoBuffer MarshalLocalChunkReference(IoBuffer AttachmentBuffer); +}; + +void forcelink_httpshared(); + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h new file mode 100644 index 000000000..adca7e988 --- /dev/null +++ b/src/zenhttp/include/zenhttp/websocket.h @@ -0,0 +1,256 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/compactbinarypackage.h> +#include <zencore/memory.h> + +#include <compare> +#include <functional> +#include <future> +#include <memory> +#include <optional> + +#pragma once + +namespace asio { +class io_context; +} + +namespace zen { + +class BinaryWriter; + +/** + * A unique socket ID. + */ +class WebSocketId +{ + static std::atomic_uint32_t NextId; + +public: + WebSocketId() = default; + + uint32_t Value() const { return m_Value; } + + auto operator<=>(const WebSocketId&) const = default; + + static WebSocketId New() { return WebSocketId(NextId.fetch_add(1)); } + +private: + WebSocketId(uint32_t Value) : m_Value(Value) {} + + uint32_t m_Value{}; +}; + +/** + * Type of web socket message. + */ +enum class WebSocketMessageType : uint8_t +{ + kInvalid, + kNotification, + kRequest, + kStreamRequest, + kResponse, + kStreamResponse, + kStreamCompleteResponse, + kCount +}; + +inline std::string_view +ToString(WebSocketMessageType Type) +{ + switch (Type) + { + case WebSocketMessageType::kInvalid: + return std::string_view("Invalid"); + case WebSocketMessageType::kNotification: + return std::string_view("Notification"); + case WebSocketMessageType::kRequest: + return std::string_view("Request"); + case WebSocketMessageType::kStreamRequest: + return std::string_view("StreamRequest"); + case WebSocketMessageType::kResponse: + return std::string_view("Response"); + case WebSocketMessageType::kStreamResponse: + return std::string_view("StreamResponse"); + case WebSocketMessageType::kStreamCompleteResponse: + return std::string_view("StreamCompleteResponse"); + default: + return std::string_view("Unknown"); + }; +} + +/** + * Web socket message. + */ +class WebSocketMessage +{ + struct Header + { + static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh + + uint64_t MessageSize{}; + uint32_t Magic{ExpectedMagic}; + uint32_t CorrelationId{}; + uint32_t StatusCode{200u}; + WebSocketMessageType MessageType{}; + uint8_t Reserved[3] = {0}; + + bool IsValid() const; + }; + + static_assert(sizeof(Header) == 24); + + static std::atomic_uint32_t NextCorrelationId; + +public: + static constexpr size_t HeaderSize = sizeof(Header); + + WebSocketMessage() = default; + + WebSocketId SocketId() const { return m_SocketId; } + void SetSocketId(WebSocketId Id) { m_SocketId = Id; } + uint64_t MessageSize() const { return m_Header.MessageSize; } + void SetMessageType(WebSocketMessageType MessageType); + void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; } + uint32_t CorrelationId() const { return m_Header.CorrelationId; } + uint32_t StatusCode() const { return m_Header.StatusCode; } + void SetStatusCode(uint32_t StatusCode) { m_Header.StatusCode = StatusCode; } + WebSocketMessageType MessageType() const { return m_Header.MessageType; } + + const CbPackage& Body() const { return m_Body.value(); } + void SetBody(CbPackage&& Body); + void SetBody(CbObject&& Body); + bool HasBody() const { return m_Body.has_value(); } + + void Save(BinaryWriter& Writer); + bool TryLoadHeader(MemoryView Memory); + + bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; } + +private: + Header m_Header{}; + WebSocketId m_SocketId{}; + std::optional<CbPackage> m_Body; +}; + +class WebSocketServer; + +/** + * Base class for handling web socket requests and notifications from connected client(s). + */ +class WebSocketService +{ +public: + virtual ~WebSocketService() = default; + + void Configure(WebSocketServer& Server); + + virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); } + virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); } + +protected: + WebSocketService() = default; + + virtual void RegisterHandlers(WebSocketServer& Server) = 0; + void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete); + void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete); + + WebSocketServer& SocketServer() + { + ZEN_ASSERT(m_SocketServer); + return *m_SocketServer; + } + +private: + WebSocketServer* m_SocketServer{}; +}; + +/** + * Server options. + */ +struct WebSocketServerOptions +{ + uint16_t Port = 2337; + uint32_t ThreadCount = 1; +}; + +/** + * The web socket server manages client connections and routing of requests and notifications. + */ +class WebSocketServer +{ +public: + virtual ~WebSocketServer() = default; + + virtual bool Run() = 0; + virtual void Shutdown() = 0; + + virtual void RegisterService(WebSocketService& Service) = 0; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0; + + virtual void SendNotification(WebSocketMessage&& Notification) = 0; + virtual void SendResponse(WebSocketMessage&& Response) = 0; + + static std::unique_ptr<WebSocketServer> Create(const WebSocketServerOptions& Options); +}; + +/** + * The state of the web socket. + */ +enum class WebSocketState : uint32_t +{ + kNone, + kHandshaking, + kConnected, + kDisconnected, + kError +}; + +/** + * Type of web socket client event. + */ +enum class WebSocketEvent : uint32_t +{ + kConnected, + kDisconnected, + kError +}; + +/** + * Web socket client connection info. + */ +struct WebSocketConnectInfo +{ + std::string Host; + int16_t Port{8848}; + std::string Endpoint; + std::vector<std::string> Protocols; + uint16_t Version{13}; +}; + +/** + * A connection to a web socket server for sending requests and listening for notifications. + */ +class WebSocketClient +{ +public: + using EventCallback = std::function<void()>; + using NotificationCallback = std::function<void(WebSocketMessage&&)>; + + virtual ~WebSocketClient() = default; + + virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0; + virtual void Disconnect() = 0; + virtual bool IsConnected() const = 0; + virtual WebSocketState State() const = 0; + + virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0; + virtual void OnNotification(NotificationCallback&& Cb) = 0; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0; + + static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx); +}; + +} // namespace zen diff --git a/src/zenhttp/include/zenhttp/zenhttp.h b/src/zenhttp/include/zenhttp/zenhttp.h new file mode 100644 index 000000000..59c64b31f --- /dev/null +++ b/src/zenhttp/include/zenhttp/zenhttp.h @@ -0,0 +1,21 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#ifndef ZEN_WITH_HTTPSYS +# if ZEN_PLATFORM_WINDOWS +# define ZEN_WITH_HTTPSYS 1 +# else +# define ZEN_WITH_HTTPSYS 0 +# endif +#endif + +#define ZENHTTP_API // Placeholder to allow DLL configs in the future + +namespace zen { + +ZENHTTP_API void zenhttp_forcelinktests(); + +} diff --git a/src/zenhttp/iothreadpool.cpp b/src/zenhttp/iothreadpool.cpp new file mode 100644 index 000000000..6087e69ec --- /dev/null +++ b/src/zenhttp/iothreadpool.cpp @@ -0,0 +1,49 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "iothreadpool.h" + +#include <zencore/except.h> + +#if ZEN_PLATFORM_WINDOWS + +namespace zen { + +WinIoThreadPool::WinIoThreadPool(int InThreadCount) +{ + // Thread pool setup + + m_ThreadPool = CreateThreadpool(NULL); + + SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount); + SetThreadpoolThreadMaximum(m_ThreadPool, InThreadCount * 2); + + InitializeThreadpoolEnvironment(&m_CallbackEnvironment); + + m_CleanupGroup = CreateThreadpoolCleanupGroup(); + + SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool); + + SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL); +} + +WinIoThreadPool::~WinIoThreadPool() +{ + CloseThreadpool(m_ThreadPool); +} + +void +WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode) +{ + ZEN_ASSERT(!m_ThreadPoolIo); + + m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment); + + if (!m_ThreadPoolIo) + { + ErrorCode = MakeErrorCodeFromLastError(); + } +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/iothreadpool.h b/src/zenhttp/iothreadpool.h new file mode 100644 index 000000000..8333964c3 --- /dev/null +++ b/src/zenhttp/iothreadpool.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> + +# include <system_error> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +// +// Thread pool. Implemented in terms of Windows thread pool right now, will +// need a cross-platform implementation eventually +// + +class WinIoThreadPool +{ +public: + WinIoThreadPool(int InThreadCount); + ~WinIoThreadPool(); + + void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode); + inline PTP_IO Iocp() const { return m_ThreadPoolIo; } + +private: + PTP_POOL m_ThreadPool = nullptr; + PTP_CLEANUP_GROUP m_CleanupGroup = nullptr; + PTP_IO m_ThreadPoolIo = nullptr; + TP_CALLBACK_ENVIRON m_CallbackEnvironment; +}; + +} // namespace zen +#endif diff --git a/src/zenhttp/websocketasio.cpp b/src/zenhttp/websocketasio.cpp new file mode 100644 index 000000000..bbe7e1ad8 --- /dev/null +++ b/src/zenhttp/websocketasio.cpp @@ -0,0 +1,1613 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/websocket.h> + +#include <zencore/base64.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/intmath.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/sha1.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/trace.h> + +#include <chrono> +#include <optional> +#include <shared_mutex> +#include <span> +#include <system_error> +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <http_parser.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# include <mstcpip.h> +#endif + +namespace zen::websocket { + +using namespace std::literals; + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv); + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv); + +using Clock = std::chrono::steady_clock; +using TimePoint = Clock::time_point; + +/////////////////////////////////////////////////////////////////////////////// +namespace http_header { + static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv; + static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv; + static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv; + static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv; + static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv; + static constexpr std::string_view Upgrade = "Upgrade"sv; +} // namespace http_header + +/////////////////////////////////////////////////////////////////////////////// +enum class ParseMessageStatus : uint32_t +{ + kError, + kContinue, + kDone, +}; + +struct ParseMessageResult +{ + ParseMessageStatus Status{}; + size_t ByteCount{}; + std::optional<std::string> Reason; +}; + +class MessageParser +{ +public: + virtual ~MessageParser() = default; + + ParseMessageResult ParseMessage(MemoryView Msg); + void Reset(); + +protected: + MessageParser() = default; + + virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0; + virtual void OnReset() = 0; + + BinaryWriter m_Stream; +}; + +ParseMessageResult +MessageParser::ParseMessage(MemoryView Msg) +{ + return OnParseMessage(Msg); +} + +void +MessageParser::Reset() +{ + OnReset(); + + m_Stream.Reset(); +} + +/////////////////////////////////////////////////////////////////////////////// +enum class HttpMessageParserType +{ + kRequest, + kResponse, + kBoth +}; + +class HttpMessageParser final : public MessageParser +{ +public: + using HttpHeaders = std::unordered_map<std::string_view, std::string_view>; + + HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); } + + virtual ~HttpMessageParser() = default; + + int32_t StatusCode() const { return m_Parser.status_code; } + bool IsUpgrade() const { return m_Parser.upgrade != 0; } + HttpHeaders& Headers() { return m_Headers; } + MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); } + + std::string_view StatusText() const + { + return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size); + } + + bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason); + +private: + void Initialize(); + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + int OnMessageBegin(); + int OnUrl(MemoryView Url); + int OnStatus(MemoryView Status); + int OnHeaderField(MemoryView HeaderField); + int OnHeaderValue(MemoryView HeaderValue); + int OnHeadersComplete(); + int OnBody(MemoryView Body); + int OnMessageComplete(); + + struct StreamEntry + { + uint64_t Offset{}; + uint64_t Size{}; + }; + + struct HeaderStreamEntry + { + StreamEntry Field{}; + StreamEntry Value{}; + }; + + HttpMessageParserType m_Type; + http_parser m_Parser; + StreamEntry m_UrlEntry; + StreamEntry m_StatusEntry; + StreamEntry m_BodyEntry; + HeaderStreamEntry m_CurrentHeader; + std::vector<HeaderStreamEntry> m_HeaderEntries; + HttpHeaders m_Headers; + bool m_IsMsgComplete{false}; + + static http_parser_settings ParserSettings; +}; + +http_parser_settings HttpMessageParser::ParserSettings = { + .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); }, + + .on_url = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); }, + + .on_status = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); }, + + .on_header_field = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); }, + + .on_header_value = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); }, + + .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); }, + + .on_body = [](http_parser* P, + const char* Data, + size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); }, + + .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }}; + +void +HttpMessageParser::Initialize() +{ + http_parser_init(&m_Parser, + m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST + : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE + : HTTP_BOTH); + m_Parser.data = this; + + m_UrlEntry = {}; + m_StatusEntry = {}; + m_CurrentHeader = {}; + m_BodyEntry = {}; + + m_IsMsgComplete = false; + + m_HeaderEntries.clear(); +} + +ParseMessageResult +HttpMessageParser::OnParseMessage(MemoryView Msg) +{ + const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize()); + + auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue; + + if (m_Parser.http_errno != 0) + { + Status = ParseMessageStatus::kError; + } + + return {.Status = Status, .ByteCount = uint64_t(ByteCount)}; +} + +void +HttpMessageParser::OnReset() +{ + Initialize(); +} + +int +HttpMessageParser::OnMessageBegin() +{ + ZEN_ASSERT(m_IsMsgComplete == false); + ZEN_ASSERT(m_HeaderEntries.empty()); + ZEN_ASSERT(m_Headers.empty()); + + return 0; +} + +int +HttpMessageParser::OnStatus(MemoryView Status) +{ + m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()}; + + m_Stream.Write(Status); + + return 0; +} + +int +HttpMessageParser::OnUrl(MemoryView Url) +{ + m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()}; + + m_Stream.Write(Url); + + return 0; +} + +int +HttpMessageParser::OnHeaderField(MemoryView HeaderField) +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } + + if (m_CurrentHeader.Field.Size == 0) + { + m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Field.Size += HeaderField.GetSize(); + + m_Stream.Write(HeaderField); + + return 0; +} + +int +HttpMessageParser::OnHeaderValue(MemoryView HeaderValue) +{ + if (m_CurrentHeader.Value.Size == 0) + { + m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset(); + } + + m_CurrentHeader.Value.Size += HeaderValue.GetSize(); + + m_Stream.Write(HeaderValue); + + return 0; +} + +int +HttpMessageParser::OnHeadersComplete() +{ + if (m_CurrentHeader.Value.Size > 0) + { + m_HeaderEntries.push_back(m_CurrentHeader); + m_CurrentHeader = {}; + } + + m_Headers.clear(); + m_Headers.reserve(m_HeaderEntries.size()); + + const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data()); + + for (const auto& Entry : m_HeaderEntries) + { + auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size); + auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size); + + m_Headers.try_emplace(std::move(Field), std::move(Value)); + } + + return 0; +} + +int +HttpMessageParser::OnBody(MemoryView Body) +{ + m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()}; + + m_Stream.Write(Body); + + return 0; +} + +int +HttpMessageParser::OnMessageComplete() +{ + m_IsMsgComplete = true; + + return 0; +} + +bool +HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason) +{ + static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv; + + OutAcceptHash = std::string(); + + if (m_Headers.contains(http_header::SecWebSocketKey) == false) + { + OutReason = "Missing header Sec-WebSocket-Key"; + return false; + } + + if (m_Headers.contains(http_header::Upgrade) == false) + { + OutReason = "Missing header Upgrade"; + return false; + } + + ExtendableStringBuilder<128> Sb; + Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid; + + SHA1Stream HashStream; + HashStream.Append(Sb.Data(), Sb.Size()); + + SHA1 Hash = HashStream.GetHash(); + + OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash))); + Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data()); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////// +class WebSocketMessageParser final : public MessageParser +{ +public: + WebSocketMessageParser() : MessageParser() {} + + WebSocketMessage ConsumeMessage(); + +private: + virtual ParseMessageResult OnParseMessage(MemoryView Msg) override; + virtual void OnReset() override; + + WebSocketMessage m_Message; +}; + +ParseMessageResult +WebSocketMessageParser::OnParseMessage(MemoryView Msg) +{ + ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage"); + + const uint64_t PrevOffset = m_Stream.CurrentOffset(); + + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) + { + const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset(); + + m_Stream.Write(Msg.Left(RemaingHeaderSize)); + Msg += RemaingHeaderSize; + + if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize) + { + return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } + + const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView()); + + if (IsValidHeader == false) + { + OnReset(); + + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message header")}; + } + + if (m_Message.MessageSize() == 0) + { + return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; + } + } + + ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize); + + if (Msg.IsEmpty() == false) + { + const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset(); + m_Stream.Write(Msg.Left(RemaingMessageSize)); + } + + auto Status = ParseMessageStatus::kContinue; + + if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize()) + { + Status = ParseMessageStatus::kDone; + + BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize)); + + CbPackage Pkg; + if (Pkg.TryLoad(Reader) == false) + { + return {.Status = ParseMessageStatus::kError, + .ByteCount = m_Stream.CurrentOffset() - PrevOffset, + .Reason = std::string("Invalid websocket message")}; + } + + m_Message.SetBody(std::move(Pkg)); + } + + return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset}; +} + +void +WebSocketMessageParser::OnReset() +{ + m_Message = WebSocketMessage(); +} + +WebSocketMessage +WebSocketMessageParser::ConsumeMessage() +{ + WebSocketMessage Msg = std::move(m_Message); + m_Message = WebSocketMessage(); + + return Msg; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsConnection : public std::enable_shared_from_this<WsConnection> +{ +public: + WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket) + : m_Id(Id) + , m_Socket(std::move(Socket)) + , m_StartTime(Clock::now()) + , m_State() + { + } + + ~WsConnection() = default; + + std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); } + + WebSocketId Id() const { return m_Id; } + asio::ip::tcp::socket& Socket() { return *m_Socket; } + TimePoint StartTime() const { return m_StartTime; } + WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); } + std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + WebSocketState Close(); + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + std::mutex& WriteMutex() { return m_WriteMutex; } + +private: + WebSocketId m_Id; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + TimePoint m_StartTime; + std::atomic_uint32_t m_State; + std::unique_ptr<MessageParser> m_MsgParser; + asio::streambuf m_ReadBuffer; + std::mutex m_WriteMutex; +}; + +WebSocketState +WsConnection::Close() +{ + const auto PrevState = SetState(WebSocketState::kDisconnected); + + if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open()) + { + m_Socket->close(); + } + + return PrevState; +} + +/////////////////////////////////////////////////////////////////////////////// +class WsThreadPool +{ +public: + WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {} + void Start(uint32_t ThreadCount); + void Stop(); + +private: + asio::io_service& m_IoSvc; + std::vector<std::thread> m_Threads; + std::atomic_bool m_Running{false}; +}; + +void +WsThreadPool::Start(uint32_t ThreadCount) +{ + ZEN_ASSERT(m_Threads.empty()); + + ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount); + + m_Running = true; + + for (uint32_t Idx = 0; Idx < ThreadCount; Idx++) + { + m_Threads.emplace_back([this, ThreadId = Idx + 1] { + for (;;) + { + if (m_Running == false) + { + break; + } + + try + { + m_IoSvc.run(); + } + catch (std::exception& Err) + { + ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what()); + } + } + + ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId); + }); + } +} + +void +WsThreadPool::Stop() +{ + if (m_Running) + { + m_Running = false; + + for (std::thread& Thread : m_Threads) + { + if (Thread.joinable()) + { + Thread.join(); + } + } + + m_Threads.clear(); + } +} + +/////////////////////////////////////////////////////////////////////////////// +class WsServer final : public WebSocketServer +{ +public: + WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {} + virtual ~WsServer() { Shutdown(); } + + virtual bool Run() override; + virtual void Shutdown() override; + + virtual void RegisterService(WebSocketService& Service) override; + virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override; + virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override; + + virtual void SendNotification(WebSocketMessage&& Notification) override; + virtual void SendResponse(WebSocketMessage&& Response) override; + +private: + friend class WsConnection; + + void AcceptConnection(); + void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec); + + void ReadMessage(std::shared_ptr<WsConnection> Connection); + void RouteMessage(WebSocketMessage&& Msg); + void SendMessage(WebSocketMessage&& Msg); + + struct IdHasher + { + size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); } + }; + + using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>; + using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>; + using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>; + + WebSocketServerOptions m_Options; + asio::io_service m_IoSvc; + std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor; + std::unique_ptr<WsThreadPool> m_ThreadPool; + ConnectionMap m_Connections; + std::shared_mutex m_ConnMutex; + std::vector<WebSocketService*> m_Services; + RequestHandlerMap m_RequestHandlers; + NotificationHandlerMap m_NotificationHandlers; + std::atomic_bool m_Running{}; +}; + +void +WsServer::RegisterService(WebSocketService& Service) +{ + m_Services.push_back(&Service); + + Service.Configure(*this); +} + +bool +WsServer::Run() +{ + static constexpr size_t ReceiveBufferSize = 256 << 10; + static constexpr size_t SendBufferSize = 256 << 10; + + m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6()); + + m_Acceptor->set_option(asio::ip::v6_only(false)); + m_Acceptor->set_option(asio::socket_base::reuse_address(true)); + m_Acceptor->set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize)); + m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize)); + +#if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor->native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +#endif + + asio::error_code Ec; + m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec); + + if (Ec) + { + ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value()); + + return false; + } + + m_Acceptor->listen(); + m_Running = true; + + ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port); + + AcceptConnection(); + + m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc); + m_ThreadPool->Start(m_Options.ThreadCount); + + return true; +} + +void +WsServer::Shutdown() +{ + if (m_Running) + { + ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down"); + + m_Running = false; + + m_Acceptor->close(); + m_Acceptor.reset(); + m_IoSvc.stop(); + + m_ThreadPool->Stop(); + } +} + +void +WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) +{ + auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>()); + Result.first->second.push_back(&Service); +} + +void +WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service) +{ + m_RequestHandlers[Key] = &Service; +} + +void +WsServer::SendNotification(WebSocketMessage&& Notification) +{ + ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification); + + SendMessage(std::move(Notification)); +} +void +WsServer::SendResponse(WebSocketMessage&& Response) +{ + ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse || + Response.MessageType() == WebSocketMessageType::kStreamResponse || + Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse); + + ZEN_ASSERT(Response.CorrelationId() != 0); + + SendMessage(std::move(Response)); +} + +void +WsServer::AcceptConnection() +{ + auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc); + asio::ip::tcp::socket& SocketRef = *Socket.get(); + + m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable { + if (m_Running) + { + if (Ec) + { + ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message()); + } + else + { + auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket)); + + ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr()); + + { + std::unique_lock _(m_ConnMutex); + m_Connections[Connection->Id()] = Connection; + } + + Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest)); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + } + + AcceptConnection(); + } + }); +} + +void +WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec) +{ + if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected) + { + if (Ec) + { + ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value()); + } + else + { + ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value()); + } + } + + const WebSocketId Id = Connection->Id(); + + { + std::unique_lock _(m_ConnMutex); + if (m_Connections.contains(Id)) + { + m_Connections.erase(Id); + } + } +} + +void +WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection) +{ + Connection->ReadBuffer().prepare(64 << 10); + + asio::async_read( + Connection->Socket(), + Connection->ReadBuffer(), + asio::transfer_at_least(1), + [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable { + if (ReadEc) + { + return CloseConnection(Connection, ReadEc); + } + + switch (Connection->State()) + { + case WebSocketState::kHandshaking: + { + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser()); + asio::const_buffer Buffer = Connection->ReadBuffer().data(); + + ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size())); + + Connection->ReadBuffer().consume(Result.ByteCount); + + if (Result.Status == ParseMessageStatus::kContinue) + { + return ReadMessage(Connection); + } + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Parser.IsUpgrade() == false) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv; + + return async_write(Connection->Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + return CloseConnection(Connection, WriteEc); + } + + Connection->Parser()->Reset(); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + }); + } + + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + + std::string AcceptHash; + std::string Reason; + const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason); + + if (ValidHandshake == false) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + Reason); + + constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv; + + return async_write(Connection->Socket(), + asio::buffer(UpgradeRequiredResponse), + [this, &Connection](const asio::error_code& WriteEc, std::size_t) { + if (WriteEc) + { + return CloseConnection(Connection, WriteEc); + } + + Connection->Parser()->Reset(); + Connection->SetState(WebSocketState::kHandshaking); + + ReadMessage(Connection); + }); + } + + ExtendableStringBuilder<128> Sb; + + Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: Upgrade\r\n"sv; + + // TODO: Verify protocol + if (Parser.Headers().contains(http_header::SecWebSocketProtocol)) + { + Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol] + << "\r\n"; + } + + Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n"; + Sb << "\r\n"sv; + + ZEN_LOG_DEBUG(LogWebSocket, + "accepting handshake from connection '#{} {}'", + Connection->Id().Value(), + Connection->RemoteAddr()); + + std::string Response = Sb.ToString(); + Buffer = asio::buffer(Response); + + async_write(Connection->Socket(), + Buffer, + [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) { + if (WriteEc) + { + ZEN_LOG_DEBUG(LogWebSocket, + "handshake with connection '{}' FAILED, reason '{}'", + Connection->Id().Value(), + WriteEc.message()); + + return CloseConnection(Connection, WriteEc); + } + + ZEN_LOG_DEBUG(LogWebSocket, + "handshake ({}B) with connection '#{} {}' OK", + ByteCount, + Connection->Id().Value(), + Connection->RemoteAddr()); + + Connection->SetParser(std::make_unique<WebSocketMessageParser>()); + Connection->SetState(WebSocketState::kConnected); + + ReadMessage(Connection); + }); + } + break; + + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser()); + + uint64_t RemainingBytes = Connection->ReadBuffer().size(); + + while (RemainingBytes > 0) + { + MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Connection->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Connection->ReadBuffer().size(); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + return CloseConnection(Connection, std::error_code()); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(RemainingBytes == 0); + continue; + } + + WebSocketMessage Message = Parser.ConsumeMessage(); + Parser.Reset(); + + Message.SetSocketId(Connection->Id()); + + RouteMessage(std::move(Message)); + } + + ReadMessage(Connection); + } + break; + + default: + break; + }; + }); +} + +void +WsServer::RouteMessage(WebSocketMessage&& RoutedMessage) +{ + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kRequest: + case WebSocketMessageType::kStreamRequest: + { + CbObjectView Request = RoutedMessage.Body().GetObject(); + std::string_view Method = Request["Method"].AsString(); + bool Handled = false; + bool Error = false; + std::exception Exception; + + if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end()) + { + WebSocketService* Service = It->second; + ZEN_ASSERT(Service); + + try + { + Handled = Service->HandleRequest(std::move(RoutedMessage)); + } + catch (std::exception& Err) + { + Exception = std::move(Err); + Error = true; + } + } + + if (Error || Handled == false) + { + std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method); + + ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText); + + CbObjectWriter Response; + Response << "Error"sv << ErrorText; + + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId()); + ResponseMsg.SetSocketId(RoutedMessage.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SendResponse(std::move(ResponseMsg)); + } + } + break; + + case WebSocketMessageType::kNotification: + { + CbObjectView Notification = RoutedMessage.Body().GetObject(); + std::string_view Message = Notification["Message"].AsString(); + + if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end()) + { + std::vector<WebSocketService*>& Handlers = It->second; + + for (WebSocketService* Handler : Handlers) + { + Handler->HandleNotification(RoutedMessage); + } + } + else + { + ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message); + } + } + break; + + default: + break; + }; +} + +void +WsServer::SendMessage(WebSocketMessage&& Msg) +{ + std::shared_ptr<WsConnection> Connection; + + { + std::unique_lock _(m_ConnMutex); + + if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end()) + { + Connection = It->second; + } + } + + if (Connection.get() == nullptr) + { + ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value()); + return; + } + + if (Connection.get() != nullptr) + { + BinaryWriter Writer; + Msg.Save(Writer); + + ZEN_LOG_TRACE(LogWebSocket, + "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}", + ToString(Msg.MessageType()), + Connection->Id().Value(), + Msg.MessageSize(), + Msg.CorrelationId(), + NiceBytes(Writer.Size())); + + { + ZEN_TRACE_CPU("WS::SendMessage"); + std::unique_lock _(Connection->WriteMutex()); + ZEN_TRACE_CPU("WS::WriteSocketData"); + asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size())); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient> +{ +public: + WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {} + + virtual ~WsClient() { Disconnect(); } + + std::shared_ptr<WsClient> AsShared() { return shared_from_this(); } + + virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override; + virtual void Disconnect() override; + virtual bool IsConnected() const override { return false; } + virtual WebSocketState State() const override { return static_cast<WebSocketState>(m_State.load()); } + + virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override; + virtual void OnNotification(NotificationCallback&& Cb) override; + virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override; + +private: + WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); } + MessageParser* Parser() { return m_MsgParser.get(); } + void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); } + asio::streambuf& ReadBuffer() { return m_ReadBuffer; } + void TriggerEvent(WebSocketEvent Evt); + void ReadMessage(); + void RouteMessage(WebSocketMessage&& RoutedMessage); + + using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>; + + asio::io_context& m_IoCtx; + WebSocketId m_Id; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + std::unique_ptr<MessageParser> m_MsgParser; + asio::streambuf m_ReadBuffer; + EventCallback m_EventCallbacks[3]; + NotificationCallback m_NotificationCallback; + PendingRequestMap m_PendingRequests; + std::mutex m_RequestMutex; + std::promise<bool> m_ConnectPromise; + std::atomic_uint32_t m_State; + std::string m_Host; + int16_t m_Port{}; +}; + +std::future<bool> +WsClient::Connect(const WebSocketConnectInfo& Info) +{ + if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected) + { + return m_ConnectPromise.get_future(); + } + + SetState(WebSocketState::kHandshaking); + + try + { + asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port); + m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol()); + + m_Socket->connect(Endpoint); + + m_Host = m_Socket->remote_endpoint().address().to_string(); + m_Port = Info.Port; + + ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port); + } + catch (std::exception& Err) + { + ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what()); + + SetState(WebSocketState::kError); + m_Socket.reset(); + + TriggerEvent(WebSocketEvent::kDisconnected); + + m_ConnectPromise.set_value(false); + + return m_ConnectPromise.get_future(); + } + + ExtendableStringBuilder<128> Sb; + Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv; + Sb << "Host: " << Info.Host << "\r\n"sv; + Sb << "Upgrade: websocket\r\n"sv; + Sb << "Connection: upgrade\r\n"sv; + Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv; + + if (Info.Protocols.empty() == false) + { + Sb << "Sec-WebSocket-Protocol: "sv; + for (size_t Idx = 0; const auto& Protocol : Info.Protocols) + { + if (Idx++) + { + Sb << ", "; + } + Sb << Protocol; + } + } + + Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv; + Sb << "\r\n"; + + std::string HandshakeRequest = Sb.ToString(); + asio::const_buffer Buffer = asio::buffer(HandshakeRequest); + + ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port); + + m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse); + m_MsgParser->Reset(); + + async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) { + if (Ec) + { + ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message()); + + Self->Disconnect(); + } + else + { + Self->ReadMessage(); + } + }); + + return m_ConnectPromise.get_future(); +} + +void +WsClient::Disconnect() +{ + if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected) + { + ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port); + + if (m_Socket && m_Socket->is_open()) + { + m_Socket->close(); + m_Socket.reset(); + } + + TriggerEvent(WebSocketEvent::kDisconnected); + + { + std::unique_lock _(m_RequestMutex); + + for (auto& Kv : m_PendingRequests) + { + Kv.second.set_value(WebSocketMessage()); + } + + m_PendingRequests.clear(); + } + } +} + +std::future<WebSocketMessage> +WsClient::SendRequest(WebSocketMessage&& Request) +{ + ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest); + + BinaryWriter Writer; + Request.Save(Writer); + + std::future<WebSocketMessage> FutureResponse; + + { + std::unique_lock _(m_RequestMutex); + + auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>()); + ZEN_ASSERT(Result.second); + + auto It = Result.first; + FutureResponse = It->second.get_future(); + } + + IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size()); + + async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) { + if (Ec) + { + ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message()); + + Self->Disconnect(); + } + }); + + return FutureResponse; +} + +void +WsClient::OnNotification(NotificationCallback&& Cb) +{ + m_NotificationCallback = std::move(Cb); +} + +void +WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb) +{ + m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb); +} + +void +WsClient::TriggerEvent(WebSocketEvent Evt) +{ + const uint32_t Index = static_cast<uint32_t>(Evt); + + if (m_EventCallbacks[Index]) + { + m_EventCallbacks[Index](); + } +} + +void +WsClient::ReadMessage() +{ + m_ReadBuffer.prepare(64 << 10); + + async_read(*m_Socket, + m_ReadBuffer, + asio::transfer_at_least(1), + [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable { + const WebSocketState State = Self->State(); + + if (State == WebSocketState::kDisconnected) + { + return; + } + + if (Ec) + { + ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message()); + + return Self->Disconnect(); + } + + switch (State) + { + case WebSocketState::kHandshaking: + { + HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser()); + + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount); + + ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Self->ReadBuffer().consume(size_t(Result.ByteCount)); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode()); + + Self->m_ConnectPromise.set_value(false); + + return Self->Disconnect(); + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + return Self->ReadMessage(); + } + + ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone); + + if (Parser.StatusCode() != 101) + { + ZEN_LOG_WARN(LogWsClient, + "handshake FAILED, status '{}', status code '{}'", + Parser.StatusText(), + Parser.StatusCode()); + + Self->m_ConnectPromise.set_value(false); + + return Self->Disconnect(); + } + + ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText()); + + Self->SetParser(std::make_unique<WebSocketMessageParser>()); + Self->SetState(WebSocketState::kConnected); + Self->ReadMessage(); + Self->TriggerEvent(WebSocketEvent::kConnected); + + Self->m_ConnectPromise.set_value(true); + } + break; + + case WebSocketState::kConnected: + { + WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser()); + + uint64_t RemainingBytes = Self->ReadBuffer().size(); + + while (RemainingBytes > 0) + { + MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes); + const ParseMessageResult Result = Parser.ParseMessage(MessageData); + + Self->ReadBuffer().consume(Result.ByteCount); + RemainingBytes = Self->ReadBuffer().size(); + + if (Result.Status == ParseMessageStatus::kError) + { + ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value()); + + Parser.Reset(); + continue; + } + + if (Result.Status == ParseMessageStatus::kContinue) + { + ZEN_ASSERT(RemainingBytes == 0); + continue; + } + + WebSocketMessage Message = Parser.ConsumeMessage(); + Parser.Reset(); + + Self->RouteMessage(std::move(Message)); + } + + Self->ReadMessage(); + } + break; + + default: + break; + } + }); +} + +void +WsClient::RouteMessage(WebSocketMessage&& RoutedMessage) +{ + switch (RoutedMessage.MessageType()) + { + case WebSocketMessageType::kResponse: + { + std::unique_lock _(m_RequestMutex); + + if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end()) + { + It->second.set_value(std::move(RoutedMessage)); + m_PendingRequests.erase(It); + } + else + { + ZEN_LOG_WARN(LogWsClient, + "route request message FAILED, reason 'unknown correlation ID ({})'", + RoutedMessage.CorrelationId()); + } + } + break; + + case WebSocketMessageType::kNotification: + { + std::unique_lock _(m_RequestMutex); + + if (m_NotificationCallback) + { + m_NotificationCallback(std::move(RoutedMessage)); + } + } + break; + + default: + ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType())); + break; + }; +} + +} // namespace zen::websocket + +namespace zen { + +std::atomic_uint32_t WebSocketId::NextId{1}; + +bool +WebSocketMessage::Header::IsValid() const +{ + return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) && + uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount); +} + +std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1}; + +void +WebSocketMessage::SetMessageType(WebSocketMessageType MessageType) +{ + m_Header.MessageType = MessageType; +} + +void +WebSocketMessage::SetBody(CbPackage&& Body) +{ + m_Body = std::move(Body); +} +void +WebSocketMessage::SetBody(CbObject&& Body) +{ + CbPackage Pkg; + Pkg.SetObject(Body); + + SetBody(std::move(Pkg)); +} + +void +WebSocketMessage::Save(BinaryWriter& Writer) +{ + Writer.Write(&m_Header, HeaderSize); + + if (m_Body.has_value()) + { + const CbObject& Obj = m_Body.value().GetObject(); + MemoryView View = Obj.GetBuffer().GetView(); + + const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All); + ZEN_ASSERT(ValidationResult == CbValidateError::None); + + m_Body.value().Save(Writer); + } + + if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest) + { + m_Header.CorrelationId = NextCorrelationId.fetch_add(1); + } + + m_Header.MessageSize = Writer.Size() - HeaderSize; + + Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize)); +} + +bool +WebSocketMessage::TryLoadHeader(MemoryView Memory) +{ + if (Memory.GetSize() < HeaderSize) + { + return false; + } + + MutableMemoryView HeaderView(&m_Header, HeaderSize); + + HeaderView.CopyFrom(Memory); + + return m_Header.IsValid(); +} + +void +WebSocketService::Configure(WebSocketServer& Server) +{ + ZEN_ASSERT(m_SocketServer == nullptr); + + m_SocketServer = &Server; + + RegisterHandlers(Server); +} + +void +WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete) +{ + WebSocketMessage Message; + + Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse); + Message.SetCorrelationId(CorrelationId); + Message.SetSocketId(SocketId); + Message.SetBody(std::move(StreamResponse)); + + SocketServer().SendResponse(std::move(Message)); +} + +void +WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete) +{ + CbPackage Response; + Response.SetObject(std::move(StreamResponse)); + + SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete); +} + +std::unique_ptr<WebSocketServer> +WebSocketServer::Create(const WebSocketServerOptions& Options) +{ + return std::make_unique<websocket::WsServer>(Options); +} + +std::shared_ptr<WebSocketClient> +WebSocketClient::Create(asio::io_context& IoCtx) +{ + return std::make_shared<websocket::WsClient>(IoCtx); +} + +} // namespace zen diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua new file mode 100644 index 000000000..b0dbdbc79 --- /dev/null +++ b/src/zenhttp/xmake.lua @@ -0,0 +1,14 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zenhttp') + set_kind("static") + add_headerfiles("**.h") + add_files("**.cpp") + add_files("httpsys.cpp", {unity_ignored=true}) + add_includedirs("include", {public=true}) + add_deps("zencore") + add_packages( + "vcpkg::gsl-lite", + "vcpkg::http-parser" + ) + add_options("httpsys") diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp new file mode 100644 index 000000000..4bd6a5697 --- /dev/null +++ b/src/zenhttp/zenhttp.cpp @@ -0,0 +1,22 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/zenhttp.h> + +#if ZEN_WITH_TESTS + +# include <zenhttp/httpclient.h> +# include <zenhttp/httpserver.h> +# include <zenhttp/httpshared.h> + +namespace zen { + +void +zenhttp_forcelinktests() +{ + http_forcelink(); + forcelink_httpshared(); +} + +} // namespace zen + +#endif diff --git a/src/zenserver-test/cachepolicy-tests.cpp b/src/zenserver-test/cachepolicy-tests.cpp new file mode 100644 index 000000000..79d78e522 --- /dev/null +++ b/src/zenserver-test/cachepolicy-tests.cpp @@ -0,0 +1,153 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/zencore.h> + +#if ZEN_WITH_TESTS + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/string.h> +# include <zencore/testing.h> +# include <zencore/uid.h> +# include <zenutil/cache/cachepolicy.h> + +namespace zen::tests { + +using namespace std::literals; + +TEST_CASE("cachepolicy") +{ + SUBCASE("atomics serialization") + { + CachePolicy SomeAtomics[] = {CachePolicy::None, + CachePolicy::QueryLocal, + CachePolicy::StoreRemote, + CachePolicy::SkipData, + CachePolicy::KeepAlive}; + for (CachePolicy Atomic : SomeAtomics) + { + CHECK(ParseCachePolicy(WriteToString<128>(Atomic)) == Atomic); + } + // Also verify that we ignore unrecognized bits + for (CachePolicy Atomic : SomeAtomics) + { + CHECK(ParseCachePolicy(WriteToString<128>(Atomic | (CachePolicy)0x10000000)) == Atomic); + } + } + SUBCASE("aliases serialization") + { + CachePolicy SomeAliases[] = {CachePolicy::Query, CachePolicy::Local}; + for (CachePolicy Alias : SomeAliases) + { + CHECK(ParseCachePolicy(WriteToString<128>(Alias)) == Alias); + } + // Also verify that we ignore unrecognized bits + for (CachePolicy Alias : SomeAliases) + { + CHECK(ParseCachePolicy(WriteToString<128>(Alias | (CachePolicy)0x10000000)) == Alias); + } + } + SUBCASE("aliases take priority over atomics") + { + CHECK(WriteToString<128>(CachePolicy::Default).ToView() == "Default"sv); + CHECK(WriteToString<128>(CachePolicy::Query).ToView() == "Query"sv); + CHECK(WriteToString<128>(CachePolicy::Local).ToView() == "Local"sv); + } + SUBCASE("policies requiring multiple strings work") + { + char Delimiter = ','; + CachePolicy Combination = CachePolicy::SkipData | CachePolicy::QueryLocal; + CHECK(WriteToString<128>(Combination).ToView().find(Delimiter) != std::string_view::npos); + CHECK(ParseCachePolicy(WriteToString<128>(Combination)) == Combination); + } + SUBCASE("parsing invalid text") + { + CHECK(ParseCachePolicy(",,,") == CachePolicy::None); + CHECK(ParseCachePolicy("fee,fie,foo,fum") == CachePolicy::None); + CHECK(ParseCachePolicy("fee,KeepAlive,foo,fum") == CachePolicy::KeepAlive); + } +} + +TEST_CASE("cacherecordpolicy") +{ + SUBCASE("policy with no values") + { + CachePolicy Policy = CachePolicy::SkipData | CachePolicy::QueryLocal | CachePolicy::PartialRecord; + CachePolicy ValuePolicy = Policy & CacheValuePolicy::PolicyMask; + CacheRecordPolicy RecordPolicy; + CacheRecordPolicyBuilder Builder(Policy); + RecordPolicy = Builder.Build(); + SUBCASE("construct") + { + CHECK(RecordPolicy.IsUniform()); + CHECK(RecordPolicy.GetRecordPolicy() == Policy); + CHECK(RecordPolicy.GetBasePolicy() == Policy); + CHECK(RecordPolicy.GetValuePolicy(Oid::NewOid()) == ValuePolicy); + CHECK(RecordPolicy.GetValuePolicies().size() == 0); + } + SUBCASE("saveload") + { + CbWriter Writer; + RecordPolicy.Save(Writer); + CbObject Saved = Writer.Save()->AsObject(); + CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get(); + CHECK(Loaded.IsUniform()); + CHECK(Loaded.GetRecordPolicy() == Policy); + CHECK(Loaded.GetBasePolicy() == Policy); + CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == ValuePolicy); + CHECK(Loaded.GetValuePolicies().size() == 0); + } + } + + SUBCASE("policy with values") + { + CachePolicy DefaultPolicy = CachePolicy::StoreRemote | CachePolicy::QueryLocal | CachePolicy::PartialRecord; + CachePolicy DefaultValuePolicy = DefaultPolicy & CacheValuePolicy::PolicyMask; + CachePolicy PartialOverlap = CachePolicy::StoreRemote; + CachePolicy NoOverlap = CachePolicy::QueryRemote; + CachePolicy UnionPolicy = DefaultPolicy | PartialOverlap | NoOverlap | CachePolicy::PartialRecord; + + CacheRecordPolicy RecordPolicy; + CacheRecordPolicyBuilder Builder(DefaultPolicy); + Oid PartialOid = Oid::NewOid(); + Oid NoOverlapOid = Oid::NewOid(); + Oid OtherOid = Oid::NewOid(); + Builder.AddValuePolicy(PartialOid, PartialOverlap); + Builder.AddValuePolicy(NoOverlapOid, NoOverlap); + RecordPolicy = Builder.Build(); + SUBCASE("construct") + { + CHECK(!RecordPolicy.IsUniform()); + CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy); + CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy); + CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap); + CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap); + CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy); + CHECK(RecordPolicy.GetValuePolicies().size() == 2); + } + SUBCASE("saveload") + { + CbWriter Writer; + RecordPolicy.Save(Writer); + CbObject Saved = Writer.Save()->AsObject(); + CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get(); + CHECK(!RecordPolicy.IsUniform()); + CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy); + CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy); + CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap); + CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap); + CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy); + CHECK(RecordPolicy.GetValuePolicies().size() == 2); + } + } + + SUBCASE("parsing invalid text") + { + OptionalCacheRecordPolicy Loaded = CacheRecordPolicy::Load(CbObject()); + CHECK(Loaded.IsNull()); + } +} + +} // namespace zen::tests + +#endif diff --git a/src/zenserver-test/projectclient.cpp b/src/zenserver-test/projectclient.cpp new file mode 100644 index 000000000..597838e0d --- /dev/null +++ b/src/zenserver-test/projectclient.cpp @@ -0,0 +1,164 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "projectclient.h" + +#if 0 + +# include <zencore/compactbinary.h> +# include <zencore/logging.h> +# include <zencore/sharedbuffer.h> +# include <zencore/string.h> +# include <zencore/zencore.h> + +# include <asio.hpp> +# include <gsl/gsl-lite.hpp> + +# if ZEN_PLATFORM_WINDOWS +# include <atlbase.h> +# endif + +namespace zen { + +struct ProjectClientConnection +{ + ProjectClientConnection(int BasePort) { Connect(BasePort); } + + void Connect(int BasePort) + { + ZEN_UNUSED(BasePort); + + WideStringBuilder<64> PipeName; + PipeName << "\\\\.\\pipe\\zenprj"; // TODO: this should use an instance-specific identifier! + + HANDLE hPipe = CreateFileW(PipeName.c_str(), + GENERIC_READ | GENERIC_WRITE, + 0, // Sharing doesn't make any sense + nullptr, // No security attributes + OPEN_EXISTING, // Open existing pipe + 0, // Attributes + nullptr // Template file + ); + + if (hPipe == INVALID_HANDLE_VALUE) + { + ZEN_WARN("failed while creating named pipe {}", WideToUtf8(PipeName)); + + throw std::system_error(GetLastError(), std::system_category(), fmt::format("Failed to open named pipe '{}'", WideToUtf8(PipeName))); + } + + // Change to message mode + DWORD dwMode = PIPE_READMODE_MESSAGE; + BOOL Success = SetNamedPipeHandleState(hPipe, &dwMode, nullptr, nullptr); + + if (!Success) + { + throw std::system_error(GetLastError(), + std::system_category(), + fmt::format("Failed to change named pipe '{}' to message mode", WideToUtf8(PipeName))); + } + + m_hPipe.Attach(hPipe); // This now owns the handle and will close it + } + + ~ProjectClientConnection() {} + + CbObject MessageTransaction(CbObject Request) + { + DWORD dwWrittenBytes = 0; + + MemoryView View = Request.GetView(); + + BOOL Success = ::WriteFile(m_hPipe, View.GetData(), gsl::narrow_cast<DWORD>(View.GetSize()), &dwWrittenBytes, nullptr); + + if (!Success) + { + throw std::system_error(GetLastError(), std::system_category(), "Failed to write pipe message"); + } + + ZEN_ASSERT(dwWrittenBytes == View.GetSize()); + + DWORD dwReadBytes = 0; + + Success = ReadFile(m_hPipe, m_Buffer, sizeof m_Buffer, &dwReadBytes, nullptr); + + if (!Success) + { + DWORD ErrorCode = GetLastError(); + + if (ERROR_MORE_DATA == ErrorCode) + { + // Response message is larger than our buffer - handle it by allocating a larger + // buffer on the heap and read the remainder into that buffer + + DWORD dwBytesAvail = 0, dwLeftThisMessage = 0; + + Success = PeekNamedPipe(m_hPipe, nullptr, 0, nullptr, &dwBytesAvail, &dwLeftThisMessage); + + if (Success) + { + UniqueBuffer MessageBuffer = UniqueBuffer::Alloc(dwReadBytes + dwLeftThisMessage); + + memcpy(MessageBuffer.GetData(), m_Buffer, dwReadBytes); + + Success = ReadFile(m_hPipe, + reinterpret_cast<uint8_t*>(MessageBuffer.GetData()) + dwReadBytes, + dwLeftThisMessage, + &dwReadBytes, + nullptr); + + if (Success) + { + return CbObject(SharedBuffer(std::move(MessageBuffer))); + } + } + } + + throw std::system_error(GetLastError(), std::system_category(), "Failed to read pipe message"); + } + + return CbObject(SharedBuffer::MakeView(MakeMemoryView(m_Buffer))); + } + +private: + static const int kEmbeddedBufferSize = 512 - 16; + + CHandle m_hPipe; + uint8_t m_Buffer[kEmbeddedBufferSize]; +}; + +struct LocalProjectClient::ClientImpl +{ + ClientImpl(int BasePort) : m_BasePort(BasePort) {} + ~ClientImpl() {} + + void Start() {} + void Stop() {} + + inline int BasePort() const { return m_BasePort; } + +private: + int m_BasePort = 0; +}; + +LocalProjectClient::LocalProjectClient(int BasePort) +{ + m_Impl = std::make_unique<ClientImpl>(BasePort); + m_Impl->Start(); +} + +LocalProjectClient::~LocalProjectClient() +{ + m_Impl->Stop(); +} + +CbObject +LocalProjectClient::MessageTransaction(CbObject Request) +{ + ProjectClientConnection Cx(m_Impl->BasePort()); + + return Cx.MessageTransaction(Request); +} + +} // namespace zen + +#endif // 0 diff --git a/src/zenserver-test/projectclient.h b/src/zenserver-test/projectclient.h new file mode 100644 index 000000000..1865dd67d --- /dev/null +++ b/src/zenserver-test/projectclient.h @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <memory> + +#include <zencore/compactbinary.h> +#include <zencore/refcount.h> + +namespace zen { + +/** + * Client for communication with local project service + * + * This is WIP and not yet functional! + */ + +class LocalProjectClient : public RefCounted +{ +public: + LocalProjectClient(int BasePort = 0); + ~LocalProjectClient(); + + CbObject MessageTransaction(CbObject Request); + +private: + struct ClientImpl; + + std::unique_ptr<ClientImpl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenserver-test/xmake.lua b/src/zenserver-test/xmake.lua new file mode 100644 index 000000000..f0b34f6ca --- /dev/null +++ b/src/zenserver-test/xmake.lua @@ -0,0 +1,16 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zenserver-test") + set_kind("binary") + add_headerfiles("**.h") + add_files("*.cpp") + add_files("zenserver-test.cpp", {unity_ignored = true }) + add_deps("zencore", "zenutil", "zenhttp") + add_deps("zenserver", {inherit=false}) + add_packages("vcpkg::http-parser", "vcpkg::mimalloc") + + if is_plat("macosx") then + add_ldflags("-framework CoreFoundation") + add_ldflags("-framework Security") + add_ldflags("-framework SystemConfiguration") + end diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp new file mode 100644 index 000000000..3195181d1 --- /dev/null +++ b/src/zenserver-test/zenserver-test.cpp @@ -0,0 +1,3323 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#define _SILENCE_CXX17_C_HEADER_DEPRECATION_WARNING + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compress.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/refcount.h> +#include <zencore/stream.h> +#include <zencore/string.h> +#include <zencore/testutils.h> +#include <zencore/thread.h> +#include <zencore/timer.h> +#include <zencore/xxhash.h> +#include <zenhttp/httpclient.h> +#include <zenhttp/httpshared.h> +#include <zenhttp/websocket.h> +#include <zenhttp/zenhttp.h> +#include <zenutil/cache/cache.h> +#include <zenutil/cache/cacherequests.h> +#include <zenutil/zenserverprocess.h> + +#if ZEN_USE_MIMALLOC +ZEN_THIRD_PARTY_INCLUDES_START +# include <mimalloc.h> +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +#include <http_parser.h> + +#if ZEN_PLATFORM_WINDOWS +# pragma comment(lib, "Crypt32.lib") +# pragma comment(lib, "Wldap32.lib") +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#undef GetObject +ZEN_THIRD_PARTY_INCLUDES_END + +#include <atomic> +#include <filesystem> +#include <map> +#include <random> +#include <span> +#include <thread> +#include <typeindex> +#include <unordered_map> + +#if ZEN_PLATFORM_WINDOWS +# include <ppl.h> +# include <atlbase.h> +# include <process.h> +#endif + +#include <asio.hpp> + +////////////////////////////////////////////////////////////////////////// + +#include "projectclient.h" + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS +# define ZEN_TEST_WITH_RUNNER 1 +# include <zencore/testing.h> +# include <zencore/workthreadpool.h> +#endif + +using namespace std::literals; + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +struct Concurrency +{ + template<typename... T> + static void parallel_invoke(T&&... t) + { + constexpr size_t NumTs = sizeof...(t); + std::thread Threads[NumTs] = { + std::thread(std::forward<T>(t))..., + }; + + for (std::thread& Thread : Threads) + { + Thread.join(); + } + } +}; +#endif + +////////////////////////////////////////////////////////////////////////// +// +// Custom logging -- test code, this should be tweaked +// + +namespace logging { +using namespace spdlog; +using namespace spdlog::details; +using namespace std::literals; + +class full_test_formatter final : public spdlog::formatter +{ +public: + full_test_formatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) : m_Epoch(Epoch), m_LogId(LogId) + { + } + + virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<full_test_formatter>(m_LogId, m_Epoch); } + + static constexpr bool UseDate = false; + + virtual void format(const details::log_msg& msg, memory_buf_t& dest) override + { + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using std::chrono::seconds; + + if constexpr (UseDate) + { + auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch()); + if (secs != m_LastLogSecs) + { + m_CachedTm = os::localtime(log_clock::to_time_t(msg.time)); + m_LastLogSecs = secs; + } + } + + const auto& tm_time = m_CachedTm; + + // cache the date/time part for the next second. + auto duration = msg.time - m_Epoch; + auto secs = duration_cast<seconds>(duration); + + if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0) + { + m_CachedDatetime.clear(); + m_CachedDatetime.push_back('['); + + if constexpr (UseDate) + { + fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime); + m_CachedDatetime.push_back('-'); + + fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime); + m_CachedDatetime.push_back('-'); + + fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime); + m_CachedDatetime.push_back(' '); + + fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + + fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + + fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime); + } + else + { + int Count = int(secs.count()); + + const int LogSecs = Count % 60; + Count /= 60; + + const int LogMins = Count % 60; + Count /= 60; + + const int LogHours = Count; + + fmt_helper::pad2(LogHours, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + fmt_helper::pad2(LogMins, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + fmt_helper::pad2(LogSecs, m_CachedDatetime); + } + + m_CachedDatetime.push_back('.'); + + m_CacheTimestamp = secs; + } + + dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); + + auto millis = fmt_helper::time_fraction<milliseconds>(msg.time); + fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest); + dest.push_back(']'); + dest.push_back(' '); + + if (!m_LogId.empty()) + { + dest.push_back('['); + fmt_helper::append_string_view(m_LogId, dest); + dest.push_back(']'); + dest.push_back(' '); + } + + // append logger name if exists + if (msg.logger_name.size() > 0) + { + dest.push_back('['); + fmt_helper::append_string_view(msg.logger_name, dest); + dest.push_back(']'); + dest.push_back(' '); + } + + dest.push_back('['); + // wrap the level name with color + msg.color_range_start = dest.size(); + fmt_helper::append_string_view(level::to_string_view(msg.level), dest); + msg.color_range_end = dest.size(); + dest.push_back(']'); + dest.push_back(' '); + + // add source location if present + if (!msg.source.empty()) + { + dest.push_back('['); + const char* filename = details::short_filename_formatter<details::null_scoped_padder>::basename(msg.source.filename); + fmt_helper::append_string_view(filename, dest); + dest.push_back(':'); + fmt_helper::append_int(msg.source.line, dest); + dest.push_back(']'); + dest.push_back(' '); + } + + fmt_helper::append_string_view(msg.payload, dest); + fmt_helper::append_string_view("\n"sv, dest); + } + +private: + std::chrono::time_point<std::chrono::system_clock> m_Epoch; + std::tm m_CachedTm; + std::chrono::seconds m_LastLogSecs; + std::chrono::seconds m_CacheTimestamp{0}; + memory_buf_t m_CachedDatetime; + std::string m_LogId; +}; +} // namespace logging + +////////////////////////////////////////////////////////////////////////// + +#if 0 + +int +main() +{ + mi_version(); + + zen::Sleep(1000); + + zen::Stopwatch timer; + + const int RequestCount = 100000; + + cpr::Session Sessions[10]; + + for (auto& Session : Sessions) + { + Session.SetUrl(cpr::Url{"http://localhost:1337/test/hello"}); + //Session.SetUrl(cpr::Url{ "http://arn-wd-l0182:1337/test/hello" }); + } + + auto Run = [](cpr::Session& Session) { + for (int i = 0; i < 10000; ++i) + { + cpr::Response Result = Session.Get(); + + if (Result.status_code != 200) + { + ZEN_WARN("request response: {}", Result.status_code); + } + } + }; + + Concurrency::parallel_invoke([&] { Run(Sessions[0]); }, + [&] { Run(Sessions[1]); }, + [&] { Run(Sessions[2]); }, + [&] { Run(Sessions[3]); }, + [&] { Run(Sessions[4]); }, + [&] { Run(Sessions[5]); }, + [&] { Run(Sessions[6]); }, + [&] { Run(Sessions[7]); }, + [&] { Run(Sessions[8]); }, + [&] { Run(Sessions[9]); }); + + // cpr::Response r = cpr::Get(cpr::Url{ "http://localhost:1337/test/hello" }); + + ZEN_INFO("{} requests in {} ({})", + RequestCount, + zen::NiceTimeSpanMs(timer.GetElapsedTimeMs()), + zen::NiceRate(RequestCount, (uint32_t)timer.GetElapsedTimeMs(), "req")); + + return 0; +} +#elif 0 +// #include <restinio/all.hpp> + +int +main() +{ + mi_version(); + restinio::run(restinio::on_thread_pool(32).port(8080).request_handler( + [](auto req) { return req->create_response().set_body("Hello, World!").done(); })); + return 0; +} +#elif ZEN_WITH_TESTS + +zen::ZenServerEnvironment TestEnv; + +int +main(int argc, char** argv) +{ + using namespace std::literals; + +# if ZEN_USE_MIMALLOC + mi_version(); +# endif + + zen::zencore_forcelinktests(); + zen::zenhttp_forcelinktests(); + zen::cacherequests_forcelink(); + + zen::logging::InitializeLogging(); + + spdlog::set_level(spdlog::level::debug); + spdlog::set_formatter(std::make_unique< ::logging::full_test_formatter>("test", std::chrono::system_clock::now())); + + std::filesystem::path ProgramBaseDir = std::filesystem::path(argv[0]).parent_path(); + std::filesystem::path TestBaseDir = ProgramBaseDir.parent_path().parent_path() / ".test"; + + // This is pretty janky because we're passing most of the options through to the test + // framework, so we can't just use cxxopts (I think). This should ideally be cleaned up + // somehow in the future + + std::string ServerClass; + + for (int i = 1; i < argc; ++i) + { + if (argv[i] == "--http"sv) + { + if ((i + 1) < argc) + { + ServerClass = argv[++i]; + } + } + } + + TestEnv.InitializeForTest(ProgramBaseDir, TestBaseDir, ServerClass); + + ZEN_INFO("Running tests...(base dir: '{}')", TestBaseDir); + + zen::testing::TestRunner Runner; + Runner.ApplyCommandLine(argc, argv); + + return Runner.Run(); +} + +namespace zen::tests { + +TEST_CASE("default.single") +{ + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + + ZenServerInstance Instance(TestEnv); + Instance.SetTestDir(TestDir); + Instance.SpawnServer(13337); + + ZEN_INFO("Waiting..."); + + Instance.WaitUntilReady(); + + std::atomic<uint64_t> RequestCount{0}; + std::atomic<uint64_t> BatchCounter{0}; + + ZEN_INFO("Running single server test..."); + + auto IssueTestRequests = [&] { + const uint64_t BatchNo = BatchCounter.fetch_add(1); + const int ThreadId = zen::GetCurrentThreadId(); + + ZEN_INFO("query batch {} started (thread {})", BatchNo, ThreadId); + cpr::Session cli; + cli.SetUrl(cpr::Url{"http://localhost:13337/test/hello"}); + + for (int i = 0; i < 10000; ++i) + { + auto res = cli.Get(); + ++RequestCount; + } + ZEN_INFO("query batch {} ended (thread {})", BatchNo, ThreadId); + }; + + zen::Stopwatch timer; + + Concurrency::parallel_invoke(IssueTestRequests, + IssueTestRequests, + IssueTestRequests, + IssueTestRequests, + IssueTestRequests, + IssueTestRequests, + IssueTestRequests, + IssueTestRequests, + IssueTestRequests, + IssueTestRequests); + + uint64_t Elapsed = timer.GetElapsedTimeMs(); + + ZEN_INFO("{} requests in {} ({})", RequestCount, zen::NiceTimeSpanMs(Elapsed), zen::NiceRate(RequestCount, (uint32_t)Elapsed, "req")); +} + +TEST_CASE("multi.basic") +{ + ZenServerInstance Instance1(TestEnv); + std::filesystem::path TestDir1 = TestEnv.CreateNewTestDir(); + Instance1.SetTestDir(TestDir1); + Instance1.SpawnServer(13337); + + ZenServerInstance Instance2(TestEnv); + std::filesystem::path TestDir2 = TestEnv.CreateNewTestDir(); + Instance2.SetTestDir(TestDir2); + Instance2.SpawnServer(13338); + + ZEN_INFO("Waiting..."); + + Instance1.WaitUntilReady(); + Instance2.WaitUntilReady(); + + std::atomic<uint64_t> RequestCount{0}; + std::atomic<uint64_t> BatchCounter{0}; + + auto IssueTestRequests = [&](int PortNumber) { + const uint64_t BatchNo = BatchCounter.fetch_add(1); + const int ThreadId = zen::GetCurrentThreadId(); + + ZEN_INFO("query batch {} started (thread {}) for port {}", BatchNo, ThreadId, PortNumber); + + cpr::Session cli; + cli.SetUrl(cpr::Url{fmt::format("http://localhost:{}/test/hello", PortNumber)}); + + for (int i = 0; i < 10000; ++i) + { + auto res = cli.Get(); + ++RequestCount; + } + ZEN_INFO("query batch {} ended (thread {})", BatchNo, ThreadId); + }; + + zen::Stopwatch timer; + + ZEN_INFO("Running multi-server test..."); + + Concurrency::parallel_invoke([&] { IssueTestRequests(13337); }, + [&] { IssueTestRequests(13338); }, + [&] { IssueTestRequests(13337); }, + [&] { IssueTestRequests(13338); }); + + uint64_t Elapsed = timer.GetElapsedTimeMs(); + + ZEN_INFO("{} requests in {} ({})", RequestCount, zen::NiceTimeSpanMs(Elapsed), zen::NiceRate(RequestCount, (uint32_t)Elapsed, "req")); +} + +TEST_CASE("project.basic") +{ + using namespace std::literals; + + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + + const uint16_t PortNumber = 13337; + + ZenServerInstance Instance1(TestEnv); + Instance1.SetTestDir(TestDir); + Instance1.SpawnServer(PortNumber); + Instance1.WaitUntilReady(); + + std::atomic<uint64_t> RequestCount{0}; + + zen::Stopwatch timer; + + std::mt19937_64 mt; + + zen::StringBuilder<64> BaseUri; + BaseUri << fmt::format("http://localhost:{}/prj/test", PortNumber); + + std::filesystem::path BinPath = zen::GetRunningExecutablePath(); + std::filesystem::path RootPath = BinPath.parent_path().parent_path(); + BinPath = BinPath.lexically_relative(RootPath); + + SUBCASE("build store init") + { + { + { + zen::CbObjectWriter Body; + Body << "id" + << "test"; + Body << "root" << RootPath.c_str(); + Body << "project" + << "/zooom"; + Body << "engine" + << "/zooom"; + + zen::BinaryWriter MemOut; + Body.Save(MemOut); + + auto Response = cpr::Post(cpr::Url{BaseUri.c_str()}, cpr::Body{(const char*)MemOut.Data(), MemOut.Size()}); + CHECK(Response.status_code == 201); + } + + { + auto Response = cpr::Get(cpr::Url{BaseUri.c_str()}); + CHECK(Response.status_code == 200); + + zen::CbObjectView ResponseObject = zen::CbFieldView(Response.text.data()).AsObjectView(); + + CHECK(ResponseObject["id"].AsString() == "test"sv); + CHECK(ResponseObject["root"].AsString() == PathToUtf8(RootPath.c_str())); + } + } + + BaseUri << "/oplog/foobar"; + + { + { + zen::StringBuilder<64> PostUri; + PostUri << BaseUri; + auto Response = cpr::Post(cpr::Url{PostUri.c_str()}); + CHECK(Response.status_code == 201); + } + + { + auto Response = cpr::Get(cpr::Url{BaseUri.c_str()}); + CHECK(Response.status_code == 200); + + zen::CbObjectView ResponseObject = zen::CbFieldView(Response.text.data()).AsObjectView(); + + CHECK(ResponseObject["id"].AsString() == "foobar"sv); + CHECK(ResponseObject["project"].AsString() == "test"sv); + } + } + + SUBCASE("build store persistence") + { + uint8_t AttachData[] = {1, 2, 3}; + + zen::CompressedBuffer Attachment = zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone(zen::MemoryView{AttachData, 3})); + zen::CbAttachment Attach{Attachment, Attachment.DecodeRawHash()}; + + zen::CbObjectWriter OpWriter; + OpWriter << "key" + << "foo" + << "attachment" << Attach; + + const std::string_view ChunkId{ + "00000000" + "00000000" + "00010000"}; + auto FileOid = zen::Oid::FromHexString(ChunkId); + + OpWriter.BeginArray("files"); + OpWriter.BeginObject(); + OpWriter << "id" << FileOid; + OpWriter << "clientpath" + << "/{engine}/client/side/path"; + OpWriter << "serverpath" << BinPath.c_str(); + OpWriter.EndObject(); + OpWriter.EndArray(); + + zen::CbObject Op = OpWriter.Save(); + + zen::CbPackage OpPackage(Op); + OpPackage.AddAttachment(Attach); + + zen::BinaryWriter MemOut; + legacy::SaveCbPackage(OpPackage, MemOut); + + { + zen::StringBuilder<64> PostUri; + PostUri << BaseUri << "/new"; + auto Response = cpr::Post(cpr::Url{PostUri.c_str()}, cpr::Body{(const char*)MemOut.Data(), MemOut.Size()}); + + REQUIRE(!Response.error); + CHECK(Response.status_code == 201); + } + + // Read file data + + { + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << BaseUri << "/" << ChunkId; + auto Response = cpr::Get(cpr::Url{ChunkGetUri.c_str()}); + + REQUIRE(!Response.error); + CHECK(Response.status_code == 200); + } + + { + zen::StringBuilder<128> ChunkGetUri; + ChunkGetUri << BaseUri << "/" << ChunkId << "?offset=1&size=10"; + auto Response = cpr::Get(cpr::Url{ChunkGetUri.c_str()}); + + REQUIRE(!Response.error); + CHECK(Response.status_code == 200); + CHECK(Response.text.size() == 10); + } + + ZEN_INFO("+++++++"); + } + SUBCASE("build store op commit") { ZEN_INFO("-------"); } + SUBCASE("test chunk not found error") + { + for (size_t I = 0; I < 65; I++) + { + zen::StringBuilder<128> PostUri; + PostUri << BaseUri << "/f77c781846caead318084604/info"; + auto Response = cpr::Get(cpr::Url{PostUri.c_str()}); + + REQUIRE(!Response.error); + CHECK(Response.status_code == 404); + } + } + } + + const uint64_t Elapsed = timer.GetElapsedTimeMs(); + + ZEN_INFO("{} requests in {} ({})", RequestCount, zen::NiceTimeSpanMs(Elapsed), zen::NiceRate(RequestCount, (uint32_t)Elapsed, "req")); +} + +# if 0 // this is extremely WIP +TEST_CASE("project.pipe") +{ + using namespace std::literals; + + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + + const uint16_t PortNumber = 13337; + + ZenServerInstance Instance1(TestEnv); + Instance1.SetTestDir(TestDir); + Instance1.SpawnServer(PortNumber); + Instance1.WaitUntilReady(); + + zen::LocalProjectClient LocalClient(PortNumber); + + zen::CbObjectWriter Cbow; + Cbow << "hey" << 42; + + zen::CbObject Response = LocalClient.MessageTransaction(Cbow.Save()); +} +# endif + +namespace utils { + + struct ZenConfig + { + std::filesystem::path DataDir; + uint16_t Port; + std::string BaseUri; + std::string Args; + + static ZenConfig New(uint16_t Port = 13337, std::string Args = "") + { + return ZenConfig{.DataDir = TestEnv.CreateNewTestDir(), + .Port = Port, + .BaseUri = fmt::format("http://localhost:{}/z$", Port), + .Args = std::move(Args)}; + } + + static ZenConfig NewWithUpstream(uint16_t UpstreamPort) + { + return New(13337, fmt::format("--debug --upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", UpstreamPort)); + } + + static ZenConfig NewWithThreadedUpstreams(std::span<uint16_t> UpstreamPorts, bool Debug) + { + std::string Args = Debug ? "--debug" : ""; + for (uint16_t Port : UpstreamPorts) + { + Args = fmt::format("{}{}--upstream-zen-url=http://localhost:{}", Args, Args.length() > 0 ? " " : "", Port); + } + return New(13337, Args); + } + + void Spawn(ZenServerInstance& Inst) + { + Inst.SetTestDir(DataDir); + Inst.SpawnServer(Port, Args); + Inst.WaitUntilReady(); + } + }; + + void SpawnServer(ZenServerInstance& Server, ZenConfig& Cfg) + { + Server.SetTestDir(Cfg.DataDir); + Server.SpawnServer(Cfg.Port, Cfg.Args); + Server.WaitUntilReady(); + } + +} // namespace utils + +TEST_CASE("zcache.basic") +{ + using namespace std::literals; + + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + + const uint16_t PortNumber = 13337; + + const int kIterationCount = 100; + const auto BaseUri = fmt::format("http://localhost:{}/z$", PortNumber); + + auto HashKey = [](int i) -> zen::IoHash { return zen::IoHash::HashBuffer(&i, sizeof i); }; + + { + ZenServerInstance Instance1(TestEnv); + Instance1.SetTestDir(TestDir); + Instance1.SpawnServer(PortNumber); + Instance1.WaitUntilReady(); + + // Populate with some simple data + + for (int i = 0; i < kIterationCount; ++i) + { + zen::CbObjectWriter Cbo; + Cbo << "index" << i; + + zen::BinaryWriter MemOut; + Cbo.Save(MemOut); + + zen::IoHash Key = HashKey(i); + + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", BaseUri, "test", Key)}, + cpr::Body{(const char*)MemOut.Data(), MemOut.Size()}, + cpr::Header{{"Content-Type", "application/x-ue-cb"}}); + + CHECK(Result.status_code == 201); + } + + // Retrieve data + + for (int i = 0; i < kIterationCount; ++i) + { + zen::IoHash Key = zen::IoHash::HashBuffer(&i, sizeof i); + + cpr::Response Result = + cpr::Get(cpr::Url{fmt::format("{}/{}/{}", BaseUri, "test", Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + + CHECK(Result.status_code == 200); + } + + // Ensure bad bucket identifiers are rejected + + { + zen::CbObjectWriter Cbo; + Cbo << "index" << 42; + + zen::BinaryWriter MemOut; + Cbo.Save(MemOut); + + zen::IoHash Key = HashKey(442); + + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", BaseUri, "te!st", Key)}, + cpr::Body{(const char*)MemOut.Data(), MemOut.Size()}, + cpr::Header{{"Content-Type", "application/x-ue-cb"}}); + + CHECK(Result.status_code == 400); + } + } + + // Verify that the data persists between process runs (the previous server has exited at this point) + + { + ZenServerInstance Instance1(TestEnv); + Instance1.SetTestDir(TestDir); + Instance1.SpawnServer(PortNumber); + Instance1.WaitUntilReady(); + + // Retrieve data again + + for (int i = 0; i < kIterationCount; ++i) + { + zen::IoHash Key = HashKey(i); + + cpr::Response Result = + cpr::Get(cpr::Url{fmt::format("{}/{}/{}", BaseUri, "test", Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + + CHECK(Result.status_code == 200); + } + } +} + +TEST_CASE("zcache.cbpackage") +{ + using namespace std::literals; + + auto CreateTestPackage = [](zen::IoHash& OutAttachmentKey) -> zen::CbPackage { + auto Data = zen::SharedBuffer::Clone(zen::MakeMemoryView<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9})); + auto CompressedData = zen::CompressedBuffer::Compress(Data); + + OutAttachmentKey = CompressedData.DecodeRawHash(); + + zen::CbWriter Obj; + Obj.BeginObject("obj"sv); + Obj.AddBinaryAttachment("data", OutAttachmentKey); + Obj.EndObject(); + + zen::CbPackage Package; + Package.SetObject(Obj.Save().AsObject()); + Package.AddAttachment(zen::CbAttachment(CompressedData, OutAttachmentKey)); + + return Package; + }; + + auto SerializeToBuffer = [](zen::CbPackage Package) -> zen::IoBuffer { + zen::BinaryWriter MemStream; + + Package.Save(MemStream); + + return zen::IoBuffer(zen::IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + }; + + auto IsEqual = [](zen::CbPackage Lhs, zen::CbPackage Rhs) -> bool { + std::span<const zen::CbAttachment> LhsAttachments = Lhs.GetAttachments(); + std::span<const zen::CbAttachment> RhsAttachments = Rhs.GetAttachments(); + + if (LhsAttachments.size() != RhsAttachments.size()) + { + return false; + } + + for (const zen::CbAttachment& LhsAttachment : LhsAttachments) + { + const zen::CbAttachment* RhsAttachment = Rhs.FindAttachment(LhsAttachment.GetHash()); + CHECK(RhsAttachment); + + zen::SharedBuffer LhsBuffer = LhsAttachment.AsCompressedBinary().Decompress(); + CHECK(!LhsBuffer.IsNull()); + + zen::SharedBuffer RhsBuffer = RhsAttachment->AsCompressedBinary().Decompress(); + CHECK(!RhsBuffer.IsNull()); + + if (!LhsBuffer.GetView().EqualBytes(RhsBuffer.GetView())) + { + return false; + } + } + + return true; + }; + + SUBCASE("PUT/GET returns correct package") + { + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const uint16_t PortNumber = 13337; + const auto BaseUri = fmt::format("http://localhost:{}/z$", PortNumber); + + ZenServerInstance Instance1(TestEnv); + Instance1.SetTestDir(TestDir); + Instance1.SpawnServer(PortNumber); + Instance1.WaitUntilReady(); + + const std::string_view Bucket = "mosdef"sv; + zen::IoHash Key; + zen::CbPackage ExpectedPackage = CreateTestPackage(Key); + + // PUT + { + zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage); + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", BaseUri, Bucket, Key)}, + cpr::Body{(const char*)Body.Data(), Body.Size()}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 201); + } + + // GET + { + cpr::Response Result = + cpr::Get(cpr::Url{fmt::format("{}/{}/{}", BaseUri, Bucket, Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 200); + + zen::IoBuffer Response(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size()); + + zen::CbPackage Package; + const bool Ok = Package.TryLoad(Response); + CHECK(Ok); + CHECK(IsEqual(Package, ExpectedPackage)); + } + } + + SUBCASE("PUT propagates upstream") + { + // Setup local and remote server + std::filesystem::path LocalDataDir = TestEnv.CreateNewTestDir(); + std::filesystem::path RemoteDataDir = TestEnv.CreateNewTestDir(); + const uint16_t LocalPortNumber = 13337; + const uint16_t RemotePortNumber = 13338; + + const auto LocalBaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber); + const auto RemoteBaseUri = fmt::format("http://localhost:{}/z$", RemotePortNumber); + + ZenServerInstance RemoteInstance(TestEnv); + RemoteInstance.SetTestDir(RemoteDataDir); + RemoteInstance.SpawnServer(RemotePortNumber); + RemoteInstance.WaitUntilReady(); + + ZenServerInstance LocalInstance(TestEnv); + LocalInstance.SetTestDir(LocalDataDir); + LocalInstance.SpawnServer(LocalPortNumber, + fmt::format("--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", RemotePortNumber)); + LocalInstance.WaitUntilReady(); + + const std::string_view Bucket = "mosdef"sv; + zen::IoHash Key; + zen::CbPackage ExpectedPackage = CreateTestPackage(Key); + + // Store the cache record package in the local instance + { + zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage); + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", LocalBaseUri, Bucket, Key)}, + cpr::Body{(const char*)Body.Data(), Body.Size()}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}}); + + CHECK(Result.status_code == 201); + } + + // The cache record can be retrieved as a package from the local instance + { + cpr::Response Result = + cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalBaseUri, Bucket, Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 200); + + zen::IoBuffer Body(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size()); + zen::CbPackage Package; + const bool Ok = Package.TryLoad(Body); + CHECK(Ok); + CHECK(IsEqual(Package, ExpectedPackage)); + } + + // The cache record can be retrieved as a package from the remote instance + { + cpr::Response Result = + cpr::Get(cpr::Url{fmt::format("{}/{}/{}", RemoteBaseUri, Bucket, Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 200); + + zen::IoBuffer Body(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size()); + zen::CbPackage Package; + const bool Ok = Package.TryLoad(Body); + CHECK(Ok); + CHECK(IsEqual(Package, ExpectedPackage)); + } + } + + SUBCASE("GET finds upstream when missing in local") + { + // Setup local and remote server + std::filesystem::path LocalDataDir = TestEnv.CreateNewTestDir(); + std::filesystem::path RemoteDataDir = TestEnv.CreateNewTestDir(); + const uint16_t LocalPortNumber = 13337; + const uint16_t RemotePortNumber = 13338; + + const auto LocalBaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber); + const auto RemoteBaseUri = fmt::format("http://localhost:{}/z$", RemotePortNumber); + + ZenServerInstance RemoteInstance(TestEnv); + RemoteInstance.SetTestDir(RemoteDataDir); + RemoteInstance.SpawnServer(RemotePortNumber); + RemoteInstance.WaitUntilReady(); + + ZenServerInstance LocalInstance(TestEnv); + LocalInstance.SetTestDir(LocalDataDir); + LocalInstance.SpawnServer(LocalPortNumber, + fmt::format("--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", RemotePortNumber)); + LocalInstance.WaitUntilReady(); + + const std::string_view Bucket = "mosdef"sv; + zen::IoHash Key; + zen::CbPackage ExpectedPackage = CreateTestPackage(Key); + + // Store the cache record package in upstream cache + { + zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage); + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", RemoteBaseUri, Bucket, Key)}, + cpr::Body{(const char*)Body.Data(), Body.Size()}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}}); + + CHECK(Result.status_code == 201); + } + + // The cache record can be retrieved as a package from the local cache + { + cpr::Response Result = + cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalBaseUri, Bucket, Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 200); + + zen::IoBuffer Body(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size()); + zen::CbPackage Package; + const bool Ok = Package.TryLoad(Body); + CHECK(Ok); + CHECK(IsEqual(Package, ExpectedPackage)); + } + } +} + +TEST_CASE("zcache.policy") +{ + using namespace std::literals; + using namespace utils; + + auto GenerateData = [](uint64_t Size, zen::IoHash& OutHash) -> zen::UniqueBuffer { + auto Buf = zen::UniqueBuffer::Alloc(Size); + uint8_t* Data = reinterpret_cast<uint8_t*>(Buf.GetData()); + for (uint64_t Idx = 0; Idx < Size; Idx++) + { + Data[Idx] = Idx % 256; + } + OutHash = zen::IoHash::HashBuffer(Data, Size); + return Buf; + }; + + auto GeneratePackage = [](zen::IoHash& OutRecordKey, zen::IoHash& OutAttachmentKey) -> zen::CbPackage { + auto Data = zen::SharedBuffer::Clone(zen::MakeMemoryView<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9})); + auto CompressedData = zen::CompressedBuffer::Compress(Data); + OutAttachmentKey = CompressedData.DecodeRawHash(); + + zen::CbWriter Writer; + Writer.BeginObject("obj"sv); + Writer.AddBinaryAttachment("data", OutAttachmentKey); + Writer.EndObject(); + CbObject CacheRecord = Writer.Save().AsObject(); + + OutRecordKey = IoHash::HashBuffer(CacheRecord.GetBuffer().GetView()); + + zen::CbPackage Package; + Package.SetObject(CacheRecord); + Package.AddAttachment(zen::CbAttachment(CompressedData, OutAttachmentKey)); + + return Package; + }; + + auto ToBuffer = [](zen::CbPackage Package) -> zen::IoBuffer { + zen::BinaryWriter MemStream; + Package.Save(MemStream); + + return zen::IoBuffer(zen::IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + }; + + SUBCASE("query - 'local' does not query upstream (binary)") + { + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamInst(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalInst(TestEnv); + const auto Bucket = "legacy"sv; + + UpstreamCfg.Spawn(UpstreamInst); + LocalCfg.Spawn(LocalInst); + + zen::IoHash Key; + auto BinaryValue = GenerateData(1024, Key); + + // Store binary cache value upstream + { + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)}, + cpr::Body{(const char*)BinaryValue.GetData(), BinaryValue.GetSize()}, + cpr::Header{{"Content-Type", "application/octet-stream"}}); + CHECK(Result.status_code == 201); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=QueryLocal,Store", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/octet-stream"}}); + CHECK(Result.status_code == 404); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,Store", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/octet-stream"}}); + CHECK(Result.status_code == 200); + } + } + + SUBCASE("store - 'local' does not store upstream (binary)") + { + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamInst(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalInst(TestEnv); + const auto Bucket = "legacy"sv; + + UpstreamCfg.Spawn(UpstreamInst); + LocalCfg.Spawn(LocalInst); + + zen::IoHash Key; + auto BinaryValue = GenerateData(1024, Key); + + // Store binary cache value locally + { + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,StoreLocal", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Body{(const char*)BinaryValue.GetData(), BinaryValue.GetSize()}, + cpr::Header{{"Content-Type", "application/octet-stream"}}); + CHECK(Result.status_code == 201); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/octet-stream"}}); + CHECK(Result.status_code == 404); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/octet-stream"}}); + CHECK(Result.status_code == 200); + } + } + + SUBCASE("store - 'local/remote' stores local and upstream (binary)") + { + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamInst(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalInst(TestEnv); + const auto Bucket = "legacy"sv; + + UpstreamCfg.Spawn(UpstreamInst); + LocalCfg.Spawn(LocalInst); + + zen::IoHash Key; + auto BinaryValue = GenerateData(1024, Key); + + // Store binary cache value locally and upstream + { + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,Store", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Body{(const char*)BinaryValue.GetData(), BinaryValue.GetSize()}, + cpr::Header{{"Content-Type", "application/octet-stream"}}); + CHECK(Result.status_code == 201); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/octet-stream"}}); + CHECK(Result.status_code == 200); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/octet-stream"}}); + CHECK(Result.status_code == 200); + } + } + + SUBCASE("query - 'local' does not query upstream (cppackage)") + { + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamInst(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalInst(TestEnv); + const auto Bucket = "legacy"sv; + + UpstreamCfg.Spawn(UpstreamInst); + LocalCfg.Spawn(LocalInst); + + zen::IoHash Key; + zen::IoHash PayloadId; + zen::CbPackage Package = GeneratePackage(Key, PayloadId); + auto Buf = ToBuffer(Package); + + // Store package upstream + { + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)}, + cpr::Body{(const char*)Buf.GetData(), Buf.GetSize()}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 201); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=QueryLocal,Store", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 404); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,Store", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 200); + } + } + + SUBCASE("store - 'local' does not store upstream (cbpackge)") + { + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamInst(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalInst(TestEnv); + const auto Bucket = "legacy"sv; + + UpstreamCfg.Spawn(UpstreamInst); + LocalCfg.Spawn(LocalInst); + + zen::IoHash Key; + zen::IoHash PayloadId; + zen::CbPackage Package = GeneratePackage(Key, PayloadId); + auto Buf = ToBuffer(Package); + + // Store packge locally + { + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,StoreLocal", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Body{(const char*)Buf.GetData(), Buf.GetSize()}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 201); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 404); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 200); + } + } + + SUBCASE("store - 'local/remote' stores local and upstream (cbpackage)") + { + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamInst(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalInst(TestEnv); + const auto Bucket = "legacy"sv; + + UpstreamCfg.Spawn(UpstreamInst); + LocalCfg.Spawn(LocalInst); + + zen::IoHash Key; + zen::IoHash PayloadId; + zen::CbPackage Package = GeneratePackage(Key, PayloadId); + auto Buf = ToBuffer(Package); + + // Store package locally and upstream + { + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,Store", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Body{(const char*)Buf.GetData(), Buf.GetSize()}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 201); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 200); + } + + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalCfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 200); + } + } + + SUBCASE("skip - 'data' returns cache record without attachments/empty payload") + { + ZenConfig Cfg = ZenConfig::New(); + ZenServerInstance Instance(TestEnv); + const auto Bucket = "test"sv; + + Cfg.Spawn(Instance); + + zen::IoHash Key; + zen::IoHash PayloadId; + zen::CbPackage Package = GeneratePackage(Key, PayloadId); + auto Buf = ToBuffer(Package); + + // Store package + { + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", Cfg.BaseUri, Bucket, Key)}, + cpr::Body{(const char*)Buf.GetData(), Buf.GetSize()}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}}); + CHECK(Result.status_code == 201); + } + + // Get package + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=Default,SkipData", Cfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/x-ue-cbpkg"}}); + CHECK(IsHttpSuccessCode(Result.status_code)); + IoBuffer Buffer(IoBuffer::Wrap, Result.text.c_str(), Result.text.size()); + CbPackage ResponsePackage; + CHECK(ResponsePackage.TryLoad(Buffer)); + CHECK(ResponsePackage.GetAttachments().size() == 0); + } + + // Get record + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=Default,SkipData", Cfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/x-ue-cb"}}); + CHECK(IsHttpSuccessCode(Result.status_code)); + IoBuffer Buffer(IoBuffer::Wrap, Result.text.c_str(), Result.text.size()); + CbObject ResponseObject = zen::LoadCompactBinaryObject(Buffer); + CHECK((bool)ResponseObject); + } + + // Get payload + { + cpr::Response Result = + cpr::Get(cpr::Url{fmt::format("{}/{}/{}/{}?Policy=Default,SkipData", Cfg.BaseUri, Bucket, Key, PayloadId)}, + cpr::Header{{"Accept", "application/x-ue-comp"}}); + CHECK(IsHttpSuccessCode(Result.status_code)); + CHECK(Result.text.size() == 0); + } + } + + SUBCASE("skip - 'data' returns empty binary value") + { + ZenConfig Cfg = ZenConfig::New(); + ZenServerInstance Instance(TestEnv); + const auto Bucket = "test"sv; + + Cfg.Spawn(Instance); + + zen::IoHash Key; + auto BinaryValue = GenerateData(1024, Key); + + // Store binary cache value + { + cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", Cfg.BaseUri, Bucket, Key)}, + cpr::Body{(const char*)BinaryValue.GetData(), BinaryValue.GetSize()}, + cpr::Header{{"Content-Type", "application/octet-stream"}}); + CHECK(Result.status_code == 201); + } + + // Get package + { + cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=Default,SkipData", Cfg.BaseUri, Bucket, Key)}, + cpr::Header{{"Accept", "application/octet-stream"}}); + CHECK(IsHttpSuccessCode(Result.status_code)); + CHECK(Result.text.size() == 0); + } + } +} + +TEST_CASE("zcache.rpc") +{ + using namespace std::literals; + + auto AppendCacheRecord = [](cacherequests::PutCacheRecordsRequest& Request, + const zen::CacheKey& CacheKey, + size_t PayloadSize, + CachePolicy RecordPolicy) { + std::vector<uint8_t> Data; + Data.resize(PayloadSize); + uint32_t DataSeed = *reinterpret_cast<const uint32_t*>(&CacheKey.Hash.Hash[0]); + uint16_t* DataPtr = reinterpret_cast<uint16_t*>(Data.data()); + for (size_t Idx = 0; Idx < PayloadSize / 2; ++Idx) + { + DataPtr[Idx] = static_cast<uint16_t>((Idx + DataSeed) % 0xffffu); + } + if (PayloadSize & 1) + { + Data[PayloadSize - 1] = static_cast<uint8_t>((PayloadSize - 1) & 0xff); + } + CompressedBuffer Value = zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size())); + Request.Requests.push_back({.Key = CacheKey, .Values = {{.Id = Oid::NewOid(), .Body = std::move(Value)}}, .Policy = RecordPolicy}); + }; + + auto PutCacheRecords = [&AppendCacheRecord](std::string_view BaseUri, + std::string_view Namespace, + std::string_view Bucket, + size_t Num, + size_t PayloadSize = 1024, + size_t KeyOffset = 1) -> std::vector<CacheKey> { + std::vector<zen::CacheKey> OutKeys; + + for (uint32_t Key = 1; Key <= Num; ++Key) + { + zen::IoHash KeyHash; + ((uint32_t*)(KeyHash.Hash))[0] = gsl::narrow<uint32_t>(KeyOffset + Key); + const zen::CacheKey CacheKey = zen::CacheKey::Create(Bucket, KeyHash); + + cacherequests::PutCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic, .Namespace = std::string(Namespace)}; + AppendCacheRecord(Request, CacheKey, PayloadSize, CachePolicy::Default); + OutKeys.push_back(CacheKey); + + CbPackage Package; + CHECK(Request.Format(Package)); + + IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer(); + cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); + + CHECK(Result.status_code == 200); + } + + return OutKeys; + }; + + struct GetCacheRecordResult + { + zen::CbPackage Response; + cacherequests::GetCacheRecordsResult Result; + bool Success; + }; + + auto GetCacheRecords = [](std::string_view BaseUri, + std::string_view Namespace, + std::span<zen::CacheKey> Keys, + zen::CachePolicy Policy, + zen::RpcAcceptOptions AcceptOptions = zen::RpcAcceptOptions::kNone, + int Pid = 0) -> GetCacheRecordResult { + cacherequests::GetCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic, + .AcceptOptions = static_cast<uint16_t>(AcceptOptions), + .ProcessPid = Pid, + .DefaultPolicy = Policy, + .Namespace = std::string(Namespace)}; + for (const CacheKey& Key : Keys) + { + Request.Requests.push_back({.Key = Key}); + } + + CbObjectWriter RequestWriter; + CHECK(Request.Format(RequestWriter)); + + BinaryWriter Body; + RequestWriter.Save(Body); + + cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); + + GetCacheRecordResult OutResult; + + if (Result.status_code == 200) + { + CbPackage Response = ParsePackageMessage(zen::IoBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size())); + if (!Response.IsNull()) + { + OutResult.Response = std::move(Response); + CHECK(OutResult.Result.Parse(OutResult.Response)); + OutResult.Success = true; + } + } + + return OutResult; + }; + + SUBCASE("get cache records") + { + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const uint16_t PortNumber = 13337; + const auto BaseUri = fmt::format("http://localhost:{}/z$", PortNumber); + + ZenServerInstance Inst(TestEnv); + Inst.SetTestDir(TestDir); + Inst.SpawnServer(PortNumber); + Inst.WaitUntilReady(); + + CachePolicy Policy = CachePolicy::Default; + std::vector<zen::CacheKey> Keys = PutCacheRecords(BaseUri, "ue4.ddc"sv, "mastodon"sv, 128); + GetCacheRecordResult Result = GetCacheRecords(BaseUri, "ue4.ddc"sv, Keys, Policy); + + CHECK(Result.Result.Results.size() == Keys.size()); + + for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results) + { + const CacheKey& ExpectedKey = Keys[Index++]; + CHECK(Record); + CHECK(Record->Key == ExpectedKey); + CHECK(Record->Values.size() == 1); + + for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values) + { + CHECK(Value.Body); + } + } + } + + SUBCASE("get missing cache records") + { + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const uint16_t PortNumber = 13337; + const auto BaseUri = fmt::format("http://localhost:{}/z$", PortNumber); + + ZenServerInstance Inst(TestEnv); + Inst.SetTestDir(TestDir); + Inst.SpawnServer(PortNumber); + Inst.WaitUntilReady(); + + CachePolicy Policy = CachePolicy::Default; + std::vector<zen::CacheKey> ExistingKeys = PutCacheRecords(BaseUri, "ue4.ddc"sv, "mastodon"sv, 128); + std::vector<zen::CacheKey> Keys; + + for (const zen::CacheKey& Key : ExistingKeys) + { + Keys.push_back(Key); + Keys.push_back(CacheKey::Create("missing"sv, IoHash::Zero)); + } + + GetCacheRecordResult Result = GetCacheRecords(BaseUri, "ue4.ddc"sv, Keys, Policy); + + CHECK(Result.Result.Results.size() == Keys.size()); + + size_t KeyIndex = 0; + for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results) + { + const bool Missing = Index++ % 2 != 0; + + if (Missing) + { + CHECK(!Record); + } + else + { + const CacheKey& ExpectedKey = ExistingKeys[KeyIndex++]; + CHECK(Record->Key == ExpectedKey); + for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values) + { + CHECK(Value.Body); + } + } + } + } + + SUBCASE("policy - 'QueryLocal' does not query upstream") + { + using namespace utils; + + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamServer(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalServer(TestEnv); + + SpawnServer(UpstreamServer, UpstreamCfg); + SpawnServer(LocalServer, LocalCfg); + + std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, "ue4.ddc"sv, "mastodon"sv, 4); + + CachePolicy Policy = CachePolicy::QueryLocal; + GetCacheRecordResult Result = GetCacheRecords(LocalCfg.BaseUri, "ue4.ddc"sv, Keys, Policy); + + CHECK(Result.Result.Results.size() == Keys.size()); + + for (const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results) + { + CHECK(!Record); + } + } + + SUBCASE("policy - 'QueryRemote' does query upstream") + { + using namespace utils; + + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamServer(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalServer(TestEnv); + + SpawnServer(UpstreamServer, UpstreamCfg); + SpawnServer(LocalServer, LocalCfg); + + std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, "ue4.ddc"sv, "mastodon"sv, 4); + + CachePolicy Policy = (CachePolicy::QueryLocal | CachePolicy::QueryRemote); + GetCacheRecordResult Result = GetCacheRecords(LocalCfg.BaseUri, "ue4.ddc"sv, Keys, Policy); + + CHECK(Result.Result.Results.size() == Keys.size()); + + for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results) + { + CHECK(Record); + const CacheKey& ExpectedKey = Keys[Index++]; + CHECK(Record->Key == ExpectedKey); + } + } + + SUBCASE("RpcAcceptOptions") + { + using namespace utils; + + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const uint16_t PortNumber = 13337; + const auto BaseUri = fmt::format("http://localhost:{}/z$", PortNumber); + + ZenServerInstance Inst(TestEnv); + Inst.SetTestDir(TestDir); + Inst.SpawnServer(PortNumber); + Inst.WaitUntilReady(); + + std::vector<zen::CacheKey> SmallKeys = PutCacheRecords(BaseUri, "ue4.ddc"sv, "mastodon"sv, 4, 1024); + std::vector<zen::CacheKey> LargeKeys = PutCacheRecords(BaseUri, "ue4.ddc"sv, "mastodon"sv, 4, 1024 * 1024 * 16, SmallKeys.size()); + + std::vector<zen::CacheKey> Keys(SmallKeys.begin(), SmallKeys.end()); + Keys.insert(Keys.end(), LargeKeys.begin(), LargeKeys.end()); + + { + GetCacheRecordResult Result = GetCacheRecords(BaseUri, "ue4.ddc"sv, Keys, CachePolicy::Default); + + CHECK(Result.Result.Results.size() == Keys.size()); + + for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results) + { + CHECK(Record); + const CacheKey& ExpectedKey = Keys[Index++]; + CHECK(Record->Key == ExpectedKey); + for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values) + { + const IoBuffer& Body = Value.Body.GetCompressed().Flatten().AsIoBuffer(); + IoBufferFileReference Ref; + bool IsFileRef = Body.GetFileReference(Ref); + CHECK(!IsFileRef); + } + } + } + + // File path, but only for large files + { + GetCacheRecordResult Result = + GetCacheRecords(BaseUri, "ue4.ddc"sv, Keys, CachePolicy::Default, RpcAcceptOptions::kAllowLocalReferences); + + CHECK(Result.Result.Results.size() == Keys.size()); + + for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results) + { + CHECK(Record); + const CacheKey& ExpectedKey = Keys[Index++]; + CHECK(Record->Key == ExpectedKey); + for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values) + { + const IoBuffer& Body = Value.Body.GetCompressed().Flatten().AsIoBuffer(); + IoBufferFileReference Ref; + bool IsFileRef = Body.GetFileReference(Ref); + CHECK(IsFileRef == (Body.Size() > 1024)); + } + } + } + + // File path, for all files + { + GetCacheRecordResult Result = + GetCacheRecords(BaseUri, + "ue4.ddc"sv, + Keys, + CachePolicy::Default, + RpcAcceptOptions::kAllowLocalReferences | RpcAcceptOptions::kAllowPartialLocalReferences); + + CHECK(Result.Result.Results.size() == Keys.size()); + + for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results) + { + CHECK(Record); + const CacheKey& ExpectedKey = Keys[Index++]; + CHECK(Record->Key == ExpectedKey); + for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values) + { + const IoBuffer& Body = Value.Body.GetCompressed().Flatten().AsIoBuffer(); + IoBufferFileReference Ref; + bool IsFileRef = Body.GetFileReference(Ref); + CHECK(IsFileRef); + } + } + } + + // File handle, but only for large files + { + GetCacheRecordResult Result = GetCacheRecords(BaseUri, + "ue4.ddc"sv, + Keys, + CachePolicy::Default, + RpcAcceptOptions::kAllowLocalReferences, + GetCurrentProcessId()); + + CHECK(Result.Result.Results.size() == Keys.size()); + + for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results) + { + CHECK(Record); + const CacheKey& ExpectedKey = Keys[Index++]; + CHECK(Record->Key == ExpectedKey); + for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values) + { + const IoBuffer& Body = Value.Body.GetCompressed().Flatten().AsIoBuffer(); + IoBufferFileReference Ref; + bool IsFileRef = Body.GetFileReference(Ref); + CHECK(IsFileRef == (Body.Size() > 1024)); + } + } + } + + // File handle, for all files + { + GetCacheRecordResult Result = + GetCacheRecords(BaseUri, + "ue4.ddc"sv, + Keys, + CachePolicy::Default, + RpcAcceptOptions::kAllowLocalReferences | RpcAcceptOptions::kAllowPartialLocalReferences, + GetCurrentProcessId()); + + CHECK(Result.Result.Results.size() == Keys.size()); + + for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results) + { + CHECK(Record); + const CacheKey& ExpectedKey = Keys[Index++]; + CHECK(Record->Key == ExpectedKey); + for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values) + { + const IoBuffer& Body = Value.Body.GetCompressed().Flatten().AsIoBuffer(); + IoBufferFileReference Ref; + bool IsFileRef = Body.GetFileReference(Ref); + CHECK(IsFileRef); + } + } + } + } +} + +TEST_CASE("zcache.failing.upstream") +{ + // This is an exploratory test that takes a long time to run, so lets skip it by default + if (true) + { + return; + } + + using namespace std::literals; + using namespace utils; + + const uint16_t Upstream1PortNumber = 13338; + ZenConfig Upstream1Cfg = ZenConfig::New(Upstream1PortNumber); + ZenServerInstance Upstream1Server(TestEnv); + + const uint16_t Upstream2PortNumber = 13339; + ZenConfig Upstream2Cfg = ZenConfig::New(Upstream2PortNumber); + ZenServerInstance Upstream2Server(TestEnv); + + std::vector<std::uint16_t> UpstreamPorts = {Upstream1PortNumber, Upstream2PortNumber}; + ZenConfig LocalCfg = ZenConfig::NewWithThreadedUpstreams(UpstreamPorts, false); + LocalCfg.Args += (" --upstream-thread-count 2"); + ZenServerInstance LocalServer(TestEnv); + const uint16_t LocalPortNumber = 13337; + const auto LocalUri = fmt::format("http://localhost:{}/z$", LocalPortNumber); + const auto Upstream1Uri = fmt::format("http://localhost:{}/z$", Upstream1PortNumber); + const auto Upstream2Uri = fmt::format("http://localhost:{}/z$", Upstream2PortNumber); + + SpawnServer(Upstream1Server, Upstream1Cfg); + SpawnServer(Upstream2Server, Upstream2Cfg); + SpawnServer(LocalServer, LocalCfg); + bool Upstream1Running = true; + bool Upstream2Running = true; + + using namespace std::literals; + + auto AppendCacheRecord = [](cacherequests::PutCacheRecordsRequest& Request, + const zen::CacheKey& CacheKey, + size_t PayloadSize, + CachePolicy RecordPolicy) { + std::vector<uint32_t> Data; + Data.resize(PayloadSize / 4); + for (uint32_t Idx = 0; Idx < PayloadSize / 4; ++Idx) + { + Data[Idx] = (*reinterpret_cast<const uint32_t*>(&CacheKey.Hash.Hash[0])) + Idx; + } + + CompressedBuffer Value = zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size() * 4)); + Request.Requests.push_back({.Key = CacheKey, .Values = {{.Id = Oid::NewOid(), .Body = std::move(Value)}}, .Policy = RecordPolicy}); + }; + + auto PutCacheRecords = [&AppendCacheRecord](std::string_view BaseUri, + std::string_view Namespace, + std::string_view Bucket, + size_t Num, + size_t KeyOffset, + size_t PayloadSize = 8192) -> std::vector<CacheKey> { + std::vector<zen::CacheKey> OutKeys; + + cacherequests::PutCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic, .Namespace = std::string(Namespace)}; + for (size_t Key = 1; Key <= Num; ++Key) + { + zen::IoHash KeyHash; + ((size_t*)(KeyHash.Hash))[0] = KeyOffset + Key; + const zen::CacheKey CacheKey = zen::CacheKey::Create(Bucket, KeyHash); + + AppendCacheRecord(Request, CacheKey, PayloadSize, CachePolicy::Default); + OutKeys.push_back(CacheKey); + } + + CbPackage Package; + CHECK(Request.Format(Package)); + + IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer(); + cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); + + if (Result.status_code != 200) + { + ZEN_DEBUG("PutCacheRecords failed with {}, reason '{}'", Result.status_code, Result.reason); + OutKeys.clear(); + } + + return OutKeys; + }; + + struct GetCacheRecordResult + { + zen::CbPackage Response; + cacherequests::GetCacheRecordsResult Result; + bool Success = false; + }; + + auto GetCacheRecords = [](std::string_view BaseUri, + std::string_view Namespace, + std::span<zen::CacheKey> Keys, + zen::CachePolicy Policy) -> GetCacheRecordResult { + cacherequests::GetCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic, + .DefaultPolicy = Policy, + .Namespace = std::string(Namespace)}; + for (const CacheKey& Key : Keys) + { + Request.Requests.push_back({.Key = Key}); + } + + CbObjectWriter RequestWriter; + CHECK(Request.Format(RequestWriter)); + + BinaryWriter Body; + RequestWriter.Save(Body); + + cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); + + GetCacheRecordResult OutResult; + + if (Result.status_code == 200) + { + CbPackage Response = ParsePackageMessage(zen::IoBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size())); + if (!Response.IsNull()) + { + OutResult.Response = std::move(Response); + CHECK(OutResult.Result.Parse(OutResult.Response)); + OutResult.Success = true; + } + } + else + { + ZEN_DEBUG("GetCacheRecords with {}, reason '{}'", Result.reason, Result.status_code); + } + + return OutResult; + }; + + // Populate with some simple data + + CachePolicy Policy = CachePolicy::Default; + + const size_t ThreadCount = 128; + const size_t KeyMultiplier = 16384; + const size_t RecordsPerRequest = 64; + WorkerThreadPool Pool(ThreadCount); + + std::atomic_size_t Completed = 0; + + auto Keys = new std::vector<CacheKey>[ThreadCount * KeyMultiplier]; + RwLock KeysLock; + + for (size_t I = 0; I < ThreadCount * KeyMultiplier; I++) + { + size_t Iteration = I; + Pool.ScheduleWork([&] { + std::vector<CacheKey> NewKeys = PutCacheRecords(LocalUri, "ue4.ddc"sv, "mastodon"sv, RecordsPerRequest, I * RecordsPerRequest); + if (NewKeys.size() != RecordsPerRequest) + { + ZEN_DEBUG("PutCacheRecords iteration {} failed", Iteration); + Completed.fetch_add(1); + return; + } + { + RwLock::ExclusiveLockScope _(KeysLock); + Keys[Iteration].swap(NewKeys); + } + Completed.fetch_add(1); + }); + } + bool UseUpstream1 = false; + while (Completed < ThreadCount * KeyMultiplier) + { + Sleep(8000); + + if (UseUpstream1) + { + if (Upstream2Running) + { + Upstream2Server.EnableTermination(); + Upstream2Server.Shutdown(); + Sleep(100); + Upstream2Running = false; + } + if (!Upstream1Running) + { + SpawnServer(Upstream1Server, Upstream1Cfg); + Upstream1Running = true; + } + UseUpstream1 = !UseUpstream1; + } + else + { + if (Upstream1Running) + { + Upstream1Server.EnableTermination(); + Upstream1Server.Shutdown(); + Sleep(100); + Upstream1Running = false; + } + if (!Upstream2Running) + { + SpawnServer(Upstream2Server, Upstream2Cfg); + Upstream2Running = true; + } + UseUpstream1 = !UseUpstream1; + } + } + + Completed = 0; + for (size_t I = 0; I < ThreadCount * KeyMultiplier; I++) + { + size_t Iteration = I; + std::vector<CacheKey>& LocalKeys = Keys[Iteration]; + if (LocalKeys.empty()) + { + Completed.fetch_add(1); + continue; + } + Pool.ScheduleWork([&] { + GetCacheRecordResult Result = GetCacheRecords(LocalUri, "ue4.ddc"sv, LocalKeys, Policy); + + if (!Result.Success) + { + ZEN_DEBUG("GetCacheRecords iteration {} failed", Iteration); + Completed.fetch_add(1); + return; + } + + if (Result.Result.Results.size() != LocalKeys.size()) + { + ZEN_DEBUG("GetCacheRecords iteration {} empty records", Iteration); + Completed.fetch_add(1); + return; + } + for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results) + { + const CacheKey& ExpectedKey = LocalKeys[Index++]; + if (!Record) + { + continue; + } + if (Record->Key != ExpectedKey) + { + continue; + } + if (Record->Values.size() != 1) + { + continue; + } + + for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values) + { + if (!Value.Body) + { + continue; + } + } + } + Completed.fetch_add(1); + }); + } + while (Completed < ThreadCount * KeyMultiplier) + { + Sleep(10); + } +} + +TEST_CASE("zcache.rpc.allpolicies") +{ + using namespace std::literals; + using namespace utils; + + ZenConfig UpstreamCfg = ZenConfig::New(13338); + ZenServerInstance UpstreamServer(TestEnv); + ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338); + ZenServerInstance LocalServer(TestEnv); + const uint16_t LocalPortNumber = 13337; + const auto BaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber); + + SpawnServer(UpstreamServer, UpstreamCfg); + SpawnServer(LocalServer, LocalCfg); + + std::string_view TestVersion = "F72150A02AE34B57A9EC91D36BA1CE08"sv; + std::string_view TestBucket = "allpoliciestest"sv; + std::string_view TestNamespace = "ue4.ddc"sv; + + // NumKeys = (2 Value vs Record)*(2 SkipData vs Default)*(2 ForceMiss vs Not)*(2 use local) + // *(2 use remote)*(2 UseValue Policy vs not)*(4 cases per type) + constexpr int NumKeys = 256; + constexpr int NumValues = 4; + Oid ValueIds[NumValues]; + IoHash Hash; + for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex) + { + ExtendableStringBuilder<16> ValueName; + ValueName << "ValueId_"sv << ValueIndex; + static_assert(sizeof(IoHash) >= sizeof(Oid)); + ValueIds[ValueIndex] = Oid::FromMemory(IoHash::HashBuffer(ValueName.Data(), ValueName.Size() * sizeof(ValueName.Data()[0])).Hash); + } + + struct KeyData; + struct UserData + { + UserData& Set(KeyData* InKeyData, int InValueIndex) + { + Data = InKeyData; + ValueIndex = InValueIndex; + return *this; + } + KeyData* Data = nullptr; + int ValueIndex = 0; + }; + struct KeyData + { + CompressedBuffer BufferValues[NumValues]; + uint64_t IntValues[NumValues]; + UserData ValueUserData[NumValues]; + bool ReceivedChunk[NumValues]; + CacheKey Key; + UserData KeyUserData; + uint32_t KeyIndex = 0; + bool GetRequestsData = true; + bool UseValueAPI = false; + bool UseValuePolicy = false; + bool ForceMiss = false; + bool UseLocal = true; + bool UseRemote = true; + bool ShouldBeHit = true; + bool ReceivedPut = false; + bool ReceivedGet = false; + bool ReceivedPutValue = false; + bool ReceivedGetValue = false; + }; + struct CachePutRequest + { + CacheKey Key; + CbObject Record; + CacheRecordPolicy Policy; + KeyData* Values; + UserData* Data; + }; + struct CachePutValueRequest + { + CacheKey Key; + CompressedBuffer Value; + CachePolicy Policy; + UserData* Data; + }; + struct CacheGetRequest + { + CacheKey Key; + CacheRecordPolicy Policy; + UserData* Data; + }; + struct CacheGetValueRequest + { + CacheKey Key; + CachePolicy Policy; + UserData* Data; + }; + struct CacheGetChunkRequest + { + CacheKey Key; + Oid ValueId; + uint64_t RawOffset; + uint64_t RawSize; + IoHash RawHash; + CachePolicy Policy; + UserData* Data; + }; + + KeyData KeyDatas[NumKeys]; + std::vector<CachePutRequest> PutRequests; + std::vector<CachePutValueRequest> PutValueRequests; + std::vector<CacheGetRequest> GetRequests; + std::vector<CacheGetValueRequest> GetValueRequests; + std::vector<CacheGetChunkRequest> ChunkRequests; + + for (uint32_t KeyIndex = 0; KeyIndex < NumKeys; ++KeyIndex) + { + IoHashStream KeyWriter; + KeyWriter.Append(TestVersion.data(), TestVersion.length() * sizeof(TestVersion.data()[0])); + KeyWriter.Append(&KeyIndex, sizeof(KeyIndex)); + IoHash KeyHash = KeyWriter.GetHash(); + KeyData& KeyData = KeyDatas[KeyIndex]; + + KeyData.Key = CacheKey::Create(TestBucket, KeyHash); + KeyData.KeyIndex = KeyIndex; + KeyData.GetRequestsData = (KeyIndex & (1 << 1)) == 0; + KeyData.UseValueAPI = (KeyIndex & (1 << 2)) != 0; + KeyData.UseValuePolicy = (KeyIndex & (1 << 3)) != 0; + KeyData.ForceMiss = (KeyIndex & (1 << 4)) == 0; + KeyData.UseLocal = (KeyIndex & (1 << 5)) == 0; + KeyData.UseRemote = (KeyIndex & (1 << 6)) == 0; + KeyData.ShouldBeHit = !KeyData.ForceMiss && (KeyData.UseLocal || KeyData.UseRemote); + CachePolicy SharedPolicy = KeyData.UseLocal ? CachePolicy::Local : CachePolicy::None; + SharedPolicy |= KeyData.UseRemote ? CachePolicy::Remote : CachePolicy::None; + CachePolicy PutPolicy = SharedPolicy; + CachePolicy GetPolicy = SharedPolicy; + GetPolicy |= !KeyData.GetRequestsData ? CachePolicy::SkipData : CachePolicy::None; + CacheKey& Key = KeyData.Key; + + for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex) + { + KeyData.IntValues[ValueIndex] = static_cast<uint64_t>(KeyIndex) | (static_cast<uint64_t>(ValueIndex) << 32); + KeyData.BufferValues[ValueIndex] = + CompressedBuffer::Compress(SharedBuffer::MakeView(&KeyData.IntValues[ValueIndex], sizeof(KeyData.IntValues[ValueIndex]))); + KeyData.ReceivedChunk[ValueIndex] = false; + } + + UserData& KeyUserData = KeyData.KeyUserData.Set(&KeyData, -1); + for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex) + { + KeyData.ValueUserData[ValueIndex].Set(&KeyData, ValueIndex); + } + if (!KeyData.UseValueAPI) + { + CbObjectWriter Builder; + Builder.BeginObject("key"sv); + Builder << "Bucket"sv << Key.Bucket << "Hash"sv << Key.Hash; + Builder.EndObject(); + Builder.BeginArray("Values"sv); + for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex) + { + Builder.BeginObject(); + Builder.AddObjectId("Id"sv, ValueIds[ValueIndex]); + Builder.AddBinaryAttachment("RawHash"sv, KeyData.BufferValues[ValueIndex].DecodeRawHash()); + Builder.AddInteger("RawSize"sv, KeyData.BufferValues[ValueIndex].DecodeRawSize()); + Builder.EndObject(); + } + Builder.EndArray(); + + CacheRecordPolicy PutRecordPolicy; + CacheRecordPolicy GetRecordPolicy; + if (!KeyData.UseValuePolicy) + { + PutRecordPolicy = CacheRecordPolicy(PutPolicy); + GetRecordPolicy = CacheRecordPolicy(GetPolicy); + } + else + { + // Switch the SkipData field in the Record policy so that if the CacheStore ignores the ValuePolicies + // it will use the wrong value for SkipData and fail our tests. + CacheRecordPolicyBuilder PutBuilder(PutPolicy ^ CachePolicy::SkipData); + CacheRecordPolicyBuilder GetBuilder(GetPolicy ^ CachePolicy::SkipData); + for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex) + { + PutBuilder.AddValuePolicy(ValueIds[ValueIndex], PutPolicy); + GetBuilder.AddValuePolicy(ValueIds[ValueIndex], GetPolicy); + } + PutRecordPolicy = PutBuilder.Build(); + GetRecordPolicy = GetBuilder.Build(); + } + if (!KeyData.ForceMiss) + { + PutRequests.push_back({Key, Builder.Save(), PutRecordPolicy, &KeyData, &KeyUserData}); + } + GetRequests.push_back({Key, GetRecordPolicy, &KeyUserData}); + for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex) + { + UserData& ValueUserData = KeyData.ValueUserData[ValueIndex]; + ChunkRequests.push_back({Key, ValueIds[ValueIndex], 0, UINT64_MAX, IoHash(), GetPolicy, &ValueUserData}); + } + } + else + { + if (!KeyData.ForceMiss) + { + PutValueRequests.push_back({Key, KeyData.BufferValues[0], PutPolicy, &KeyUserData}); + } + GetValueRequests.push_back({Key, GetPolicy, &KeyUserData}); + ChunkRequests.push_back({Key, Oid::Zero, 0, UINT64_MAX, IoHash(), GetPolicy, &KeyUserData}); + } + } + + // PutCacheRecords + { + CachePolicy BatchDefaultPolicy = CachePolicy::Default; + cacherequests::PutCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic, + .DefaultPolicy = BatchDefaultPolicy, + .Namespace = std::string(TestNamespace)}; + Request.Requests.reserve(PutRequests.size()); + for (CachePutRequest& PutRequest : PutRequests) + { + cacherequests::PutCacheRecordRequest& RecordRequest = Request.Requests.emplace_back(); + RecordRequest.Key = PutRequest.Key; + RecordRequest.Policy = PutRequest.Policy; + RecordRequest.Values.reserve(NumValues); + for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex) + { + RecordRequest.Values.push_back({.Id = ValueIds[ValueIndex], .Body = PutRequest.Values->BufferValues[ValueIndex]}); + } + PutRequest.Data->Data->ReceivedPut = true; + } + + CbPackage Package; + CHECK(Request.Format(Package)); + IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer(); + cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); + CHECK_MESSAGE(Result.status_code == 200, "PutCacheRecords unexpectedly failed."); + } + + // PutCacheValues + { + CachePolicy BatchDefaultPolicy = CachePolicy::Default; + + cacherequests::PutCacheValuesRequest Request = {.AcceptMagic = kCbPkgMagic, + .DefaultPolicy = BatchDefaultPolicy, + .Namespace = std::string(TestNamespace)}; + Request.Requests.reserve(PutValueRequests.size()); + for (CachePutValueRequest& PutRequest : PutValueRequests) + { + Request.Requests.push_back({.Key = PutRequest.Key, .Body = PutRequest.Value, .Policy = PutRequest.Policy}); + PutRequest.Data->Data->ReceivedPutValue = true; + } + + CbPackage Package; + CHECK(Request.Format(Package)); + + IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer(); + cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); + CHECK_MESSAGE(Result.status_code == 200, "PutCacheValues unexpectedly failed."); + } + + for (KeyData& KeyData : KeyDatas) + { + if (!KeyData.ForceMiss) + { + if (!KeyData.UseValueAPI) + { + CHECK_MESSAGE(KeyData.ReceivedPut, WriteToString<32>("Key ", KeyData.KeyIndex, " was unexpectedly not put.").c_str()); + } + else + { + CHECK_MESSAGE(KeyData.ReceivedPutValue, + WriteToString<32>("Key ", KeyData.KeyIndex, " was unexpectedly not put to ValueAPI.").c_str()); + } + } + } + + // GetCacheRecords + { + CachePolicy BatchDefaultPolicy = CachePolicy::Default; + cacherequests::GetCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic, + .DefaultPolicy = BatchDefaultPolicy, + .Namespace = std::string(TestNamespace)}; + Request.Requests.reserve(GetRequests.size()); + for (CacheGetRequest& GetRequest : GetRequests) + { + Request.Requests.push_back({.Key = GetRequest.Key, .Policy = GetRequest.Policy}); + } + + CbPackage Package; + CHECK(Request.Format(Package)); + IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer(); + cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); + CHECK_MESSAGE(Result.status_code == 200, "GetCacheRecords unexpectedly failed."); + CbPackage Response = ParsePackageMessage(zen::IoBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size())); + bool Loaded = !Response.IsNull(); + CHECK_MESSAGE(Loaded, "GetCacheRecords response failed to load."); + cacherequests::GetCacheRecordsResult RequestResult; + CHECK(RequestResult.Parse(Response)); + CHECK_MESSAGE(RequestResult.Results.size() == GetRequests.size(), "GetCacheRecords response count did not match request count."); + for (int Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& RecordResult : RequestResult.Results) + { + bool Succeeded = RecordResult.has_value(); + CacheGetRequest& GetRequest = GetRequests[Index++]; + KeyData* KeyData = GetRequest.Data->Data; + KeyData->ReceivedGet = true; + WriteToString<32> Name("Get(", KeyData->KeyIndex, ")"); + if (KeyData->ShouldBeHit) + { + CHECK_MESSAGE(Succeeded, WriteToString<32>(Name, " unexpectedly failed.").c_str()); + } + else if (KeyData->ForceMiss) + { + CHECK_MESSAGE(!Succeeded, WriteToString<32>(Name, " unexpectedly succeeded.").c_str()); + } + if (!KeyData->ForceMiss && Succeeded) + { + CHECK_MESSAGE(RecordResult->Values.size() == NumValues, + WriteToString<32>(Name, " number of values did not match.").c_str()); + for (const cacherequests::GetCacheRecordResultValue& Value : RecordResult->Values) + { + int ExpectedValueIndex = 0; + for (; ExpectedValueIndex < NumValues; ++ExpectedValueIndex) + { + if (ValueIds[ExpectedValueIndex] == Value.Id) + { + break; + } + } + CHECK_MESSAGE(ExpectedValueIndex < NumValues, WriteToString<32>(Name, " could not find matching ValueId.").c_str()); + + WriteToString<32> ValueName("Get(", KeyData->KeyIndex, ",", ExpectedValueIndex, ")"); + + CompressedBuffer ExpectedValue = KeyData->BufferValues[ExpectedValueIndex]; + CHECK_MESSAGE(Value.RawHash == ExpectedValue.DecodeRawHash(), + WriteToString<32>(ValueName, " RawHash did not match.").c_str()); + CHECK_MESSAGE(Value.RawSize == ExpectedValue.DecodeRawSize(), + WriteToString<32>(ValueName, " RawSize did not match.").c_str()); + + if (KeyData->GetRequestsData) + { + SharedBuffer Buffer = Value.Body.Decompress(); + CHECK_MESSAGE(Buffer.GetSize() == Value.RawSize, + WriteToString<32>(ValueName, " BufferSize did not match RawSize.").c_str()); + uint64_t ActualIntValue = ((const uint64_t*)Buffer.GetData())[0]; + uint64_t ExpectedIntValue = KeyData->IntValues[ExpectedValueIndex]; + CHECK_MESSAGE(ActualIntValue == ExpectedIntValue, WriteToString<32>(ValueName, " had unexpected data.").c_str()); + } + } + } + } + } + + // GetCacheValues + { + CachePolicy BatchDefaultPolicy = CachePolicy::Default; + + cacherequests::GetCacheValuesRequest GetCacheValuesRequest = {.AcceptMagic = kCbPkgMagic, + .DefaultPolicy = BatchDefaultPolicy, + .Namespace = std::string(TestNamespace)}; + GetCacheValuesRequest.Requests.reserve(GetValueRequests.size()); + for (CacheGetValueRequest& GetRequest : GetValueRequests) + { + GetCacheValuesRequest.Requests.push_back({.Key = GetRequest.Key, .Policy = GetRequest.Policy}); + } + + CbPackage Package; + CHECK(GetCacheValuesRequest.Format(Package)); + + IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer(); + cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); + CHECK_MESSAGE(Result.status_code == 200, "GetCacheValues unexpectedly failed."); + IoBuffer MessageBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size()); + CbPackage Response = ParsePackageMessage(MessageBuffer); + bool Loaded = !Response.IsNull(); + CHECK_MESSAGE(Loaded, "GetCacheValues response failed to load."); + cacherequests::GetCacheValuesResult GetCacheValuesResult; + CHECK(GetCacheValuesResult.Parse(Response)); + for (int Index = 0; const cacherequests::CacheValueResult& ValueResult : GetCacheValuesResult.Results) + { + bool Succeeded = ValueResult.RawHash != IoHash::Zero; + CacheGetValueRequest& Request = GetValueRequests[Index++]; + KeyData* KeyData = Request.Data->Data; + KeyData->ReceivedGetValue = true; + WriteToString<32> Name("GetValue("sv, KeyData->KeyIndex, ")"sv); + + if (KeyData->ShouldBeHit) + { + CHECK_MESSAGE(Succeeded, WriteToString<32>(Name, " unexpectedly failed.").c_str()); + } + else if (KeyData->ForceMiss) + { + CHECK_MESSAGE(!Succeeded, WriteToString<32>(Name, "unexpectedly succeeded.").c_str()); + } + if (!KeyData->ForceMiss && Succeeded) + { + CompressedBuffer ExpectedValue = KeyData->BufferValues[0]; + CHECK_MESSAGE(ValueResult.RawHash == ExpectedValue.DecodeRawHash(), + WriteToString<32>(Name, " RawHash did not match.").c_str()); + CHECK_MESSAGE(ValueResult.RawSize == ExpectedValue.DecodeRawSize(), + WriteToString<32>(Name, " RawSize did not match.").c_str()); + + if (KeyData->GetRequestsData) + { + SharedBuffer Buffer = ValueResult.Body.Decompress(); + CHECK_MESSAGE(Buffer.GetSize() == ValueResult.RawSize, + WriteToString<32>(Name, " BufferSize did not match RawSize.").c_str()); + uint64_t ActualIntValue = ((const uint64_t*)Buffer.GetData())[0]; + uint64_t ExpectedIntValue = KeyData->IntValues[0]; + CHECK_MESSAGE(ActualIntValue == ExpectedIntValue, WriteToString<32>(Name, " had unexpected data.").c_str()); + } + } + } + } + + // GetCacheChunks + { + std::sort(ChunkRequests.begin(), ChunkRequests.end(), [](CacheGetChunkRequest& A, CacheGetChunkRequest& B) { + return A.Key.Hash < B.Key.Hash; + }); + CachePolicy BatchDefaultPolicy = CachePolicy::Default; + cacherequests::GetCacheChunksRequest GetCacheChunksRequest = {.AcceptMagic = kCbPkgMagic, + .DefaultPolicy = BatchDefaultPolicy, + .Namespace = std::string(TestNamespace)}; + GetCacheChunksRequest.Requests.reserve(ChunkRequests.size()); + for (CacheGetChunkRequest& ChunkRequest : ChunkRequests) + { + GetCacheChunksRequest.Requests.push_back({.Key = ChunkRequest.Key, + .ValueId = ChunkRequest.ValueId, + .ChunkId = IoHash(), + .RawOffset = ChunkRequest.RawOffset, + .RawSize = ChunkRequest.RawSize, + .Policy = ChunkRequest.Policy}); + } + CbPackage Package; + CHECK(GetCacheChunksRequest.Format(Package)); + + IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer(); + cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)}, + cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}, + cpr::Body{(const char*)Body.GetData(), Body.GetSize()}); + CHECK_MESSAGE(Result.status_code == 200, "GetCacheChunks unexpectedly failed."); + CbPackage Response = ParsePackageMessage(zen::IoBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size())); + bool Loaded = !Response.IsNull(); + CHECK_MESSAGE(Loaded, "GetCacheChunks response failed to load."); + cacherequests::GetCacheChunksResult GetCacheChunksResult; + CHECK(GetCacheChunksResult.Parse(Response)); + CHECK_MESSAGE(GetCacheChunksResult.Results.size() == ChunkRequests.size(), + "GetCacheChunks response count did not match request count."); + + for (int Index = 0; const cacherequests::CacheValueResult& ValueResult : GetCacheChunksResult.Results) + { + bool Succeeded = ValueResult.RawHash != IoHash::Zero; + + CacheGetChunkRequest& Request = ChunkRequests[Index++]; + KeyData* KeyData = Request.Data->Data; + int ValueIndex = Request.Data->ValueIndex >= 0 ? Request.Data->ValueIndex : 0; + KeyData->ReceivedChunk[ValueIndex] = true; + WriteToString<32> Name("GetChunks("sv, KeyData->KeyIndex, ","sv, ValueIndex, ")"sv); + + if (KeyData->ShouldBeHit) + { + CHECK_MESSAGE(Succeeded, WriteToString<256>(Name, " unexpectedly failed."sv).c_str()); + } + else if (KeyData->ForceMiss) + { + CHECK_MESSAGE(!Succeeded, WriteToString<256>(Name, " unexpectedly succeeded."sv).c_str()); + } + if (KeyData->ShouldBeHit && Succeeded) + { + CompressedBuffer ExpectedValue = KeyData->BufferValues[ValueIndex]; + CHECK_MESSAGE(ValueResult.RawHash == ExpectedValue.DecodeRawHash(), + WriteToString<32>(Name, " had unexpected RawHash.").c_str()); + CHECK_MESSAGE(ValueResult.RawSize == ExpectedValue.DecodeRawSize(), + WriteToString<32>(Name, " had unexpected RawSize.").c_str()); + + if (KeyData->GetRequestsData) + { + SharedBuffer Buffer = ValueResult.Body.Decompress(); + CHECK_MESSAGE(Buffer.GetSize() == ValueResult.RawSize, + WriteToString<32>(Name, " BufferSize did not match RawSize.").c_str()); + uint64_t ActualIntValue = ((const uint64_t*)Buffer.GetData())[0]; + uint64_t ExpectedIntValue = KeyData->IntValues[ValueIndex]; + CHECK_MESSAGE(ActualIntValue == ExpectedIntValue, WriteToString<32>(Name, " had unexpected data.").c_str()); + } + } + } + } + + for (KeyData& KeyData : KeyDatas) + { + if (!KeyData.UseValueAPI) + { + CHECK_MESSAGE(KeyData.ReceivedGet, WriteToString<32>("Get(", KeyData.KeyIndex, ") was unexpectedly not received.").c_str()); + for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex) + { + CHECK_MESSAGE( + KeyData.ReceivedChunk[ValueIndex], + WriteToString<32>("GetChunks(", KeyData.KeyIndex, ",", ValueIndex, ") was unexpectedly not received.").c_str()); + } + } + else + { + CHECK_MESSAGE(KeyData.ReceivedGetValue, + WriteToString<32>("GetValue(", KeyData.KeyIndex, ") was unexpectedly not received.").c_str()); + CHECK_MESSAGE(KeyData.ReceivedChunk[0], + WriteToString<32>("GetChunks(", KeyData.KeyIndex, ") was unexpectedly not received.").c_str()); + } + } +} + +class ZenServerTestHelper +{ +public: + ZenServerTestHelper(std::string_view HelperId, int ServerCount) : m_HelperId{HelperId}, m_ServerCount{ServerCount} {} + ~ZenServerTestHelper() {} + + void SpawnServers(std::string_view AdditionalServerArgs = std::string_view()) + { + SpawnServers([](ZenServerInstance&) {}, AdditionalServerArgs); + } + + void SpawnServers(auto&& Callback, std::string_view AdditionalServerArgs) + { + ZEN_INFO("{}: spawning {} server instances", m_HelperId, m_ServerCount); + + m_Instances.resize(m_ServerCount); + + for (int i = 0; i < m_ServerCount; ++i) + { + auto& Instance = m_Instances[i]; + + Instance = std::make_unique<ZenServerInstance>(TestEnv); + Instance->SetTestDir(TestEnv.CreateNewTestDir()); + + Callback(*Instance); + + Instance->SpawnServer(13337 + i, AdditionalServerArgs); + } + + for (int i = 0; i < m_ServerCount; ++i) + { + auto& Instance = m_Instances[i]; + + Instance->WaitUntilReady(); + } + } + + ZenServerInstance& GetInstance(int Index) { return *m_Instances[Index]; } + +private: + std::string m_HelperId; + int m_ServerCount = 0; + std::vector<std::unique_ptr<ZenServerInstance> > m_Instances; +}; + +class IoDispatcher +{ +public: + IoDispatcher(asio::io_context& IoCtx) : m_IoCtx(IoCtx) {} + ~IoDispatcher() { Stop(); } + + void Run() + { + Stop(); + + m_Running = true; + + m_IoThread = std::thread([this]() { + try + { + m_IoCtx.run(); + } + catch (std::exception& Error) + { + m_Error = Error; + } + }); + } + + void Stop() + { + if (m_Running) + { + m_Running = false; + + if (m_IoThread.joinable()) + { + m_IoThread.join(); + } + } + } + + bool IsRunning() const { return m_Running; } + + const std::exception& Error() { return m_Error; } + +private: + asio::io_context& m_IoCtx; + std::thread m_IoThread; + std::exception m_Error; + std::atomic_bool m_Running{false}; +}; + +TEST_CASE("http.basics") +{ + using namespace std::literals; + + ZenServerTestHelper Servers{"http.basics"sv, 1}; + Servers.SpawnServers(); + + ZenServerInstance& Instance = Servers.GetInstance(0); + const std::string BaseUri = Instance.GetBaseUri(); + + { + cpr::Response r = cpr::Get(cpr::Url{fmt::format("{}/testing/hello", BaseUri)}); + CHECK(IsHttpSuccessCode(r.status_code)); + } + + { + cpr::Response r = cpr::Post(cpr::Url{fmt::format("{}/testing/hello", BaseUri)}); + CHECK_EQ(r.status_code, 404); + } + + { + cpr::Response r = cpr::Post(cpr::Url{fmt::format("{}/testing/echo", BaseUri)}, cpr::Body{"yoyoyoyo"}); + CHECK_EQ(r.status_code, 200); + CHECK_EQ(r.text, "yoyoyoyo"); + } +} + +TEST_CASE("http.package") +{ + using namespace std::literals; + + ZenServerTestHelper Servers{"http.package"sv, 1}; + Servers.SpawnServers(); + + ZenServerInstance& Instance = Servers.GetInstance(0); + const std::string BaseUri = Instance.GetBaseUri(); + + static const uint8_t Data1[] = {0, 1, 2, 3}; + static const uint8_t Data2[] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + + zen::CompressedBuffer AttachmentData1 = zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone({Data1, 4}), + zen::OodleCompressor::NotSet, + zen::OodleCompressionLevel::None); + zen::CbAttachment Attach1{AttachmentData1, AttachmentData1.DecodeRawHash()}; + zen::CompressedBuffer AttachmentData2 = zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone({Data2, 8}), + zen::OodleCompressor::NotSet, + zen::OodleCompressionLevel::None); + zen::CbAttachment Attach2{AttachmentData2, AttachmentData2.DecodeRawHash()}; + + zen::CbObjectWriter Writer; + + Writer.AddAttachment("attach1", Attach1); + Writer.AddAttachment("attach2", Attach2); + + zen::CbObject CoreObject = Writer.Save(); + + zen::CbPackage TestPackage; + TestPackage.SetObject(CoreObject); + TestPackage.AddAttachment(Attach1); + TestPackage.AddAttachment(Attach2); + + zen::HttpClient TestClient(BaseUri); + zen::HttpClient::Response Response = TestClient.TransactPackage("/testing/package"sv, TestPackage); + + zen::CbPackage ResponsePackage = ParsePackageMessage(Response.ResponsePayload); + + CHECK_EQ(ResponsePackage, TestPackage); +} + +TEST_CASE("websocket.basic") +{ + if (true) + { + return; + } + + std::filesystem::path TestDir = TestEnv.CreateNewTestDir(); + const uint16_t PortNumber = 13337; + const auto MaxWaitTime = std::chrono::seconds(5); + + ZenServerInstance Inst(TestEnv); + Inst.SetTestDir(TestDir); + Inst.SpawnServer(PortNumber, "--websocket-port=8848"sv); + Inst.WaitUntilReady(); + + asio::io_context IoCtx; + IoDispatcher IoDispatcher(IoCtx); + auto WebSocket = WebSocketClient::Create(IoCtx); + + auto ConnectFuture = WebSocket->Connect({.Host = "127.0.0.1", .Port = 8848, .Endpoint = "/zen"}); + IoDispatcher.Run(); + + ConnectFuture.wait_for(MaxWaitTime); + CHECK(ConnectFuture.get()); + + for (size_t Idx = 0; Idx < 10; Idx++) + { + CbObjectWriter Request; + Request << "Method"sv + << "SayHello"sv; + + WebSocketMessage RequestMsg; + RequestMsg.SetMessageType(WebSocketMessageType::kRequest); + RequestMsg.SetBody(Request.Save()); + + auto ResponseFuture = WebSocket->SendRequest(std::move(RequestMsg)); + ResponseFuture.wait_for(MaxWaitTime); + + CbObject Response = ResponseFuture.get().Body().GetObject(); + std::string_view Message = Response["Result"].AsString(); + + CHECK(Message == "Hello Friend!!"sv); + } + + WebSocket->Disconnect(); + + IoCtx.stop(); + IoDispatcher.Stop(); +} + +std::string +OidAsString(const Oid& Id) +{ + StringBuilder<25> OidStringBuilder; + Id.ToString(OidStringBuilder); + return OidStringBuilder.ToString(); +} + +CbPackage +CreateOplogPackage(const Oid& Id, const std::span<const std::pair<Oid, CompressedBuffer> >& Attachments) +{ + CbPackage Package; + CbObjectWriter Object; + Object << "key"sv << OidAsString(Id); + if (!Attachments.empty()) + { + Object.BeginArray("bulkdata"); + for (const auto& Attachment : Attachments) + { + CbAttachment Attach(Attachment.second, Attachment.second.DecodeRawHash()); + Object.BeginObject(); + Object << "id"sv << Attachment.first; + Object << "type"sv + << "Standard"sv; + Object << "data"sv << Attach; + Object.EndObject(); + + Package.AddAttachment(Attach); + ZEN_DEBUG("Added attachment {}", Attach.GetHash()); + } + Object.EndArray(); + } + Package.SetObject(Object.Save()); + return Package; +}; + +std::vector<std::pair<Oid, CompressedBuffer> > +CreateAttachments(const std::span<const size_t>& Sizes) +{ + std::vector<std::pair<Oid, CompressedBuffer> > Result; + Result.reserve(Sizes.size()); + for (size_t Size : Sizes) + { + std::vector<uint8_t> Data; + Data.resize(Size); + uint16_t* DataPtr = reinterpret_cast<uint16_t*>(Data.data()); + for (size_t Idx = 0; Idx < Size / 2; ++Idx) + { + DataPtr[Idx] = static_cast<uint16_t>(Idx % 0xffffu); + } + if (Size & 1) + { + Data[Size - 1] = static_cast<uint8_t>((Size - 1) & 0xff); + } + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size())); + Result.emplace_back(std::pair<Oid, CompressedBuffer>(Oid::NewOid(), Compressed)); + } + return Result; +} + +cpr::Body +AsBody(const IoBuffer& Payload) +{ + return cpr::Body{(const char*)Payload.GetData(), Payload.Size()}; +}; + +enum CbWriterMeta +{ + BeginObject, + EndObject, + BeginArray, + EndArray +}; + +inline CbWriter& +operator<<(CbWriter& Writer, CbWriterMeta Meta) +{ + switch (Meta) + { + case BeginObject: + Writer.BeginObject(); + break; + case EndObject: + Writer.EndObject(); + break; + case BeginArray: + Writer.BeginArray(); + break; + case EndArray: + Writer.EndArray(); + break; + default: + ZEN_ASSERT(false); + } + return Writer; +} + +TEST_CASE("project.remote") +{ + using namespace std::literals; + + ZenServerTestHelper Servers("remote", 3); + Servers.SpawnServers("--debug"); + + std::vector<Oid> OpIds; + OpIds.reserve(24); + for (size_t I = 0; I < 24; ++I) + { + OpIds.emplace_back(Oid::NewOid()); + } + + std::unordered_map<Oid, std::vector<std::pair<Oid, CompressedBuffer> >, Oid::Hasher> Attachments; + { + std::vector<std::size_t> AttachmentSizes({7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, + 2257, 3685, 3489, 7194, 6151, 5482, 6217, 3511, 6738, 5061, 7537, 2759, + 1916, 8210, 2235, 4024, 1582, 5251, 491, 5464, 4607, 8135, 3767, 4045, + 4415, 5007, 8876, 6761, 3359, 8526, 4097, 4855, 8225}); + auto It = AttachmentSizes.begin(); + Attachments[OpIds[0]] = {}; + Attachments[OpIds[1]] = CreateAttachments(std::initializer_list<size_t>{*It++}); + Attachments[OpIds[2]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++, *It++}); + Attachments[OpIds[3]] = CreateAttachments(std::initializer_list<size_t>{*It++}); + Attachments[OpIds[4]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++}); + Attachments[OpIds[5]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++, *It++}); + Attachments[OpIds[6]] = CreateAttachments(std::initializer_list<size_t>{*It++}); + Attachments[OpIds[7]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++, *It++}); + Attachments[OpIds[8]] = CreateAttachments(std::initializer_list<size_t>{}); + Attachments[OpIds[9]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++, *It++}); + Attachments[OpIds[10]] = CreateAttachments(std::initializer_list<size_t>{*It++}); + Attachments[OpIds[11]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++}); + Attachments[OpIds[12]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++, *It++}); + Attachments[OpIds[13]] = CreateAttachments(std::initializer_list<size_t>{*It++}); + Attachments[OpIds[14]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++}); + Attachments[OpIds[15]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++}); + Attachments[OpIds[16]] = CreateAttachments(std::initializer_list<size_t>{}); + Attachments[OpIds[17]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++}); + Attachments[OpIds[18]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++}); + Attachments[OpIds[19]] = CreateAttachments(std::initializer_list<size_t>{}); + Attachments[OpIds[20]] = CreateAttachments(std::initializer_list<size_t>{*It++}); + Attachments[OpIds[21]] = CreateAttachments(std::initializer_list<size_t>{*It++}); + Attachments[OpIds[22]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++}); + Attachments[OpIds[23]] = CreateAttachments(std::initializer_list<size_t>{*It++}); + ZEN_ASSERT(It == AttachmentSizes.end()); + } + + auto AddOp = [](const CbObject& Op, std::unordered_map<Oid, uint32_t, Oid::Hasher>& Ops) { + XXH3_128Stream KeyHasher; + Op["key"sv].WriteToStream([&](const void* Data, size_t Size) { KeyHasher.Append(Data, Size); }); + XXH3_128 KeyHash = KeyHasher.GetHash(); + Oid Id; + memcpy(Id.OidBits, &KeyHash, sizeof Id.OidBits); + IoBuffer Buffer = Op.GetBuffer().AsIoBuffer(); + const uint32_t OpCoreHash = uint32_t(XXH3_64bits(Buffer.GetData(), Buffer.GetSize()) & 0xffffFFFF); + Ops.insert({Id, OpCoreHash}); + }; + + auto MakeProject = [](cpr::Session& Session, std::string_view UrlBase, std::string_view ProjectName) { + CbObjectWriter Project; + Project.AddString("id"sv, ProjectName); + Project.AddString("root"sv, ""sv); + Project.AddString("engine"sv, ""sv); + Project.AddString("project"sv, ""sv); + Project.AddString("projectfile"sv, ""sv); + IoBuffer ProjectPayload = Project.Save().GetBuffer().AsIoBuffer(); + std::string ProjectRequest = fmt::format("{}/prj/{}", UrlBase, ProjectName); + Session.SetUrl({ProjectRequest}); + Session.SetBody(cpr::Body{(const char*)ProjectPayload.GetData(), ProjectPayload.GetSize()}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + }; + + auto MakeOplog = [](cpr::Session& Session, std::string_view UrlBase, std::string_view ProjectName, std::string_view OplogName) { + std::string CreateOplogRequest = fmt::format("{}/prj/{}/oplog/{}", UrlBase, ProjectName, OplogName); + Session.SetUrl({CreateOplogRequest}); + Session.SetBody(cpr::Body{}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + }; + + auto MakeOp = [](cpr::Session& Session, + std::string_view UrlBase, + std::string_view ProjectName, + std::string_view OplogName, + const CbPackage& OpPackage) { + std::string CreateOpRequest = fmt::format("{}/prj/{}/oplog/{}/new", UrlBase, ProjectName, OplogName); + Session.SetUrl({CreateOpRequest}); + zen::BinaryWriter MemOut; + legacy::SaveCbPackage(OpPackage, MemOut); + Session.SetBody(cpr::Body{(const char*)MemOut.Data(), MemOut.Size()}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + }; + + cpr::Session Session; + MakeProject(Session, Servers.GetInstance(0).GetBaseUri(), "proj0"); + MakeOplog(Session, Servers.GetInstance(0).GetBaseUri(), "proj0", "oplog0"); + + std::unordered_map<Oid, uint32_t, Oid::Hasher> SourceOps; + for (const Oid& OpId : OpIds) + { + CbPackage OpPackage = CreateOplogPackage(OpId, Attachments[OpId]); + CHECK(OpPackage.GetAttachments().size() == Attachments[OpId].size()); + AddOp(OpPackage.GetObject(), SourceOps); + MakeOp(Session, Servers.GetInstance(0).GetBaseUri(), "proj0", "oplog0", OpPackage); + } + + std::vector<IoHash> AttachmentHashes; + AttachmentHashes.reserve(Attachments.size()); + for (const auto& AttachmentOplog : Attachments) + { + for (const auto& Attachment : AttachmentOplog.second) + { + AttachmentHashes.emplace_back(Attachment.second.DecodeRawHash()); + } + } + + auto MakeCbObjectPayload = [](std::function<void(CbObjectWriter & Writer)> Write) -> IoBuffer { + CbObjectWriter Writer; + Write(Writer); + IoBuffer Result = Writer.Save().GetBuffer().AsIoBuffer(); + Result.MakeOwned(); + return Result; + }; + + auto ValidateAttachments = [&MakeCbObjectPayload, &AttachmentHashes, &Servers, &Session](int ServerIndex, + std::string_view Project, + std::string_view Oplog) { + std::string GetChunksRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(ServerIndex).GetBaseUri(), Project, Oplog); + Session.SetUrl({GetChunksRequest}); + IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes](CbObjectWriter& Writer) { + Writer << "method"sv + << "getchunks"sv; + Writer << "chunks"sv << BeginArray; + for (const IoHash& Chunk : AttachmentHashes) + { + Writer << Chunk; + } + Writer << EndArray; // chunks + }); + Session.SetBody(AsBody(Payload)); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + CbPackage ResponsePackage = ParsePackageMessage(IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size())); + CHECK(ResponsePackage.GetAttachments().size() == AttachmentHashes.size()); + }; + + auto ValidateOplog = [&SourceOps, &AddOp, &Servers, &Session](int ServerIndex, std::string_view Project, std::string_view Oplog) { + std::unordered_map<Oid, uint32_t, Oid::Hasher> TargetOps; + std::vector<CbObject> ResultingOplog; + + std::string GetOpsRequest = + fmt::format("{}/prj/{}/oplog/{}/entries", Servers.GetInstance(ServerIndex).GetBaseUri(), Project, Oplog); + Session.SetUrl({GetOpsRequest}); + cpr::Response Response = Session.Get(); + CHECK(IsHttpSuccessCode(Response.status_code)); + + IoBuffer Payload(IoBuffer::Wrap, Response.text.data(), Response.text.size()); + CbObject OplogResonse = LoadCompactBinaryObject(Payload); + CbArrayView EntriesArray = OplogResonse["entries"sv].AsArrayView(); + + for (CbFieldView OpEntry : EntriesArray) + { + CbObjectView Core = OpEntry.AsObjectView(); + BinaryWriter Writer; + Core.CopyTo(Writer); + MemoryView OpView = Writer.GetView(); + IoBuffer OpBuffer(IoBuffer::Wrap, OpView.GetData(), OpView.GetSize()); + CbObject Op(SharedBuffer(OpBuffer), CbFieldType::HasFieldType); + AddOp(Op, TargetOps); + } + CHECK(SourceOps == TargetOps); + }; + + SUBCASE("File") + { + ScopedTemporaryDirectory TempDir; + { + std::string SaveOplogRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(0).GetBaseUri(), "proj0", "oplog0"); + Session.SetUrl({SaveOplogRequest}); + + IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) { + Writer << "method"sv + << "export"sv; + Writer << "params" << BeginObject; + { + Writer << "maxblocksize"sv << 3072u; + Writer << "maxchunkembedsize"sv << 1296u; + Writer << "force"sv << false; + Writer << "file"sv << BeginObject; + { + Writer << "path"sv << path; + Writer << "name"sv + << "proj0_oplog0"sv; + } + Writer << EndObject; // "file" + } + Writer << EndObject; // "params" + }); + Session.SetBody(AsBody(Payload)); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + } + { + MakeProject(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy"); + MakeOplog(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy"); + std::string LoadOplogRequest = + fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy"); + Session.SetUrl({LoadOplogRequest}); + + IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) { + Writer << "method"sv + << "import"sv; + Writer << "params" << BeginObject; + { + Writer << "force"sv << false; + Writer << "file"sv << BeginObject; + { + Writer << "path"sv << path; + Writer << "name"sv + << "proj0_oplog0"sv; + } + Writer << EndObject; // "file" + } + Writer << EndObject; // "params" + }); + Session.SetBody(AsBody(Payload)); + + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + } + ValidateAttachments(1, "proj0_copy", "oplog0_copy"); + ValidateOplog(1, "proj0_copy", "oplog0_copy"); + } + + SUBCASE("File disable blocks") + { + ScopedTemporaryDirectory TempDir; + { + std::string SaveOplogRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(0).GetBaseUri(), "proj0", "oplog0"); + Session.SetUrl({SaveOplogRequest}); + + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { + Writer << "method"sv + << "export"sv; + Writer << "params" << BeginObject; + { + Writer << "maxblocksize"sv << 3072u; + Writer << "maxchunkembedsize"sv << 1296u; + Writer << "force"sv << false; + Writer << "file"sv << BeginObject; + { + Writer << "path"sv << TempDir.Path().string(); + Writer << "name"sv + << "proj0_oplog0"sv; + Writer << "disableblocks"sv << true; + } + Writer << EndObject; // "file" + } + Writer << EndObject; // "params" + }); + Session.SetBody(AsBody(Payload)); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + } + { + MakeProject(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy"); + MakeOplog(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy"); + std::string LoadOplogRequest = + fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy"); + Session.SetUrl({LoadOplogRequest}); + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { + Writer << "method"sv + << "import"sv; + Writer << "params" << BeginObject; + { + Writer << "force"sv << false; + Writer << "file"sv << BeginObject; + { + Writer << "path"sv << TempDir.Path().string(); + Writer << "name"sv + << "proj0_oplog0"sv; + } + Writer << EndObject; // "file" + } + Writer << EndObject; // "params" + }); + Session.SetBody(AsBody(Payload)); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + } + ValidateAttachments(1, "proj0_copy", "oplog0_copy"); + ValidateOplog(1, "proj0_copy", "oplog0_copy"); + } + + SUBCASE("File force temp blocks") + { + ScopedTemporaryDirectory TempDir; + { + std::string SaveOplogRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(0).GetBaseUri(), "proj0", "oplog0"); + Session.SetUrl({SaveOplogRequest}); + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { + Writer << "method"sv + << "export"sv; + Writer << "params" << BeginObject; + { + Writer << "maxblocksize"sv << 3072u; + Writer << "maxchunkembedsize"sv << 1296u; + Writer << "force"sv << false; + Writer << "file"sv << BeginObject; + { + Writer << "path"sv << TempDir.Path().string(); + Writer << "name"sv + << "proj0_oplog0"sv; + Writer << "enabletempblocks"sv << true; + } + Writer << EndObject; // "file" + } + Writer << EndObject; // "params" + }); + Session.SetBody(AsBody(Payload)); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + } + { + MakeProject(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy"); + MakeOplog(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy"); + std::string LoadOplogRequest = + fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy"); + Session.SetUrl({LoadOplogRequest}); + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { + Writer << "method"sv + << "import"sv; + Writer << "params" << BeginObject; + { + Writer << "force"sv << false; + Writer << "file"sv << BeginObject; + { + Writer << "path"sv << TempDir.Path().string(); + Writer << "name"sv + << "proj0_oplog0"sv; + } + Writer << EndObject; // "file" + } + Writer << EndObject; // "params" + }); + Session.SetBody(AsBody(Payload)); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + } + ValidateAttachments(1, "proj0_copy", "oplog0_copy"); + ValidateOplog(1, "proj0_copy", "oplog0_copy"); + } + + SUBCASE("Zen") + { + ScopedTemporaryDirectory TempDir; + { + std::string ExportSourceUri = Servers.GetInstance(0).GetBaseUri(); + std::string ExportTargetUri = Servers.GetInstance(1).GetBaseUri(); + MakeProject(Session, ExportTargetUri, "proj0_copy"); + MakeOplog(Session, ExportTargetUri, "proj0_copy", "oplog0_copy"); + + std::string SaveOplogRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", ExportSourceUri, "proj0", "oplog0"); + Session.SetUrl({SaveOplogRequest}); + + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { + Writer << "method"sv + << "export"sv; + Writer << "params" << BeginObject; + { + Writer << "maxblocksize"sv << 3072u; + Writer << "maxchunkembedsize"sv << 1296u; + Writer << "force"sv << false; + Writer << "zen"sv << BeginObject; + { + Writer << "url"sv << ExportTargetUri.substr(7); + Writer << "project" + << "proj0_copy"; + Writer << "oplog" + << "oplog0_copy"; + } + Writer << EndObject; // "file" + } + Writer << EndObject; // "params" + }); + Session.SetBody(AsBody(Payload)); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + } + ValidateAttachments(1, "proj0_copy", "oplog0_copy"); + ValidateOplog(1, "proj0_copy", "oplog0_copy"); + + { + std::string ImportSourceUri = Servers.GetInstance(1).GetBaseUri(); + std::string ImportTargetUri = Servers.GetInstance(2).GetBaseUri(); + MakeProject(Session, ImportTargetUri, "proj1"); + MakeOplog(Session, ImportTargetUri, "proj1", "oplog1"); + std::string LoadOplogRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", ImportTargetUri, "proj1", "oplog1"); + Session.SetUrl({LoadOplogRequest}); + + IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) { + Writer << "method"sv + << "import"sv; + Writer << "params" << BeginObject; + { + Writer << "force"sv << false; + Writer << "zen"sv << BeginObject; + { + Writer << "url"sv << ImportSourceUri.substr(7); + Writer << "project" + << "proj0_copy"; + Writer << "oplog" + << "oplog0_copy"; + } + Writer << EndObject; // "file" + } + Writer << EndObject; // "params" + }); + Session.SetBody(AsBody(Payload)); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}}); + cpr::Response Response = Session.Post(); + CHECK(IsHttpSuccessCode(Response.status_code)); + } + ValidateAttachments(2, "proj1", "oplog1"); + ValidateOplog(2, "proj1", "oplog1"); + } +} + +# if 0 +TEST_CASE("lifetime.owner") +{ + // This test is designed to verify that the hand-over of sponsor processes is handled + // correctly for the case when a second or third process is launched on the same port + // + // Due to the nature of it, it cannot be + + const uint16_t PortNumber = 23456; + + ZenServerInstance Zen1(TestEnv); + std::filesystem::path TestDir1 = TestEnv.CreateNewTestDir(); + Zen1.SetTestDir(TestDir1); + Zen1.SpawnServer(PortNumber); + Zen1.WaitUntilReady(); + Zen1.Detach(); + + ZenServerInstance Zen2(TestEnv); + std::filesystem::path TestDir2 = TestEnv.CreateNewTestDir(); + Zen2.SetTestDir(TestDir2); + Zen2.SpawnServer(PortNumber); + Zen2.WaitUntilReady(); + Zen2.Detach(); +} + +TEST_CASE("lifetime.owner.2") +{ + // This test is designed to verify that the hand-over of sponsor processes is handled + // correctly for the case when a second or third process is launched on the same port + // + // Due to the nature of it, it cannot be + + const uint16_t PortNumber = 13456; + + std::filesystem::path TestDir1 = TestEnv.CreateNewTestDir(); + std::filesystem::path TestDir2 = TestEnv.CreateNewTestDir(); + + ZenServerInstance Zen1(TestEnv); + Zen1.SetTestDir(TestDir1); + Zen1.SpawnServer(PortNumber); + Zen1.WaitUntilReady(); + + ZenServerInstance Zen2(TestEnv); + Zen2.SetTestDir(TestDir2); + Zen2.SetOwnerPid(Zen1.GetPid()); + Zen2.SpawnServer(PortNumber + 1); + Zen2.Detach(); + + ZenServerInstance Zen3(TestEnv); + Zen3.SetTestDir(TestDir2); + Zen3.SetOwnerPid(Zen1.GetPid()); + Zen3.SpawnServer(PortNumber + 1); + Zen3.Detach(); + + ZenServerInstance Zen4(TestEnv); + Zen4.SetTestDir(TestDir2); + Zen4.SetOwnerPid(Zen1.GetPid()); + Zen4.SpawnServer(PortNumber + 1); + Zen4.Detach(); +} +# endif + +} // namespace zen::tests +#else +int +main() +{ +} +#endif diff --git a/src/zenserver/admin/admin.cpp b/src/zenserver/admin/admin.cpp new file mode 100644 index 000000000..7aa1b48d1 --- /dev/null +++ b/src/zenserver/admin/admin.cpp @@ -0,0 +1,101 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "admin.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> +#include <zenstore/gc.h> + +#include <chrono> + +namespace zen { + +HttpAdminService::HttpAdminService(GcScheduler& Scheduler) : m_GcScheduler(Scheduler) +{ + using namespace std::literals; + + m_Router.RegisterRoute( + "health", + [](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddBool("ok", true); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "gc", + [this](HttpRouterRequest& Req) { + const GcSchedulerStatus Status = m_GcScheduler.Status(); + + CbObjectWriter Response; + Response << "Status"sv << (GcSchedulerStatus::kIdle == Status ? "Idle"sv : "Running"sv); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Response.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "gc", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + GcScheduler::TriggerParams GcParams; + + if (auto Param = Params.GetValue("smallobjects"); Param.empty() == false) + { + GcParams.CollectSmallObjects = Param == "true"sv; + } + + if (auto Param = Params.GetValue("maxcacheduration"); Param.empty() == false) + { + if (auto Value = ParseInt<uint64_t>(Param)) + { + GcParams.MaxCacheDuration = std::chrono::seconds(Value.value()); + } + } + + if (auto Param = Params.GetValue("disksizesoftlimit"); Param.empty() == false) + { + if (auto Value = ParseInt<uint64_t>(Param)) + { + GcParams.DiskSizeSoftLimit = Value.value(); + } + } + + const bool Started = m_GcScheduler.Trigger(GcParams); + + CbObjectWriter Response; + Response << "Status"sv << (Started ? "Started"sv : "Running"sv); + HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save()); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "", + [](HttpRouterRequest& Req) { + CbObject Payload = Req.ServerRequest().ReadPayloadObject(); + + CbObjectWriter Obj; + Obj.AddBool("ok", true); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kPost); +} + +HttpAdminService::~HttpAdminService() +{ +} + +const char* +HttpAdminService::BaseUri() const +{ + return "/admin/"; +} + +void +HttpAdminService::HandleRequest(zen::HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +} // namespace zen diff --git a/src/zenserver/admin/admin.h b/src/zenserver/admin/admin.h new file mode 100644 index 000000000..9463ffbb3 --- /dev/null +++ b/src/zenserver/admin/admin.h @@ -0,0 +1,26 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zenhttp/httpserver.h> + +namespace zen { + +class GcScheduler; + +class HttpAdminService : public zen::HttpService +{ +public: + HttpAdminService(GcScheduler& Scheduler); + ~HttpAdminService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + HttpRequestRouter m_Router; + GcScheduler& m_GcScheduler; +}; + +} // namespace zen diff --git a/src/zenserver/auth/authmgr.cpp b/src/zenserver/auth/authmgr.cpp new file mode 100644 index 000000000..4cd6b3362 --- /dev/null +++ b/src/zenserver/auth/authmgr.cpp @@ -0,0 +1,506 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <auth/authmgr.h> +#include <auth/oidc.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/crypto.h> +#include <zencore/filesystem.h> +#include <zencore/logging.h> + +#include <condition_variable> +#include <memory> +#include <shared_mutex> +#include <thread> +#include <unordered_map> + +#include <fmt/format.h> + +namespace zen { + +using namespace std::literals; + +namespace details { + IoBuffer ReadEncryptedFile(std::filesystem::path Path, + const AesKey256Bit& Key, + const AesIV128Bit& IV, + std::optional<std::string>& Reason) + { + FileContents Result = ReadFile(Path); + + if (Result.ErrorCode) + { + return IoBuffer(); + } + + IoBuffer EncryptedBuffer = Result.Flatten(); + + if (EncryptedBuffer.GetSize() == 0) + { + return IoBuffer(); + } + + std::vector<uint8_t> DecryptionBuffer; + DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize); + + MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason); + + if (DecryptedView.IsEmpty()) + { + return IoBuffer(); + } + + return IoBufferBuilder::MakeCloneFromMemory(DecryptedView); + } + + void WriteEncryptedFile(std::filesystem::path Path, + IoBuffer FileData, + const AesKey256Bit& Key, + const AesIV128Bit& IV, + std::optional<std::string>& Reason) + { + if (FileData.GetSize() == 0) + { + return; + } + + std::vector<uint8_t> EncryptionBuffer; + EncryptionBuffer.resize(FileData.GetSize() + Aes::BlockSize); + + MemoryView EncryptedView = Aes::Encrypt(Key, IV, FileData, MakeMutableMemoryView(EncryptionBuffer), Reason); + + if (EncryptedView.IsEmpty()) + { + return; + } + + WriteFile(Path, IoBuffer(IoBuffer::Wrap, EncryptedView.GetData(), EncryptedView.GetSize())); + } +} // namespace details + +class AuthMgrImpl final : public AuthMgr +{ + using Clock = std::chrono::system_clock; + using TimePoint = Clock::time_point; + using Seconds = std::chrono::seconds; + +public: + AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth")) + { + LoadState(); + + m_BackgroundThread.Interval = Config.UpdateInterval; + m_BackgroundThread.Thread = std::thread(&AuthMgrImpl::BackgroundThreadEntry, this); + } + + virtual ~AuthMgrImpl() { Shutdown(); } + + virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final + { + if (OpenIdProviderExist(Params.Name)) + { + ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name); + return; + } + + if (Params.Name.empty()) + { + ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'"); + return; + } + + std::unique_ptr<OidcClient> Client = + std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId}); + + if (const auto InitResult = Client->Initialize(); InitResult.Ok == false) + { + ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason); + return; + } + + std::string NewProviderName = std::string(Params.Name); + + OpenIdProvider* NewProvider = nullptr; + + { + std::unique_lock _(m_ProviderMutex); + + if (m_OpenIdProviders.contains(NewProviderName)) + { + return; + } + + auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique<OpenIdProvider>()); + NewProvider = InsertResult.first->second.get(); + } + + NewProvider->Name = std::string(Params.Name); + NewProvider->Url = std::string(Params.Url); + NewProvider->ClientId = std::string(Params.ClientId); + NewProvider->HttpClient = std::move(Client); + + ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url); + } + + virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final + { + if (Params.ProviderName.empty()) + { + ZEN_WARN("trying add OpenID token with invalid provider name"); + return false; + } + + if (Params.RefreshToken.empty()) + { + ZEN_WARN("add OpenID token FAILED, reason 'Token invalid'"); + return false; + } + + auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken); + + if (RefreshResult.Ok == false) + { + ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason); + return false; + } + + bool IsNew = false; + + { + auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, + .RefreshToken = RefreshResult.RefreshToken, + .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), + .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; + + std::unique_lock _(m_TokenMutex); + + const auto InsertResult = m_OpenIdTokens.insert_or_assign(std::string(Params.ProviderName), std::move(Token)); + + IsNew = InsertResult.second; + } + + if (IsNew) + { + ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName); + } + else + { + ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName); + } + + return true; + } + + virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final + { + std::unique_lock _(m_TokenMutex); + + if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end()) + { + const OpenIdToken& Token = It->second; + + return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + } + + return {}; + } + +private: + bool OpenIdProviderExist(std::string_view ProviderName) + { + std::unique_lock _(m_ProviderMutex); + + return m_OpenIdProviders.contains(std::string(ProviderName)); + } + + OidcClient& GetOpenIdClient(std::string_view ProviderName) + { + std::unique_lock _(m_ProviderMutex); + return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get(); + } + + OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken) + { + if (OpenIdProviderExist(ProviderName) == false) + { + return {.Reason = fmt::format("provider '{}' is missing", ProviderName)}; + } + + OidcClient& Client = GetOpenIdClient(ProviderName); + + return Client.RefreshToken(RefreshToken); + } + + void Shutdown() + { + BackgroundThread::Stop(m_BackgroundThread); + SaveState(); + } + + void LoadState() + { + try + { + std::optional<std::string> Reason; + + IoBuffer Buffer = + details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason); + + if (!Buffer) + { + if (Reason) + { + ZEN_WARN("load auth state FAILED, reason '{}'", Reason.value()); + } + + return; + } + + const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All); + + if (ValidationError != CbValidateError::None) + { + ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'"); + return; + } + + if (CbObject AuthState = LoadCompactBinaryObject(Buffer)) + { + for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv]) + { + CbObjectView ProviderObj = ProviderView.AsObjectView(); + + std::string_view ProviderName = ProviderObj["Name"].AsString(); + std::string_view Url = ProviderObj["Url"].AsString(); + std::string_view ClientId = ProviderObj["ClientId"].AsString(); + + AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId}); + } + + for (CbFieldView TokenView : AuthState["OpenIdTokens"sv]) + { + CbObjectView TokenObj = TokenView.AsObjectView(); + + std::string_view ProviderName = TokenObj["ProviderName"sv].AsString(); + std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString(); + + const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken}); + + if (!Ok) + { + ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName); + } + } + } + } + catch (std::exception& Err) + { + ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what()); + + { + std::unique_lock _(m_ProviderMutex); + m_OpenIdProviders.clear(); + } + + { + std::unique_lock _(m_TokenMutex); + m_OpenIdTokens.clear(); + } + } + } + + void SaveState() + { + try + { + CbObjectWriter AuthState; + + { + std::unique_lock _(m_ProviderMutex); + + if (m_OpenIdProviders.size() > 0) + { + AuthState.BeginArray("OpenIdProviders"); + for (const auto& Kv : m_OpenIdProviders) + { + AuthState.BeginObject(); + AuthState << "Name"sv << Kv.second->Name; + AuthState << "Url"sv << Kv.second->Url; + AuthState << "ClientId"sv << Kv.second->ClientId; + AuthState.EndObject(); + } + AuthState.EndArray(); + } + } + + { + std::unique_lock _(m_TokenMutex); + + AuthState.BeginArray("OpenIdTokens"); + if (m_OpenIdTokens.size() > 0) + { + for (const auto& Kv : m_OpenIdTokens) + { + AuthState.BeginObject(); + AuthState << "ProviderName"sv << Kv.first; + AuthState << "RefreshToken"sv << Kv.second.RefreshToken; + AuthState.EndObject(); + } + } + AuthState.EndArray(); + } + + std::filesystem::create_directories(m_Config.RootDirectory); + + std::optional<std::string> Reason; + + details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv, + AuthState.Save().GetBuffer().AsIoBuffer(), + m_Config.EncryptionKey, + m_Config.EncryptionIV, + Reason); + + if (Reason) + { + ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value()); + } + } + catch (std::exception& Err) + { + ZEN_ERROR("serialize state FAILED, reason '{}'", Err.what()); + } + } + + void BackgroundThreadEntry() + { + for (;;) + { + std::cv_status SignalStatus = BackgroundThread::WaitForSignal(m_BackgroundThread); + + if (m_BackgroundThread.Running.load() == false) + { + break; + } + + if (SignalStatus != std::cv_status::timeout) + { + continue; + } + + { + // Refresh Open ID token(s) + + std::vector<OpenIdTokenMap::value_type> ExpiredTokens; + + { + std::unique_lock _(m_TokenMutex); + + for (const auto& Kv : m_OpenIdTokens) + { + const Seconds ExpiresIn = std::chrono::duration_cast<Seconds>(Kv.second.ExpireTime - Clock::now()); + const bool Expired = ExpiresIn < Seconds(m_BackgroundThread.Interval * 2); + + if (Expired) + { + ExpiredTokens.push_back(Kv); + } + } + } + + ZEN_DEBUG("refreshing '{}' OpenID token(s)", ExpiredTokens.size()); + + for (const auto& Kv : ExpiredTokens) + { + OidcClient::RefreshTokenResult RefreshResult = RefreshOpenIdToken(Kv.first, Kv.second.RefreshToken); + + if (RefreshResult.Ok) + { + ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first); + + auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken, + .RefreshToken = RefreshResult.RefreshToken, + .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken), + .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)}; + + { + std::unique_lock _(m_TokenMutex); + m_OpenIdTokens.insert_or_assign(Kv.first, std::move(Token)); + } + } + else + { + ZEN_WARN("refresh access token from provider '{}' FAILED, reason '{}'", Kv.first, RefreshResult.Reason); + } + } + } + } + } + + struct BackgroundThread + { + std::chrono::seconds Interval{10}; + std::mutex Mutex; + std::condition_variable Signal; + std::atomic_bool Running{true}; + std::thread Thread; + + static void Stop(BackgroundThread& State) + { + if (State.Running.load()) + { + State.Running.store(false); + State.Signal.notify_one(); + } + + if (State.Thread.joinable()) + { + State.Thread.join(); + } + } + + static std::cv_status WaitForSignal(BackgroundThread& State) + { + std::unique_lock Lock(State.Mutex); + return State.Signal.wait_for(Lock, State.Interval); + } + }; + + struct OpenIdProvider + { + std::string Name; + std::string Url; + std::string ClientId; + std::unique_ptr<OidcClient> HttpClient; + }; + + struct OpenIdToken + { + std::string IdentityToken; + std::string RefreshToken; + std::string AccessToken; + TimePoint ExpireTime{}; + }; + + using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>; + using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>; + + spdlog::logger& Log() { return m_Log; } + + AuthConfig m_Config; + spdlog::logger& m_Log; + BackgroundThread m_BackgroundThread; + OpenIdProviderMap m_OpenIdProviders; + OpenIdTokenMap m_OpenIdTokens; + std::mutex m_ProviderMutex; + std::shared_mutex m_TokenMutex; +}; + +std::unique_ptr<AuthMgr> +AuthMgr::Create(const AuthConfig& Config) +{ + return std::make_unique<AuthMgrImpl>(Config); +} + +} // namespace zen diff --git a/src/zenserver/auth/authmgr.h b/src/zenserver/auth/authmgr.h new file mode 100644 index 000000000..054588ab9 --- /dev/null +++ b/src/zenserver/auth/authmgr.h @@ -0,0 +1,56 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/crypto.h> +#include <zencore/iobuffer.h> +#include <zencore/string.h> + +#include <chrono> +#include <filesystem> +#include <memory> + +namespace zen { + +struct AuthConfig +{ + std::filesystem::path RootDirectory; + std::chrono::seconds UpdateInterval{30}; + AesKey256Bit EncryptionKey; + AesIV128Bit EncryptionIV; +}; + +class AuthMgr +{ +public: + virtual ~AuthMgr() = default; + + struct AddOpenIdProviderParams + { + std::string_view Name; + std::string_view Url; + std::string_view ClientId; + }; + + virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) = 0; + + struct AddOpenIdTokenParams + { + std::string_view ProviderName; + std::string_view RefreshToken; + }; + + virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) = 0; + + struct OpenIdAccessToken + { + std::string AccessToken; + std::chrono::system_clock::time_point ExpireTime{}; + }; + + virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) = 0; + + static std::unique_ptr<AuthMgr> Create(const AuthConfig& Config); +}; + +} // namespace zen diff --git a/src/zenserver/auth/authservice.cpp b/src/zenserver/auth/authservice.cpp new file mode 100644 index 000000000..1cc679540 --- /dev/null +++ b/src/zenserver/auth/authservice.cpp @@ -0,0 +1,91 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <auth/authservice.h> + +#include <auth/authmgr.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr) +{ + m_Router.RegisterRoute( + "oidc/refreshtoken", + [this](HttpRouterRequest& RouterRequest) { + HttpServerRequest& ServerRequest = RouterRequest.ServerRequest(); + + const HttpContentType ContentType = ServerRequest.RequestContentType(); + + if ((ContentType == HttpContentType::kUnknownContentType || ContentType == HttpContentType::kJSON) == false) + { + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest); + } + + const IoBuffer Body = ServerRequest.ReadPayload(); + + std::string JsonText(reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()); + std::string JsonError; + json11::Json TokenInfo = json11::Json::parse(JsonText, JsonError); + + if (!JsonError.empty()) + { + CbObjectWriter Response; + Response << "Result"sv << false; + Response << "Error"sv << JsonError; + + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save()); + } + + const std::string RefreshToken = TokenInfo["RefreshToken"].string_value(); + std::string ProviderName = TokenInfo["ProviderName"].string_value(); + + if (ProviderName.empty()) + { + ProviderName = "Default"sv; + } + + const bool Ok = + m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = ProviderName, .RefreshToken = RefreshToken}); + + if (Ok) + { + ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest); + } + else + { + CbObjectWriter Response; + Response << "Result"sv << false; + Response << "Error"sv + << "Invalid token"sv; + + ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save()); + } + }, + HttpVerb::kPost); +} + +HttpAuthService::~HttpAuthService() +{ +} + +const char* +HttpAuthService::BaseUri() const +{ + return "/auth/"; +} + +void +HttpAuthService::HandleRequest(zen::HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +} // namespace zen diff --git a/src/zenserver/auth/authservice.h b/src/zenserver/auth/authservice.h new file mode 100644 index 000000000..64b86e21f --- /dev/null +++ b/src/zenserver/auth/authservice.h @@ -0,0 +1,25 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +namespace zen { + +class AuthMgr; + +class HttpAuthService final : public zen::HttpService +{ +public: + HttpAuthService(AuthMgr& AuthMgr); + virtual ~HttpAuthService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + AuthMgr& m_AuthMgr; + HttpRequestRouter m_Router; +}; + +} // namespace zen diff --git a/src/zenserver/auth/oidc.cpp b/src/zenserver/auth/oidc.cpp new file mode 100644 index 000000000..d2265c22f --- /dev/null +++ b/src/zenserver/auth/oidc.cpp @@ -0,0 +1,127 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <auth/oidc.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <fmt/format.h> +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +namespace details { + + using StringArray = std::vector<std::string>; + + StringArray ToStringArray(const json11::Json JsonArray) + { + StringArray Result; + + const auto& Items = JsonArray.array_items(); + + for (const auto& Item : Items) + { + Result.push_back(Item.string_value()); + } + + return Result; + } + +} // namespace details + +using namespace std::literals; + +OidcClient::OidcClient(const OidcClient::Options& Options) +{ + m_BaseUrl = std::string(Options.BaseUrl); + m_ClientId = std::string(Options.ClientId); +} + +OidcClient::InitResult +OidcClient::Initialize() +{ + ExtendableStringBuilder<256> Uri; + Uri << m_BaseUrl << "/.well-known/openid-configuration"sv; + + cpr::Session Session; + + Session.SetOption(cpr::Url{Uri.c_str()}); + + cpr::Response Response = Session.Get(); + + if (Response.error) + { + return {.Reason = std::move(Response.error.message)}; + } + + if (Response.status_code != 200) + { + return {.Reason = std::move(Response.reason)}; + } + + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + + if (JsonError.empty() == false) + { + return {.Reason = std::move(JsonError)}; + } + + m_Config = {.Issuer = Json["issuer"].string_value(), + .AuthorizationEndpoint = Json["authorization_endpoint"].string_value(), + .TokenEndpoint = Json["token_endpoint"].string_value(), + .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(), + .RegistrationEndpoint = Json["registration_endpoint"].string_value(), + .JwksUri = Json["jwks_uri"].string_value(), + .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]), + .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]), + .SupportedGrantTypes = details::ToStringArray(Json["grant_types_supported"]), + .SupportedScopes = details::ToStringArray(Json["scopes_supported"]), + .SupportedTokenEndpointAuthMethods = details::ToStringArray(Json["token_endpoint_auth_methods_supported"]), + .SupportedClaims = details::ToStringArray(Json["claims_supported"])}; + + return {.Ok = true}; +} + +OidcClient::RefreshTokenResult +OidcClient::RefreshToken(std::string_view RefreshToken) +{ + const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId); + + cpr::Session Session; + + Session.SetOption(cpr::Url{m_Config.TokenEndpoint.c_str()}); + Session.SetOption(cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}}); + Session.SetBody(cpr::Body{Body.data(), Body.size()}); + + cpr::Response Response = Session.Post(); + + if (Response.error) + { + return {.Reason = std::move(Response.error.message)}; + } + + if (Response.status_code != 200) + { + return {.Reason = fmt::format("{} ({})", Response.reason, Response.text)}; + } + + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + + if (JsonError.empty() == false) + { + return {.Reason = std::move(JsonError)}; + } + + return {.TokenType = Json["token_type"].string_value(), + .AccessToken = Json["access_token"].string_value(), + .RefreshToken = Json["refresh_token"].string_value(), + .IdentityToken = Json["id_token"].string_value(), + .Scope = Json["scope"].string_value(), + .ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()), + .Ok = true}; +} + +} // namespace zen diff --git a/src/zenserver/auth/oidc.h b/src/zenserver/auth/oidc.h new file mode 100644 index 000000000..f43ae3cd7 --- /dev/null +++ b/src/zenserver/auth/oidc.h @@ -0,0 +1,76 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/string.h> + +#include <vector> + +namespace zen { + +class OidcClient +{ +public: + struct Options + { + std::string_view BaseUrl; + std::string_view ClientId; + }; + + OidcClient(const Options& Options); + ~OidcClient() = default; + + OidcClient(const OidcClient&) = delete; + OidcClient& operator=(const OidcClient&) = delete; + + struct Result + { + std::string Reason; + bool Ok = false; + }; + + using InitResult = Result; + + InitResult Initialize(); + + struct RefreshTokenResult + { + std::string TokenType; + std::string AccessToken; + std::string RefreshToken; + std::string IdentityToken; + std::string Scope; + std::string Reason; + int64_t ExpiresInSeconds{}; + bool Ok = false; + }; + + RefreshTokenResult RefreshToken(std::string_view RefreshToken); + +private: + using StringArray = std::vector<std::string>; + + struct OpenIdConfiguration + { + std::string Issuer; + std::string AuthorizationEndpoint; + std::string TokenEndpoint; + std::string UserInfoEndpoint; + std::string RegistrationEndpoint; + std::string EndSessionEndpoint; + std::string DeviceAuthorizationEndpoint; + std::string JwksUri; + StringArray SupportedResponseTypes; + StringArray SupportedResponseModes; + StringArray SupportedGrantTypes; + StringArray SupportedScopes; + StringArray SupportedTokenEndpointAuthMethods; + StringArray SupportedClaims; + }; + + std::string m_BaseUrl; + std::string m_ClientId; + OpenIdConfiguration m_Config; +}; + +} // namespace zen diff --git a/src/zenserver/cache/cachetracking.cpp b/src/zenserver/cache/cachetracking.cpp new file mode 100644 index 000000000..9119e3122 --- /dev/null +++ b/src/zenserver/cache/cachetracking.cpp @@ -0,0 +1,376 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "cachetracking.h" + +#if ZEN_USE_CACHE_TRACKER + +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinaryvalue.h> +# include <zencore/endian.h> +# include <zencore/filesystem.h> +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/string.h> + +# include <zencore/testing.h> +# include <zencore/testutils.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# pragma comment(lib, "Rpcrt4.lib") // RocksDB made me do this +# include <fmt/format.h> +# include <rocksdb/db.h> +# include <tsl/robin_map.h> +# include <tsl/robin_set.h> +# include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +namespace rocksdb = ROCKSDB_NAMESPACE; + +static constinit auto Epoch = std::chrono::time_point<std::chrono::system_clock>{}; + +static uint64_t +GetCurrentCacheTimeStamp() +{ + auto Duration = std::chrono::system_clock::now() - Epoch; + uint64_t Millis = std::chrono::duration_cast<std::chrono::milliseconds>(Duration).count(); + + return Millis; +} + +struct CacheAccessSnapshot +{ +public: + void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey) + { + BucketTracker* Tracker = GetBucket(std::string(BucketSegment)); + + Tracker->Track(HashKey); + } + + bool SerializeSnapshot(CbObjectWriter& Cbo) + { + bool Serialized = false; + RwLock::ExclusiveLockScope _(m_Lock); + + for (const auto& Kv : m_BucketMapping) + { + if (m_Buckets[Kv.second]->Size()) + { + Cbo.BeginArray(Kv.first); + m_Buckets[Kv.second]->SerializeSnapshotAndClear(Cbo); + Cbo.EndArray(); + Serialized = true; + } + } + + return Serialized; + } + +private: + struct BucketTracker + { + mutable RwLock Lock; + tsl::robin_set<IoHash> AccessedKeys; + + void Track(const IoHash& HashKey) + { + if (RwLock::SharedLockScope _(Lock); AccessedKeys.contains(HashKey)) + { + return; + } + + RwLock::ExclusiveLockScope _(Lock); + + AccessedKeys.insert(HashKey); + } + + void SerializeSnapshotAndClear(CbObjectWriter& Cbo) + { + RwLock::ExclusiveLockScope _(Lock); + + for (const IoHash& Hash : AccessedKeys) + { + Cbo.AddHash(Hash); + } + + AccessedKeys.clear(); + } + + size_t Size() const + { + RwLock::SharedLockScope _(Lock); + return AccessedKeys.size(); + } + }; + + BucketTracker* GetBucket(const std::string& BucketName) + { + RwLock::SharedLockScope _(m_Lock); + + if (auto It = m_BucketMapping.find(BucketName); It == m_BucketMapping.end()) + { + _.ReleaseNow(); + + return AddNewBucket(BucketName); + } + else + { + return m_Buckets[It->second].get(); + } + } + + BucketTracker* AddNewBucket(const std::string& BucketName) + { + RwLock::ExclusiveLockScope _(m_Lock); + + if (auto It = m_BucketMapping.find(BucketName); It == m_BucketMapping.end()) + { + const uint32_t BucketIndex = gsl::narrow<uint32_t>(m_Buckets.size()); + m_Buckets.emplace_back(std::make_unique<BucketTracker>()); + m_BucketMapping[BucketName] = BucketIndex; + + return m_Buckets[BucketIndex].get(); + } + else + { + return m_Buckets[It->second].get(); + } + } + + RwLock m_Lock; + std::vector<std::unique_ptr<BucketTracker>> m_Buckets; + tsl::robin_map<std::string, uint32_t> m_BucketMapping; +}; + +struct ZenCacheTracker::Impl +{ + Impl(std::filesystem::path StateDirectory) + { + std::filesystem::path StatsDbPath{StateDirectory / ".zdb"}; + + std::string RocksdbPath = StatsDbPath.string(); + + ZEN_DEBUG("opening tracker db at '{}'", RocksdbPath); + + rocksdb::DB* Db = nullptr; + rocksdb::DBOptions Options; + Options.create_if_missing = true; + + std::vector<std::string> ExistingColumnFamilies; + rocksdb::Status Status = rocksdb::DB::ListColumnFamilies(Options, RocksdbPath, &ExistingColumnFamilies); + + std::vector<rocksdb::ColumnFamilyDescriptor> ColumnDescriptors; + + if (Status.IsPathNotFound()) + { + ColumnDescriptors.emplace_back(rocksdb::ColumnFamilyDescriptor{rocksdb::kDefaultColumnFamilyName, {}}); + } + else if (Status.ok()) + { + for (const std::string& Column : ExistingColumnFamilies) + { + rocksdb::ColumnFamilyDescriptor ColumnFamily; + ColumnFamily.name = Column; + ColumnDescriptors.push_back(ColumnFamily); + } + } + else + { + throw std::runtime_error(fmt::format("column family iteration failed for '{}': '{}'", RocksdbPath, Status.getState()).c_str()); + } + + Status = rocksdb::DB::Open(Options, RocksdbPath, ColumnDescriptors, &m_RocksDbColumnHandles, &Db); + + if (!Status.ok()) + { + throw std::runtime_error(fmt::format("database open failed for '{}': '{}'", RocksdbPath, Status.getState()).c_str()); + } + + m_RocksDb.reset(Db); + } + + ~Impl() + { + for (auto* Column : m_RocksDbColumnHandles) + { + delete Column; + } + + m_RocksDbColumnHandles.clear(); + } + + struct KeyStruct + { + uint64_t TimestampLittleEndian; + }; + + void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey) { m_CurrentSnapshot.TrackAccess(BucketSegment, HashKey); } + + void SaveSnapshot() + { + CbObjectWriter Cbo; + + if (m_CurrentSnapshot.SerializeSnapshot(Cbo)) + { + IoBuffer SnapshotBuffer = Cbo.Save().GetBuffer().AsIoBuffer(); + + const KeyStruct Key{.TimestampLittleEndian = ToNetworkOrder(GetCurrentCacheTimeStamp())}; + rocksdb::Slice KeySlice{(const char*)&Key, sizeof Key}; + rocksdb::Slice ValueSlice{(char*)SnapshotBuffer.Data(), SnapshotBuffer.Size()}; + + rocksdb::WriteOptions Wo; + m_RocksDb->Put(Wo, KeySlice, ValueSlice); + } + } + + void IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback) + { + rocksdb::ManagedSnapshot Snap(m_RocksDb.get()); + + rocksdb::ReadOptions Ro; + Ro.snapshot = Snap.snapshot(); + + std::unique_ptr<rocksdb::Iterator> It{m_RocksDb->NewIterator(Ro)}; + + const KeyStruct ZeroKey{.TimestampLittleEndian = 0}; + rocksdb::Slice ZeroKeySlice{(const char*)&ZeroKey, sizeof ZeroKey}; + + It->Seek(ZeroKeySlice); + + while (It->Valid()) + { + rocksdb::Slice KeySlice = It->key(); + rocksdb::Slice ValueSlice = It->value(); + + if (KeySlice.size() == sizeof(KeyStruct)) + { + IoBuffer ValueBuffer(IoBuffer::Wrap, ValueSlice.data(), ValueSlice.size()); + + CbObject Value = LoadCompactBinaryObject(ValueBuffer); + + uint64_t Key = FromNetworkOrder(*reinterpret_cast<const uint64_t*>(KeySlice.data())); + + Callback(Key, Value); + } + + It->Next(); + } + } + + std::unique_ptr<rocksdb::DB> m_RocksDb; + std::vector<rocksdb::ColumnFamilyHandle*> m_RocksDbColumnHandles; + CacheAccessSnapshot m_CurrentSnapshot; +}; + +ZenCacheTracker::ZenCacheTracker(std::filesystem::path StateDirectory) : m_Impl(new Impl(StateDirectory)) +{ +} + +ZenCacheTracker::~ZenCacheTracker() +{ + delete m_Impl; +} + +void +ZenCacheTracker::TrackAccess(std::string_view BucketSegment, const IoHash& HashKey) +{ + m_Impl->TrackAccess(BucketSegment, HashKey); +} + +void +ZenCacheTracker::SaveSnapshot() +{ + m_Impl->SaveSnapshot(); +} + +void +ZenCacheTracker::IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback) +{ + m_Impl->IterateSnapshots(std::move(Callback)); +} + +# if ZEN_WITH_TESTS + +TEST_CASE("z$.tracker") +{ + using namespace std::literals; + + const uint64_t t0 = GetCurrentCacheTimeStamp(); + + ScopedTemporaryDirectory TempDir; + + ZenCacheTracker Zcs(TempDir.Path()); + + tsl::robin_set<IoHash> KeyHashes; + + for (int i = 0; i < 10000; ++i) + { + IoHash KeyHash = IoHash::HashBuffer(&i, sizeof i); + + KeyHashes.insert(KeyHash); + + Zcs.TrackAccess("foo"sv, KeyHash); + } + + for (int i = 0; i < 10000; ++i) + { + IoHash KeyHash = IoHash::HashBuffer(&i, sizeof i); + + Zcs.TrackAccess("foo"sv, KeyHash); + } + + Zcs.SaveSnapshot(); + + for (int n = 0; n < 10; ++n) + { + for (int i = 0; i < 1000; ++i) + { + const int Index = i + n * 1000; + IoHash KeyHash = IoHash::HashBuffer(&Index, sizeof Index); + + Zcs.TrackAccess("foo"sv, KeyHash); + } + + Zcs.SaveSnapshot(); + } + + Zcs.SaveSnapshot(); + + const uint64_t t1 = GetCurrentCacheTimeStamp(); + + int SnapshotCount = 0; + + Zcs.IterateSnapshots([&](uint64_t TimeStamp, CbObject Snapshot) { + CHECK(TimeStamp >= t0); + CHECK(TimeStamp <= t1); + + for (auto& Field : Snapshot) + { + CHECK_EQ(Field.GetName(), "foo"sv); + + const CbArray& Array = Field.AsArray(); + + for (const auto& Element : Array) + { + CHECK(KeyHashes.contains(Element.GetValue().AsHash())); + } + } + + ++SnapshotCount; + }); + + CHECK_EQ(SnapshotCount, 11); +} + +# endif + +void +cachetracker_forcelink() +{ +} + +} // namespace zen + +#endif // ZEN_USE_CACHE_TRACKER diff --git a/src/zenserver/cache/cachetracking.h b/src/zenserver/cache/cachetracking.h new file mode 100644 index 000000000..fdfe1a4c7 --- /dev/null +++ b/src/zenserver/cache/cachetracking.h @@ -0,0 +1,41 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> + +#include <stdint.h> +#include <filesystem> +#include <functional> + +namespace zen { + +#define ZEN_USE_CACHE_TRACKER 0 +#if ZEN_USE_CACHE_TRACKER + +class CbObject; + +/** + */ + +class ZenCacheTracker +{ +public: + ZenCacheTracker(std::filesystem::path StateDirectory); + ~ZenCacheTracker(); + + void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey); + void SaveSnapshot(); + void IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback); + +private: + struct Impl; + + Impl* m_Impl = nullptr; +}; + +void cachetracker_forcelink(); + +#endif // ZEN_USE_CACHE_TRACKER + +} // namespace zen diff --git a/src/zenserver/cache/structuredcache.cpp b/src/zenserver/cache/structuredcache.cpp new file mode 100644 index 000000000..90e905bf6 --- /dev/null +++ b/src/zenserver/cache/structuredcache.cpp @@ -0,0 +1,3159 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "structuredcache.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/compress.h> +#include <zencore/enumflags.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zencore/workthreadpool.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/httpshared.h> +#include <zenutil/cache/cache.h> +#include <zenutil/cache/rpcrecording.h> + +#include "monitoring/httpstats.h" +#include "structuredcachestore.h" +#include "upstream/jupiter.h" +#include "upstream/upstreamcache.h" +#include "upstream/zen.h" +#include "zenstore/cidstore.h" +#include "zenstore/scrubcontext.h" + +#include <algorithm> +#include <atomic> +#include <filesystem> +#include <queue> +#include <thread> + +#include <cpr/cpr.h> +#include <gsl/gsl-lite.hpp> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif + +namespace zen { + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +CachePolicy +ParseCachePolicy(const HttpServerRequest::QueryParams& QueryParams) +{ + std::string_view PolicyText = QueryParams.GetValue("Policy"sv); + return !PolicyText.empty() ? zen::ParseCachePolicy(PolicyText) : CachePolicy::Default; +} + +CacheRecordPolicy +LoadCacheRecordPolicy(CbObjectView Object, CachePolicy DefaultPolicy = CachePolicy::Default) +{ + OptionalCacheRecordPolicy Policy = CacheRecordPolicy::Load(Object); + return Policy ? std::move(Policy).Get() : CacheRecordPolicy(DefaultPolicy); +} + +struct AttachmentCount +{ + uint32_t New = 0; + uint32_t Valid = 0; + uint32_t Invalid = 0; + uint32_t Total = 0; +}; + +struct PutRequestData +{ + std::string Namespace; + CacheKey Key; + CbObjectView RecordObject; + CacheRecordPolicy Policy; +}; + +namespace { + static constinit std::string_view HttpZCacheRPCPrefix = "$rpc"sv; + static constinit std::string_view HttpZCacheUtilStartRecording = "exec$/start-recording"sv; + static constinit std::string_view HttpZCacheUtilStopRecording = "exec$/stop-recording"sv; + static constinit std::string_view HttpZCacheUtilReplayRecording = "exec$/replay-recording"sv; + static constinit std::string_view HttpZCacheDetailsPrefix = "details$"sv; + + struct HttpRequestData + { + std::optional<std::string> Namespace; + std::optional<std::string> Bucket; + std::optional<IoHash> HashKey; + std::optional<IoHash> ValueContentId; + }; + + constinit AsciiSet ValidNamespaceNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + constinit AsciiSet ValidBucketNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + + std::optional<std::string> GetValidNamespaceName(std::string_view Name) + { + if (Name.empty()) + { + ZEN_WARN("Namespace is invalid, empty namespace is not allowed"); + return {}; + } + + if (Name.length() > 64) + { + ZEN_WARN("Namespace '{}' is invalid, length exceeds 64 characters", Name); + return {}; + } + + if (!AsciiSet::HasOnly(Name, ValidNamespaceNameCharactersSet)) + { + ZEN_WARN("Namespace '{}' is invalid, invalid characters detected", Name); + return {}; + } + + return ToLower(Name); + } + + std::optional<std::string> GetValidBucketName(std::string_view Name) + { + if (Name.empty()) + { + ZEN_WARN("Bucket name is invalid, empty bucket name is not allowed"); + return {}; + } + + if (!AsciiSet::HasOnly(Name, ValidBucketNameCharactersSet)) + { + ZEN_WARN("Bucket name '{}' is invalid, invalid characters detected", Name); + return {}; + } + + return ToLower(Name); + } + + std::optional<IoHash> GetValidIoHash(std::string_view Hash) + { + if (Hash.length() != IoHash::StringLength) + { + return {}; + } + + IoHash KeyHash; + if (!ParseHexBytes(Hash.data(), Hash.size(), KeyHash.Hash)) + { + return {}; + } + return KeyHash; + } + + bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data) + { + std::vector<std::string_view> Tokens; + uint32_t TokenCount = ForEachStrTok(Key, '/', [&](const std::string_view& Token) { + Tokens.push_back(Token); + return true; + }); + + switch (TokenCount) + { + case 0: + return true; + case 1: + Data.Namespace = GetValidNamespaceName(Tokens[0]); + return Data.Namespace.has_value(); + case 2: + { + std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]); + if (PossibleHashKey.has_value()) + { + // Legacy bucket/key request + Data.Bucket = GetValidBucketName(Tokens[0]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = PossibleHashKey; + Data.Namespace = ZenCacheStore::DefaultNamespace; + return true; + } + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + return true; + } + case 3: + { + std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]); + if (PossibleHashKey.has_value()) + { + // Legacy bucket/key/valueid request + Data.Bucket = GetValidBucketName(Tokens[0]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = PossibleHashKey; + Data.ValueContentId = GetValidIoHash(Tokens[2]); + if (!Data.ValueContentId.has_value()) + { + return false; + } + Data.Namespace = ZenCacheStore::DefaultNamespace; + return true; + } + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = GetValidIoHash(Tokens[2]); + if (!Data.HashKey) + { + return false; + } + return true; + } + case 4: + { + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + + Data.HashKey = GetValidIoHash(Tokens[2]); + if (!Data.HashKey.has_value()) + { + return false; + } + + Data.ValueContentId = GetValidIoHash(Tokens[3]); + if (!Data.ValueContentId.has_value()) + { + return false; + } + return true; + } + default: + return false; + } + } + + std::optional<std::string> GetRpcRequestNamespace(const CbObjectView Params) + { + CbFieldView NamespaceField = Params["Namespace"sv]; + if (!NamespaceField) + { + return std::string(ZenCacheStore::DefaultNamespace); + } + + if (NamespaceField.HasError()) + { + return {}; + } + if (!NamespaceField.IsString()) + { + return {}; + } + return GetValidNamespaceName(NamespaceField.AsString()); + } + + bool GetRpcRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key) + { + CbFieldView BucketField = KeyView["Bucket"sv]; + if (BucketField.HasError()) + { + return false; + } + if (!BucketField.IsString()) + { + return false; + } + std::optional<std::string> Bucket = GetValidBucketName(BucketField.AsString()); + if (!Bucket.has_value()) + { + return false; + } + CbFieldView HashField = KeyView["Hash"sv]; + if (HashField.HasError()) + { + return false; + } + if (!HashField.IsHash()) + { + return false; + } + IoHash Hash = HashField.AsHash(); + Key = CacheKey::Create(*Bucket, Hash); + return true; + } + +} // namespace + +////////////////////////////////////////////////////////////////////////// + +HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCacheStore, + CidStore& InCidStore, + HttpStatsService& StatsService, + HttpStatusService& StatusService, + UpstreamCache& UpstreamCache) +: m_Log(logging::Get("cache")) +, m_CacheStore(InCacheStore) +, m_StatsService(StatsService) +, m_StatusService(StatusService) +, m_CidStore(InCidStore) +, m_UpstreamCache(UpstreamCache) +{ + m_StatsService.RegisterHandler("z$", *this); + m_StatusService.RegisterHandler("z$", *this); +} + +HttpStructuredCacheService::~HttpStructuredCacheService() +{ + ZEN_INFO("closing structured cache"); + m_RequestRecorder.reset(); + + m_StatsService.UnregisterHandler("z$", *this); + m_StatusService.UnregisterHandler("z$", *this); +} + +const char* +HttpStructuredCacheService::BaseUri() const +{ + return "/z$/"; +} + +void +HttpStructuredCacheService::Flush() +{ + m_CacheStore.Flush(); +} + +void +HttpStructuredCacheService::Scrub(ScrubContext& Ctx) +{ + if (m_LastScrubTime == Ctx.ScrubTimestamp()) + { + return; + } + + m_LastScrubTime = Ctx.ScrubTimestamp(); + + m_CidStore.Scrub(Ctx); + m_CacheStore.Scrub(Ctx); +} + +void +HttpStructuredCacheService::HandleDetailsRequest(HttpServerRequest& Request) +{ + std::string_view Key = Request.RelativeUri(); + std::vector<std::string> Tokens; + uint32_t TokenCount = ForEachStrTok(Key, '/', [&Tokens](std::string_view Token) { + Tokens.push_back(std::string(Token)); + return true; + }); + std::string FilterNamespace; + std::string FilterBucket; + std::string FilterValue; + switch (TokenCount) + { + case 1: + break; + case 2: + { + FilterNamespace = Tokens[1]; + if (FilterNamespace.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + } + break; + case 3: + { + FilterNamespace = Tokens[1]; + if (FilterNamespace.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + FilterBucket = Tokens[2]; + if (FilterBucket.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + } + break; + case 4: + { + FilterNamespace = Tokens[1]; + if (FilterNamespace.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + FilterBucket = Tokens[2]; + if (FilterBucket.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + FilterValue = Tokens[3]; + if (FilterValue.empty()) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + } + break; + default: + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + + HttpServerRequest::QueryParams Params = Request.GetQueryParams(); + bool CSV = Params.GetValue("csv") == "true"; + bool Details = Params.GetValue("details") == "true"; + bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true"; + + std::chrono::seconds NowSeconds = std::chrono::duration_cast<std::chrono::seconds>(GcClock::Now().time_since_epoch()); + CacheValueDetails ValueDetails = m_CacheStore.GetValueDetails(FilterNamespace, FilterBucket, FilterValue); + + if (CSV) + { + ExtendableStringBuilder<4096> CSVWriter; + if (AttachmentDetails) + { + CSVWriter << "Namespace, Bucket, Key, Cid, Size"; + } + else if (Details) + { + CSVWriter << "Namespace, Bucket, Key, Size, RawSize, RawHash, ContentType, Age, AttachmentsCount, AttachmentsSize"; + } + else + { + CSVWriter << "Namespace, Bucket, Key"; + } + for (const auto& NamespaceIt : ValueDetails.Namespaces) + { + const std::string& Namespace = NamespaceIt.first; + for (const auto& BucketIt : NamespaceIt.second.Buckets) + { + const std::string& Bucket = BucketIt.first; + for (const auto& ValueIt : BucketIt.second.Values) + { + if (AttachmentDetails) + { + for (const IoHash& Hash : ValueIt.second.Attachments) + { + IoBuffer Payload = m_CidStore.FindChunkByCid(Hash); + CSVWriter << "\r\n" + << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString() << ", " << Hash.ToHexString() + << ", " << gsl::narrow<uint64_t>(Payload.GetSize()); + } + } + else if (Details) + { + std::chrono::seconds LastAccessedSeconds = std::chrono::duration_cast<std::chrono::seconds>( + GcClock::TimePointFromTick(ValueIt.second.LastAccess).time_since_epoch()); + CSVWriter << "\r\n" + << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString() << ", " << ValueIt.second.Size << "," + << ValueIt.second.RawSize << "," << ValueIt.second.RawHash.ToHexString() << ", " + << ToString(ValueIt.second.ContentType) << ", " << (NowSeconds.count() - LastAccessedSeconds.count()) + << ", " << gsl::narrow<uint64_t>(ValueIt.second.Attachments.size()); + size_t AttachmentsSize = 0; + for (const IoHash& Hash : ValueIt.second.Attachments) + { + IoBuffer Payload = m_CidStore.FindChunkByCid(Hash); + AttachmentsSize += Payload.GetSize(); + } + CSVWriter << ", " << gsl::narrow<uint64_t>(AttachmentsSize); + } + else + { + CSVWriter << "\r\n" << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString(); + } + } + } + } + return Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView()); + } + else + { + CbObjectWriter Cbo; + Cbo.BeginArray("namespaces"); + { + for (const auto& NamespaceIt : ValueDetails.Namespaces) + { + const std::string& Namespace = NamespaceIt.first; + Cbo.BeginObject(); + { + Cbo.AddString("name", Namespace); + Cbo.BeginArray("buckets"); + { + for (const auto& BucketIt : NamespaceIt.second.Buckets) + { + const std::string& Bucket = BucketIt.first; + Cbo.BeginObject(); + { + Cbo.AddString("name", Bucket); + Cbo.BeginArray("values"); + { + for (const auto& ValueIt : BucketIt.second.Values) + { + std::chrono::seconds LastAccessedSeconds = std::chrono::duration_cast<std::chrono::seconds>( + GcClock::TimePointFromTick(ValueIt.second.LastAccess).time_since_epoch()); + Cbo.BeginObject(); + { + Cbo.AddHash("key", ValueIt.first); + if (Details) + { + Cbo.AddInteger("size", ValueIt.second.Size); + if (ValueIt.second.Size > 0 && ValueIt.second.RawSize != 0 && + ValueIt.second.RawSize != ValueIt.second.Size) + { + Cbo.AddInteger("rawsize", ValueIt.second.RawSize); + Cbo.AddHash("rawhash", ValueIt.second.RawHash); + } + Cbo.AddString("contenttype", ToString(ValueIt.second.ContentType)); + Cbo.AddInteger("age", NowSeconds.count() - LastAccessedSeconds.count()); + if (ValueIt.second.Attachments.size() > 0) + { + if (AttachmentDetails) + { + Cbo.BeginArray("attachments"); + { + for (const IoHash& Hash : ValueIt.second.Attachments) + { + Cbo.BeginObject(); + Cbo.AddHash("cid", Hash); + IoBuffer Payload = m_CidStore.FindChunkByCid(Hash); + Cbo.AddInteger("size", gsl::narrow<uint64_t>(Payload.GetSize())); + Cbo.EndObject(); + } + } + Cbo.EndArray(); + } + else + { + Cbo.AddInteger("attachmentcount", + gsl::narrow<uint64_t>(ValueIt.second.Attachments.size())); + size_t AttachmentsSize = 0; + for (const IoHash& Hash : ValueIt.second.Attachments) + { + IoBuffer Payload = m_CidStore.FindChunkByCid(Hash); + AttachmentsSize += Payload.GetSize(); + } + Cbo.AddInteger("attachmentssize", gsl::narrow<uint64_t>(AttachmentsSize)); + } + } + } + } + Cbo.EndObject(); + } + } + Cbo.EndArray(); + } + Cbo.EndObject(); + } + } + Cbo.EndArray(); + } + Cbo.EndObject(); + } + } + Cbo.EndArray(); + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } +} + +void +HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request) +{ + metrics::OperationTiming::Scope $(m_HttpRequests); + + std::string_view Key = Request.RelativeUri(); + if (Key == HttpZCacheRPCPrefix) + { + return HandleRpcRequest(Request); + } + + if (Key == HttpZCacheUtilStartRecording) + { + m_RequestRecorder.reset(); + HttpServerRequest::QueryParams Params = Request.GetQueryParams(); + std::string RecordPath = cpr::util::urlDecode(std::string(Params.GetValue("path"))); + m_RequestRecorder = cache::MakeDiskRequestRecorder(RecordPath); + Request.WriteResponse(HttpResponseCode::OK); + return; + } + if (Key == HttpZCacheUtilStopRecording) + { + m_RequestRecorder.reset(); + Request.WriteResponse(HttpResponseCode::OK); + return; + } + if (Key == HttpZCacheUtilReplayRecording) + { + m_RequestRecorder.reset(); + HttpServerRequest::QueryParams Params = Request.GetQueryParams(); + std::string RecordPath = cpr::util::urlDecode(std::string(Params.GetValue("path"))); + uint32_t ThreadCount = std::thread::hardware_concurrency(); + if (auto Param = Params.GetValue("thread_count"); Param.empty() == false) + { + if (auto Value = ParseInt<uint64_t>(Param)) + { + ThreadCount = gsl::narrow<uint32_t>(Value.value()); + } + } + std::unique_ptr<cache::IRpcRequestReplayer> Replayer(cache::MakeDiskRequestReplayer(RecordPath, false)); + ReplayRequestRecorder(*Replayer, ThreadCount < 1 ? 1 : ThreadCount); + Request.WriteResponse(HttpResponseCode::OK); + return; + } + if (Key.starts_with(HttpZCacheDetailsPrefix)) + { + HandleDetailsRequest(Request); + return; + } + + HttpRequestData RequestData; + if (!HttpRequestParseRelativeUri(Key, RequestData)) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL + } + + if (RequestData.ValueContentId.has_value()) + { + ZEN_ASSERT(RequestData.Namespace.has_value()); + ZEN_ASSERT(RequestData.Bucket.has_value()); + ZEN_ASSERT(RequestData.HashKey.has_value()); + CacheRef Ref = {.Namespace = RequestData.Namespace.value(), + .BucketSegment = RequestData.Bucket.value(), + .HashKey = RequestData.HashKey.value(), + .ValueContentId = RequestData.ValueContentId.value()}; + return HandleCacheChunkRequest(Request, Ref, ParseCachePolicy(Request.GetQueryParams())); + } + + if (RequestData.HashKey.has_value()) + { + ZEN_ASSERT(RequestData.Namespace.has_value()); + ZEN_ASSERT(RequestData.Bucket.has_value()); + CacheRef Ref = {.Namespace = RequestData.Namespace.value(), + .BucketSegment = RequestData.Bucket.value(), + .HashKey = RequestData.HashKey.value(), + .ValueContentId = IoHash::Zero}; + return HandleCacheRecordRequest(Request, Ref, ParseCachePolicy(Request.GetQueryParams())); + } + + if (RequestData.Bucket.has_value()) + { + ZEN_ASSERT(RequestData.Namespace.has_value()); + return HandleCacheBucketRequest(Request, RequestData.Namespace.value(), RequestData.Bucket.value()); + } + + if (RequestData.Namespace.has_value()) + { + return HandleCacheNamespaceRequest(Request, RequestData.Namespace.value()); + } + return HandleCacheRequest(Request); +} + +void +HttpStructuredCacheService::HandleCacheRequest(HttpServerRequest& Request) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + ZenCacheStore::Info Info = m_CacheStore.GetInfo(); + + CbObjectWriter ResponseWriter; + + ResponseWriter.BeginObject("Configuration"); + { + ExtendableStringBuilder<128> BasePathString; + BasePathString << Info.Config.BasePath.u8string(); + ResponseWriter.AddString("BasePath"sv, BasePathString.ToView()); + ResponseWriter.AddBool("AllowAutomaticCreationOfNamespaces", Info.Config.AllowAutomaticCreationOfNamespaces); + } + ResponseWriter.EndObject(); + + std::sort(begin(Info.NamespaceNames), end(Info.NamespaceNames), [](std::string_view L, std::string_view R) { + return L.compare(R) < 0; + }); + ResponseWriter.BeginArray("Namespaces"); + for (const std::string& NamespaceName : Info.NamespaceNames) + { + ResponseWriter.AddString(NamespaceName); + } + ResponseWriter.EndArray(); + ResponseWriter.BeginObject("StorageSize"); + { + ResponseWriter.AddInteger("DiskSize", Info.StorageSize.DiskSize); + ResponseWriter.AddInteger("MemorySize", Info.StorageSize.MemorySize); + } + + ResponseWriter.EndObject(); + + ResponseWriter.AddInteger("DiskEntryCount", Info.DiskEntryCount); + ResponseWriter.AddInteger("MemoryEntryCount", Info.MemoryEntryCount); + + return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); + } + break; + } +} + +void +HttpStructuredCacheService::HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view NamespaceName) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + std::optional<ZenCacheNamespace::Info> Info = m_CacheStore.GetNamespaceInfo(NamespaceName); + if (!Info.has_value()) + { + return Request.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter ResponseWriter; + + ResponseWriter.BeginObject("Configuration"); + { + ExtendableStringBuilder<128> BasePathString; + BasePathString << Info->Config.RootDir.u8string(); + ResponseWriter.AddString("RootDir"sv, BasePathString.ToView()); + ResponseWriter.AddInteger("DiskLayerThreshold"sv, Info->Config.DiskLayerThreshold); + } + ResponseWriter.EndObject(); + + std::sort(begin(Info->BucketNames), end(Info->BucketNames), [](std::string_view L, std::string_view R) { + return L.compare(R) < 0; + }); + + ResponseWriter.BeginArray("Buckets"sv); + for (const std::string& BucketName : Info->BucketNames) + { + ResponseWriter.AddString(BucketName); + } + ResponseWriter.EndArray(); + + ResponseWriter.BeginObject("StorageSize"sv); + { + ResponseWriter.AddInteger("DiskSize"sv, Info->DiskLayerInfo.TotalSize); + ResponseWriter.AddInteger("MemorySize"sv, Info->MemoryLayerInfo.TotalSize); + } + ResponseWriter.EndObject(); + + ResponseWriter.AddInteger("DiskEntryCount", Info->DiskLayerInfo.EntryCount); + ResponseWriter.AddInteger("MemoryEntryCount", Info->MemoryLayerInfo.EntryCount); + + return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); + } + break; + + case HttpVerb::kDelete: + // Drop namespace + { + if (m_CacheStore.DropNamespace(NamespaceName)) + { + return Request.WriteResponse(HttpResponseCode::OK); + } + else + { + return Request.WriteResponse(HttpResponseCode::NotFound); + } + } + break; + + default: + break; + } +} + +void +HttpStructuredCacheService::HandleCacheBucketRequest(HttpServerRequest& Request, + std::string_view NamespaceName, + std::string_view BucketName) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + std::optional<ZenCacheNamespace::BucketInfo> Info = m_CacheStore.GetBucketInfo(NamespaceName, BucketName); + if (!Info.has_value()) + { + return Request.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter ResponseWriter; + + ResponseWriter.BeginObject("StorageSize"); + { + ResponseWriter.AddInteger("DiskSize", Info->DiskLayerInfo.TotalSize); + ResponseWriter.AddInteger("MemorySize", Info->MemoryLayerInfo.TotalSize); + } + ResponseWriter.EndObject(); + + ResponseWriter.AddInteger("DiskEntryCount", Info->DiskLayerInfo.EntryCount); + ResponseWriter.AddInteger("MemoryEntryCount", Info->MemoryLayerInfo.EntryCount); + + return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save()); + } + break; + + case HttpVerb::kDelete: + // Drop bucket + { + if (m_CacheStore.DropBucket(NamespaceName, BucketName)) + { + return Request.WriteResponse(HttpResponseCode::OK); + } + else + { + return Request.WriteResponse(HttpResponseCode::NotFound); + } + } + break; + + default: + break; + } +} + +void +HttpStructuredCacheService::HandleCacheRecordRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + HandleGetCacheRecord(Request, Ref, PolicyFromUrl); + } + break; + + case HttpVerb::kPut: + HandlePutCacheRecord(Request, Ref, PolicyFromUrl); + break; + default: + break; + } +} + +void +HttpStructuredCacheService::HandleGetCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + const ZenContentType AcceptType = Request.AcceptContentType(); + const bool SkipData = EnumHasAllFlags(PolicyFromUrl, CachePolicy::SkipData); + const bool PartialRecord = EnumHasAllFlags(PolicyFromUrl, CachePolicy::PartialRecord); + + bool Success = false; + ZenCacheValue ClientResultValue; + if (!EnumHasAnyFlags(PolicyFromUrl, CachePolicy::Query)) + { + return Request.WriteResponse(HttpResponseCode::OK); + } + + Stopwatch Timer; + + if (EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryLocal) && + m_CacheStore.Get(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, ClientResultValue)) + { + Success = true; + ZenContentType ContentType = ClientResultValue.Value.GetContentType(); + + if (AcceptType == ZenContentType::kCbPackage) + { + if (ContentType == ZenContentType::kCbObject) + { + CbPackage Package; + uint32_t MissingCount = 0; + + CbObjectView CacheRecord(ClientResultValue.Value.Data()); + CacheRecord.IterateAttachments([this, &MissingCount, &Package, SkipData](CbFieldView AttachmentHash) { + if (SkipData) + { + if (!m_CidStore.ContainsChunk(AttachmentHash.AsHash())) + { + MissingCount++; + } + } + else + { + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(AttachmentHash.AsHash())) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk)); + Package.AddAttachment(CbAttachment(Compressed, AttachmentHash.AsHash())); + } + else + { + MissingCount++; + } + } + }); + + Success = MissingCount == 0 || PartialRecord; + + if (Success) + { + Package.SetObject(LoadCompactBinaryObject(ClientResultValue.Value)); + + BinaryWriter MemStream; + Package.Save(MemStream); + + ClientResultValue.Value = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + ClientResultValue.Value.SetContentType(HttpContentType::kCbPackage); + } + } + else + { + Success = false; + } + } + else if (AcceptType != ClientResultValue.Value.GetContentType() && AcceptType != ZenContentType::kUnknownContentType && + AcceptType != ZenContentType::kBinary) + { + Success = false; + } + } + + if (Success) + { + ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {} '{}' (LOCAL) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(ClientResultValue.Value.Size()), + ToString(ClientResultValue.Value.GetContentType()), + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + m_CacheStats.HitCount++; + if (SkipData && AcceptType != ZenContentType::kCbPackage && AcceptType != ZenContentType::kCbObject) + { + return Request.WriteResponse(HttpResponseCode::OK); + } + else + { + // kCbPackage handled SkipData when constructing the ClientResultValue, kcbObject ignores SkipData + return Request.WriteResponse(HttpResponseCode::OK, ClientResultValue.Value.GetContentType(), ClientResultValue.Value); + } + } + else if (!EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryRemote)) + { + ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}' '{}' in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(AcceptType), + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.MissCount++; + return Request.WriteResponse(HttpResponseCode::NotFound); + } + + // Issue upstream query asynchronously in order to keep requests flowing without + // hogging I/O servicing threads with blocking work + + uint64_t LocalElapsedTimeUs = Timer.GetElapsedTimeUs(); + + Request.WriteResponseAsync([this, AcceptType, PolicyFromUrl, Ref, LocalElapsedTimeUs](HttpServerRequest& AsyncRequest) { + Stopwatch Timer; + bool Success = false; + const bool PartialRecord = EnumHasAllFlags(PolicyFromUrl, CachePolicy::PartialRecord); + const bool QueryLocal = EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryLocal); + const bool StoreLocal = EnumHasAllFlags(PolicyFromUrl, CachePolicy::StoreLocal); + const bool SkipData = EnumHasAllFlags(PolicyFromUrl, CachePolicy::SkipData); + ZenCacheValue ClientResultValue; + + metrics::OperationTiming::Scope $(m_UpstreamGetRequestTiming); + + if (GetUpstreamCacheSingleResult UpstreamResult = + m_UpstreamCache.GetCacheRecord(Ref.Namespace, {Ref.BucketSegment, Ref.HashKey}, AcceptType); + UpstreamResult.Status.Success) + { + Success = true; + + ClientResultValue.Value = UpstreamResult.Value; + ClientResultValue.Value.SetContentType(AcceptType); + + if (AcceptType == ZenContentType::kBinary || AcceptType == ZenContentType::kCbObject) + { + if (AcceptType == ZenContentType::kCbObject) + { + const CbValidateError ValidationResult = ValidateCompactBinary(UpstreamResult.Value, CbValidateMode::All); + if (ValidationResult != CbValidateError::None) + { + Success = false; + ZEN_WARN("Get - '{}/{}/{}' '{}' FAILED, invalid compact binary object from upstream", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(AcceptType)); + } + + // We do not do anything to the returned object for SkipData, only package attachments are cut when skipping data + } + + if (Success && StoreLocal) + { + m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, ClientResultValue); + } + } + else if (AcceptType == ZenContentType::kCbPackage) + { + CbPackage Package; + if (Package.TryLoad(ClientResultValue.Value)) + { + CbObject CacheRecord = Package.GetObject(); + AttachmentCount Count; + size_t NumAttachments = Package.GetAttachments().size(); + std::vector<const CbAttachment*> AttachmentsToStoreLocally; + AttachmentsToStoreLocally.reserve(NumAttachments); + + CacheRecord.IterateAttachments( + [this, &Package, &Ref, &AttachmentsToStoreLocally, &Count, QueryLocal, StoreLocal, SkipData](CbFieldView HashView) { + IoHash Hash = HashView.AsHash(); + if (const CbAttachment* Attachment = Package.FindAttachment(Hash)) + { + if (Attachment->IsCompressedBinary()) + { + if (StoreLocal) + { + AttachmentsToStoreLocally.emplace_back(Attachment); + } + Count.Valid++; + } + else + { + ZEN_WARN("Uncompressed value '{}' from upstream cache record '{}/{}'", + Hash, + Ref.BucketSegment, + Ref.HashKey); + Count.Invalid++; + } + } + else if (QueryLocal) + { + if (SkipData) + { + if (m_CidStore.ContainsChunk(Hash)) + { + Count.Valid++; + } + } + else if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Hash)) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk)); + if (Compressed) + { + Package.AddAttachment(CbAttachment(Compressed, Hash)); + Count.Valid++; + } + else + { + ZEN_WARN("Uncompressed value '{}' stored in local cache '{}/{}'", + Hash, + Ref.BucketSegment, + Ref.HashKey); + Count.Invalid++; + } + } + } + Count.Total++; + }); + + if ((Count.Valid == Count.Total) || PartialRecord) + { + ZenCacheValue CacheValue; + CacheValue.Value = CacheRecord.GetBuffer().AsIoBuffer(); + CacheValue.Value.SetContentType(ZenContentType::kCbObject); + + if (StoreLocal) + { + m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, CacheValue); + } + + for (const CbAttachment* Attachment : AttachmentsToStoreLocally) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash()); + if (InsertResult.New) + { + Count.New++; + } + } + + BinaryWriter MemStream; + if (SkipData) + { + // Save a package containing only the object. + CbPackage(Package.GetObject()).Save(MemStream); + } + else + { + Package.Save(MemStream); + } + + ClientResultValue.Value = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + ClientResultValue.Value.SetContentType(ZenContentType::kCbPackage); + } + else + { + Success = false; + ZEN_WARN("Get - '{}/{}' '{}' FAILED, attachments missing in upstream package", + Ref.BucketSegment, + Ref.HashKey, + ToString(AcceptType)); + } + } + else + { + Success = false; + ZEN_WARN("Get - '{}/{}/{}' '{}' FAILED, invalid upstream package", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(AcceptType)); + } + } + } + + if (Success) + { + ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {} '{}' (UPSTREAM) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(ClientResultValue.Value.Size()), + ToString(ClientResultValue.Value.GetContentType()), + NiceLatencyNs((LocalElapsedTimeUs + Timer.GetElapsedTimeUs()) * 1000)); + + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount++; + + if (SkipData && AcceptType == ZenContentType::kBinary) + { + AsyncRequest.WriteResponse(HttpResponseCode::OK); + } + else + { + // Other methods modify ClientResultValue to a version that has skipped the data but keeps the Object and optionally + // metadata. + AsyncRequest.WriteResponse(HttpResponseCode::OK, ClientResultValue.Value.GetContentType(), ClientResultValue.Value); + } + } + else + { + ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}' '{}' in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(AcceptType), + NiceLatencyNs((LocalElapsedTimeUs + Timer.GetElapsedTimeUs()) * 1000)); + m_CacheStats.MissCount++; + AsyncRequest.WriteResponse(HttpResponseCode::NotFound); + } + }); +} + +void +HttpStructuredCacheService::HandlePutCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + IoBuffer Body = Request.ReadPayload(); + + if (!Body || Body.Size() == 0) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + const HttpContentType ContentType = Request.RequestContentType(); + + Body.SetContentType(ContentType); + + Stopwatch Timer; + + if (ContentType == HttpContentType::kBinary || ContentType == HttpContentType::kCompressedBinary) + { + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = Body.GetSize(); + if (ContentType == HttpContentType::kCompressedBinary) + { + if (!CompressedBuffer::ValidateCompressedHeader(Body, RawHash, RawSize)) + { + return Request.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Payload is not a valid compressed binary"sv); + } + } + else + { + RawHash = IoHash::HashBuffer(SharedBuffer(Body)); + } + m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, {.Value = Body, .RawSize = RawSize, .RawHash = RawHash}); + + if (EnumHasAllFlags(PolicyFromUrl, CachePolicy::StoreRemote)) + { + m_UpstreamCache.EnqueueUpstream({.Type = ContentType, .Namespace = Ref.Namespace, .Key = {Ref.BucketSegment, Ref.HashKey}}); + } + + ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}' in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(Body.Size()), + ToString(ContentType), + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + Request.WriteResponse(HttpResponseCode::Created); + } + else if (ContentType == HttpContentType::kCbObject) + { + const CbValidateError ValidationResult = ValidateCompactBinary(MemoryView(Body.GetData(), Body.GetSize()), CbValidateMode::All); + + if (ValidationResult != CbValidateError::None) + { + ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, invalid compact binary", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(ContentType)); + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Compact binary validation failed"sv); + } + + Body.SetContentType(ZenContentType::kCbObject); + m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, {.Value = Body}); + + CbObjectView CacheRecord(Body.Data()); + std::vector<IoHash> ValidAttachments; + int32_t TotalCount = 0; + + CacheRecord.IterateAttachments([this, &TotalCount, &ValidAttachments](CbFieldView AttachmentHash) { + const IoHash Hash = AttachmentHash.AsHash(); + if (m_CidStore.ContainsChunk(Hash)) + { + ValidAttachments.emplace_back(Hash); + } + TotalCount++; + }); + + ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}' attachments '{}/{}' (valid/total) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(Body.Size()), + ToString(ContentType), + TotalCount, + ValidAttachments.size(), + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + const bool IsPartialRecord = TotalCount != static_cast<int32_t>(ValidAttachments.size()); + + CachePolicy Policy = PolicyFromUrl; + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) && !IsPartialRecord) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbObject, + .Namespace = Ref.Namespace, + .Key = {Ref.BucketSegment, Ref.HashKey}, + .ValueContentIds = std::move(ValidAttachments)}); + } + + Request.WriteResponse(HttpResponseCode::Created); + } + else if (ContentType == HttpContentType::kCbPackage) + { + CbPackage Package; + + if (!Package.TryLoad(Body)) + { + ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, invalid package", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(ContentType)); + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid package"sv); + } + CachePolicy Policy = PolicyFromUrl; + + CbObject CacheRecord = Package.GetObject(); + + AttachmentCount Count; + size_t NumAttachments = Package.GetAttachments().size(); + std::vector<IoHash> ValidAttachments; + std::vector<const CbAttachment*> AttachmentsToStoreLocally; + ValidAttachments.reserve(NumAttachments); + AttachmentsToStoreLocally.reserve(NumAttachments); + + CacheRecord.IterateAttachments([this, &Ref, &Package, &AttachmentsToStoreLocally, &ValidAttachments, &Count](CbFieldView HashView) { + const IoHash Hash = HashView.AsHash(); + if (const CbAttachment* Attachment = Package.FindAttachment(Hash)) + { + if (Attachment->IsCompressedBinary()) + { + AttachmentsToStoreLocally.emplace_back(Attachment); + ValidAttachments.emplace_back(Hash); + Count.Valid++; + } + else + { + ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, attachment '{}' is not compressed", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + ToString(HttpContentType::kCbPackage), + Hash); + Count.Invalid++; + } + } + else if (m_CidStore.ContainsChunk(Hash)) + { + ValidAttachments.emplace_back(Hash); + Count.Valid++; + } + Count.Total++; + }); + + if (Count.Invalid > 0) + { + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid attachment(s)"sv); + } + + ZenCacheValue CacheValue; + CacheValue.Value = CacheRecord.GetBuffer().AsIoBuffer(); + CacheValue.Value.SetContentType(ZenContentType::kCbObject); + m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, CacheValue); + + for (const CbAttachment* Attachment : AttachmentsToStoreLocally) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash()); + if (InsertResult.New) + { + Count.New++; + } + } + + ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}', attachments '{}/{}/{}' (new/valid/total) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + NiceBytes(Body.GetSize()), + ToString(ContentType), + Count.New, + Count.Valid, + Count.Total, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + const bool IsPartialRecord = Count.Valid != Count.Total; + + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) && !IsPartialRecord) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage, + .Namespace = Ref.Namespace, + .Key = {Ref.BucketSegment, Ref.HashKey}, + .ValueContentIds = std::move(ValidAttachments)}); + } + + Request.WriteResponse(HttpResponseCode::Created); + } + else + { + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Content-Type invalid"sv); + } +} + +void +HttpStructuredCacheService::HandleCacheChunkRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + HandleGetCacheChunk(Request, Ref, PolicyFromUrl); + break; + case HttpVerb::kPut: + HandlePutCacheChunk(Request, Ref, PolicyFromUrl); + break; + default: + break; + } +} + +void +HttpStructuredCacheService::HandleGetCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + Stopwatch Timer; + + IoBuffer Value = m_CidStore.FindChunkByCid(Ref.ValueContentId); + const UpstreamEndpointInfo* Source = nullptr; + CachePolicy Policy = PolicyFromUrl; + { + const bool QueryUpstream = !Value && EnumHasAllFlags(Policy, CachePolicy::QueryRemote); + + if (QueryUpstream) + { + if (GetUpstreamCacheSingleResult UpstreamResult = + m_UpstreamCache.GetCacheChunk(Ref.Namespace, {Ref.BucketSegment, Ref.HashKey}, Ref.ValueContentId); + UpstreamResult.Status.Success) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(UpstreamResult.Value, RawHash, RawSize)) + { + if (RawHash == Ref.ValueContentId) + { + m_CidStore.AddChunk(UpstreamResult.Value, RawHash); + Source = UpstreamResult.Source; + } + else + { + ZEN_WARN("got missmatching upstream cache value"); + } + } + else + { + ZEN_WARN("got uncompressed upstream cache value"); + } + } + } + } + + if (!Value) + { + ZEN_DEBUG("GETCACHECHUNK MISS - '{}/{}/{}/{}' '{}' in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + Ref.ValueContentId, + ToString(Request.AcceptContentType()), + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.MissCount++; + return Request.WriteResponse(HttpResponseCode::NotFound); + } + + ZEN_DEBUG("GETCACHECHUNK HIT - '{}/{}/{}/{}' {} '{}' ({}) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + Ref.ValueContentId, + NiceBytes(Value.Size()), + ToString(Value.GetContentType()), + Source ? Source->Url : "LOCAL"sv, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + m_CacheStats.HitCount++; + if (Source) + { + m_CacheStats.UpstreamHitCount++; + } + + if (EnumHasAllFlags(Policy, CachePolicy::SkipData)) + { + Request.WriteResponse(HttpResponseCode::OK); + } + else + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value); + } +} + +void +HttpStructuredCacheService::HandlePutCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl) +{ + // Note: Individual cacherecord values are not propagated upstream until a valid cache record has been stored + ZEN_UNUSED(PolicyFromUrl); + + Stopwatch Timer; + + IoBuffer Body = Request.ReadPayload(); + + if (!Body || Body.Size() == 0) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + Body.SetContentType(Request.RequestContentType()); + + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(Body, RawHash, RawSize)) + { + return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Attachments must be compressed"sv); + } + + if (RawHash != Ref.ValueContentId) + { + return Request.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "ValueContentId does not match attachment hash"sv); + } + + CidStore::InsertResult Result = m_CidStore.AddChunk(Body, RawHash); + + ZEN_DEBUG("PUTCACHECHUNK - '{}/{}/{}/{}' {} '{}' ({}) in {}", + Ref.Namespace, + Ref.BucketSegment, + Ref.HashKey, + Ref.ValueContentId, + NiceBytes(Body.Size()), + ToString(Body.GetContentType()), + Result.New ? "NEW" : "OLD", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + const HttpResponseCode ResponseCode = Result.New ? HttpResponseCode::Created : HttpResponseCode::OK; + + Request.WriteResponse(ResponseCode); +} + +CbPackage +HttpStructuredCacheService::HandleRpcRequest(const ZenContentType ContentType, + IoBuffer&& Body, + uint32_t& OutAcceptMagic, + RpcAcceptOptions& OutAcceptFlags, + int& OutTargetProcessId) +{ + CbPackage Package; + CbObjectView Object; + CbObject ObjectBuffer; + if (ContentType == ZenContentType::kCbObject) + { + ObjectBuffer = LoadCompactBinaryObject(std::move(Body)); + Object = ObjectBuffer; + } + else + { + Package = ParsePackageMessage(Body); + Object = Package.GetObject(); + } + OutAcceptMagic = Object["Accept"sv].AsUInt32(); + OutAcceptFlags = static_cast<RpcAcceptOptions>(Object["AcceptFlags"sv].AsUInt16(0u)); + OutTargetProcessId = Object["Pid"sv].AsInt32(0); + + const std::string_view Method = Object["Method"sv].AsString(); + + if (Method == "PutCacheRecords"sv) + { + return HandleRpcPutCacheRecords(Package); + } + else if (Method == "GetCacheRecords"sv) + { + return HandleRpcGetCacheRecords(Object); + } + else if (Method == "PutCacheValues"sv) + { + return HandleRpcPutCacheValues(Package); + } + else if (Method == "GetCacheValues"sv) + { + return HandleRpcGetCacheValues(Object); + } + else if (Method == "GetCacheChunks"sv) + { + return HandleRpcGetCacheChunks(Object); + } + return CbPackage{}; +} + +void +HttpStructuredCacheService::ReplayRequestRecorder(cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount) +{ + WorkerThreadPool WorkerPool(ThreadCount); + uint64_t RequestCount = Replayer.GetRequestCount(); + Stopwatch Timer; + auto _ = MakeGuard([&]() { ZEN_INFO("Replayed {} requests in {}", RequestCount, NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); }); + Latch JobLatch(RequestCount); + ZEN_INFO("Replaying {} requests", RequestCount); + for (uint64_t RequestIndex = 0; RequestIndex < RequestCount; ++RequestIndex) + { + WorkerPool.ScheduleWork([this, &JobLatch, &Replayer, RequestIndex]() { + IoBuffer Body; + std::pair<ZenContentType, ZenContentType> ContentType = Replayer.GetRequest(RequestIndex, Body); + if (Body) + { + uint32_t AcceptMagic = 0; + RpcAcceptOptions AcceptFlags = RpcAcceptOptions::kNone; + int TargetPid = 0; + CbPackage RpcResult = HandleRpcRequest(ContentType.first, std::move(Body), AcceptMagic, AcceptFlags, TargetPid); + if (AcceptMagic == kCbPkgMagic) + { + FormatFlags Flags = FormatFlags::kDefault; + if (EnumHasAllFlags(AcceptFlags, RpcAcceptOptions::kAllowLocalReferences)) + { + Flags |= FormatFlags::kAllowLocalReferences; + if (!EnumHasAnyFlags(AcceptFlags, RpcAcceptOptions::kAllowPartialLocalReferences)) + { + Flags |= FormatFlags::kDenyPartialLocalReferences; + } + } + CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(RpcResult, Flags, TargetPid); + ZEN_ASSERT(RpcResponseBuffer.GetSize() > 0); + } + else + { + BinaryWriter MemStream; + RpcResult.Save(MemStream); + IoBuffer RpcResponseBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()); + ZEN_ASSERT(RpcResponseBuffer.Size() > 0); + } + } + JobLatch.CountDown(); + }); + } + while (!JobLatch.Wait(10000)) + { + ZEN_INFO("Replayed {} of {} requests, elapsed {}", + RequestCount - JobLatch.Remaining(), + RequestCount, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + } +} + +void +HttpStructuredCacheService::HandleRpcRequest(HttpServerRequest& Request) +{ + switch (Request.RequestVerb()) + { + case HttpVerb::kPost: + { + const HttpContentType ContentType = Request.RequestContentType(); + const HttpContentType AcceptType = Request.AcceptContentType(); + + if ((ContentType != HttpContentType::kCbObject && ContentType != HttpContentType::kCbPackage) || + AcceptType != HttpContentType::kCbPackage) + { + return Request.WriteResponse(HttpResponseCode::BadRequest); + } + + Request.WriteResponseAsync( + [this, Body = Request.ReadPayload(), ContentType, AcceptType](HttpServerRequest& AsyncRequest) mutable { + std::uint64_t RequestIndex = + m_RequestRecorder ? m_RequestRecorder->RecordRequest(ContentType, AcceptType, Body) : ~0ull; + uint32_t AcceptMagic = 0; + RpcAcceptOptions AcceptFlags = RpcAcceptOptions::kNone; + int TargetProcessId = 0; + CbPackage RpcResult = HandleRpcRequest(ContentType, std::move(Body), AcceptMagic, AcceptFlags, TargetProcessId); + if (RpcResult.IsNull()) + { + AsyncRequest.WriteResponse(HttpResponseCode::BadRequest); + return; + } + if (AcceptMagic == kCbPkgMagic) + { + FormatFlags Flags = FormatFlags::kDefault; + if (EnumHasAllFlags(AcceptFlags, RpcAcceptOptions::kAllowLocalReferences)) + { + Flags |= FormatFlags::kAllowLocalReferences; + if (!EnumHasAnyFlags(AcceptFlags, RpcAcceptOptions::kAllowPartialLocalReferences)) + { + Flags |= FormatFlags::kDenyPartialLocalReferences; + } + } + CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(RpcResult, Flags, TargetProcessId); + if (RequestIndex != ~0ull) + { + ZEN_ASSERT(m_RequestRecorder); + m_RequestRecorder->RecordResponse(RequestIndex, HttpContentType::kCbPackage, RpcResponseBuffer); + } + AsyncRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer); + } + else + { + BinaryWriter MemStream; + RpcResult.Save(MemStream); + + if (RequestIndex != ~0ull) + { + ZEN_ASSERT(m_RequestRecorder); + m_RequestRecorder->RecordResponse(RequestIndex, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); + } + AsyncRequest.WriteResponse(HttpResponseCode::OK, + HttpContentType::kCbPackage, + IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize())); + } + }); + } + break; + default: + Request.WriteResponse(HttpResponseCode::BadRequest); + break; + } +} + +CbPackage +HttpStructuredCacheService::HandleRpcPutCacheRecords(const CbPackage& BatchRequest) +{ + ZEN_TRACE_CPU("Z$::RpcPutCacheRecords"); + CbObjectView BatchObject = BatchRequest.GetObject(); + ZEN_ASSERT(BatchObject["Method"sv].AsString() == "PutCacheRecords"sv); + + CbObjectView Params = BatchObject["Params"sv].AsObjectView(); + CachePolicy DefaultPolicy; + + std::string_view PolicyText = Params["DefaultPolicy"].AsString(); + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::vector<bool> Results; + for (CbFieldView RequestField : Params["Requests"sv]) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView RecordObject = RequestObject["Record"sv].AsObjectView(); + CbObjectView KeyView = RecordObject["Key"sv].AsObjectView(); + + CacheKey Key; + if (!GetRpcRequestCacheKey(KeyView, Key)) + { + return CbPackage{}; + } + CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy); + PutRequestData PutRequest{*Namespace, std::move(Key), RecordObject, std::move(Policy)}; + + PutResult Result = PutCacheRecord(PutRequest, &BatchRequest); + + if (Result == PutResult::Invalid) + { + return CbPackage{}; + } + Results.push_back(Result == PutResult::Success); + } + if (Results.empty()) + { + return CbPackage{}; + } + + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (bool Value : Results) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + CbPackage RpcResponse; + RpcResponse.SetObject(ResponseObject.Save()); + return RpcResponse; +} + +HttpStructuredCacheService::PutResult +HttpStructuredCacheService::PutCacheRecord(PutRequestData& Request, const CbPackage* Package) +{ + CbObjectView Record = Request.RecordObject; + uint64_t RecordObjectSize = Record.GetSize(); + uint64_t TransferredSize = RecordObjectSize; + + AttachmentCount Count; + size_t NumAttachments = Package->GetAttachments().size(); + std::vector<IoHash> ValidAttachments; + std::vector<const CbAttachment*> AttachmentsToStoreLocally; + ValidAttachments.reserve(NumAttachments); + AttachmentsToStoreLocally.reserve(NumAttachments); + + Stopwatch Timer; + + Request.RecordObject.IterateAttachments( + [this, &Request, Package, &AttachmentsToStoreLocally, &ValidAttachments, &Count, &TransferredSize](CbFieldView HashView) { + const IoHash ValueHash = HashView.AsHash(); + if (const CbAttachment* Attachment = Package ? Package->FindAttachment(ValueHash) : nullptr) + { + if (Attachment->IsCompressedBinary()) + { + AttachmentsToStoreLocally.emplace_back(Attachment); + ValidAttachments.emplace_back(ValueHash); + Count.Valid++; + } + else + { + ZEN_WARN("PUTCACEHRECORD - '{}/{}/{}' '{}' FAILED, attachment '{}' is not compressed", + Request.Namespace, + Request.Key.Bucket, + Request.Key.Hash, + ToString(HttpContentType::kCbPackage), + ValueHash); + Count.Invalid++; + } + } + else if (m_CidStore.ContainsChunk(ValueHash)) + { + ValidAttachments.emplace_back(ValueHash); + Count.Valid++; + } + Count.Total++; + }); + + if (Count.Invalid > 0) + { + return PutResult::Invalid; + } + + ZenCacheValue CacheValue; + CacheValue.Value = IoBuffer(Record.GetSize()); + Record.CopyTo(MutableMemoryView(CacheValue.Value.MutableData(), CacheValue.Value.GetSize())); + CacheValue.Value.SetContentType(ZenContentType::kCbObject); + m_CacheStore.Put(Request.Namespace, Request.Key.Bucket, Request.Key.Hash, CacheValue); + + for (const CbAttachment* Attachment : AttachmentsToStoreLocally) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash()); + if (InsertResult.New) + { + Count.New++; + } + TransferredSize += Chunk.GetCompressedSize(); + } + + ZEN_DEBUG("PUTCACEHRECORD - '{}/{}/{}' {}, attachments '{}/{}/{}' (new/valid/total) in {}", + Request.Namespace, + Request.Key.Bucket, + Request.Key.Hash, + NiceBytes(TransferredSize), + Count.New, + Count.Valid, + Count.Total, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + + const bool IsPartialRecord = Count.Valid != Count.Total; + + if (EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::StoreRemote) && !IsPartialRecord) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage, + .Namespace = Request.Namespace, + .Key = Request.Key, + .ValueContentIds = std::move(ValidAttachments)}); + } + return PutResult::Success; +} + +CbPackage +HttpStructuredCacheService::HandleRpcGetCacheRecords(CbObjectView RpcRequest) +{ + ZEN_TRACE_CPU("Z$::RpcGetCacheRecords"); + + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheRecords"sv); + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + + struct ValueRequestData + { + Oid ValueId; + IoHash ContentId; + CompressedBuffer Payload; + CachePolicy DownstreamPolicy; + bool Exists = false; + bool ReadFromUpstream = false; + }; + struct RecordRequestData + { + CacheKeyRequest Upstream; + CbObjectView RecordObject; + IoBuffer RecordCacheValue; + CacheRecordPolicy DownstreamPolicy; + std::vector<ValueRequestData> Values; + bool Complete = false; + const UpstreamEndpointInfo* Source = nullptr; + uint64_t ElapsedTimeUs; + }; + + std::string_view PolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + std::vector<RecordRequestData> Requests; + std::vector<size_t> UpstreamIndexes; + CbArrayView RequestsArray = Params["Requests"sv].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + + auto ParseValues = [](RecordRequestData& Request) { + CbArrayView ValuesArray = Request.RecordObject["Values"sv].AsArrayView(); + Request.Values.reserve(ValuesArray.Num()); + for (CbFieldView ValueField : ValuesArray) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + CbFieldView RawHashField = ValueObject["RawHash"sv]; + IoHash RawHash = RawHashField.AsBinaryAttachment(); + if (ValueId && !RawHashField.HasError()) + { + Request.Values.push_back({ValueId, RawHash}); + Request.Values.back().DownstreamPolicy = Request.DownstreamPolicy.GetValuePolicy(ValueId); + } + } + }; + + for (CbFieldView RequestField : RequestsArray) + { + Stopwatch Timer; + RecordRequestData& Request = Requests.emplace_back(); + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + + CacheKey& Key = Request.Upstream.Key; + if (!GetRpcRequestCacheKey(KeyObject, Key)) + { + return CbPackage{}; + } + + Request.DownstreamPolicy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy); + const CacheRecordPolicy& Policy = Request.DownstreamPolicy; + + ZenCacheValue CacheValue; + bool NeedUpstreamAttachment = false; + bool FoundLocalInvalid = false; + ZenCacheValue RecordCacheValue; + + if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal) && + m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, RecordCacheValue)) + { + Request.RecordCacheValue = std::move(RecordCacheValue.Value); + if (Request.RecordCacheValue.GetContentType() != ZenContentType::kCbObject) + { + FoundLocalInvalid = true; + } + else + { + Request.RecordObject = CbObjectView(Request.RecordCacheValue.GetData()); + ParseValues(Request); + + Request.Complete = true; + for (ValueRequestData& Value : Request.Values) + { + CachePolicy ValuePolicy = Value.DownstreamPolicy; + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal)) + { + // A value that is requested without the Query flag (such as None/Disable) counts as existing, because we + // didn't ask for it and thus the record is complete in its absence. + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + Value.Exists = true; + } + else + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + Request.Complete = false; + } + } + else if (EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + if (m_CidStore.ContainsChunk(Value.ContentId)) + { + Value.Exists = true; + } + else + { + if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + } + Request.Complete = false; + } + } + else + { + if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Value.ContentId)) + { + ZEN_ASSERT(Chunk.GetSize() > 0); + Value.Payload = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk)); + Value.Exists = true; + } + else + { + if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + NeedUpstreamAttachment = true; + Value.ReadFromUpstream = true; + } + Request.Complete = false; + } + } + } + } + } + if (!Request.Complete) + { + bool NeedUpstreamRecord = + !Request.RecordObject && !FoundLocalInvalid && EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote); + if (NeedUpstreamRecord || NeedUpstreamAttachment) + { + UpstreamIndexes.push_back(Requests.size() - 1); + } + } + Request.ElapsedTimeUs = Timer.GetElapsedTimeUs(); + } + if (Requests.empty()) + { + return CbPackage{}; + } + + if (!UpstreamIndexes.empty()) + { + std::vector<CacheKeyRequest*> UpstreamRequests; + UpstreamRequests.reserve(UpstreamIndexes.size()); + for (size_t Index : UpstreamIndexes) + { + RecordRequestData& Request = Requests[Index]; + UpstreamRequests.push_back(&Request.Upstream); + + if (Request.Values.size()) + { + // We will be returning the local object and know all the value Ids that exist in it + // Convert all their Downstream Values to upstream values, and add SkipData to any ones that we already have. + CachePolicy UpstreamBasePolicy = ConvertToUpstream(Request.DownstreamPolicy.GetBasePolicy()) | CachePolicy::SkipMeta; + CacheRecordPolicyBuilder Builder(UpstreamBasePolicy); + for (ValueRequestData& Value : Request.Values) + { + CachePolicy UpstreamPolicy = ConvertToUpstream(Value.DownstreamPolicy); + UpstreamPolicy |= !Value.ReadFromUpstream ? CachePolicy::SkipData : CachePolicy::None; + Builder.AddValuePolicy(Value.ValueId, UpstreamPolicy); + } + Request.Upstream.Policy = Builder.Build(); + } + else + { + // We don't know which Values exist in the Record; ask the upstrem for all values that the client wants, + // and convert the CacheRecordPolicy to an upstream policy + Request.Upstream.Policy = Request.DownstreamPolicy.ConvertToUpstream(); + } + } + + const auto OnCacheRecordGetComplete = [this, Namespace, &ParseValues](CacheRecordGetCompleteParams&& Params) { + if (!Params.Record) + { + return; + } + + RecordRequestData& Request = + *reinterpret_cast<RecordRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(RecordRequestData, Upstream)); + Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); + const CacheKey& Key = Request.Upstream.Key; + Stopwatch Timer; + auto TimeGuard = MakeGuard([&Timer, &Request]() { Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); }); + if (!Request.RecordObject) + { + CbObject ObjectBuffer = CbObject::Clone(Params.Record); + Request.RecordCacheValue = ObjectBuffer.GetBuffer().AsIoBuffer(); + Request.RecordCacheValue.SetContentType(ZenContentType::kCbObject); + Request.RecordObject = ObjectBuffer; + if (EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::StoreLocal)) + { + m_CacheStore.Put(*Namespace, Key.Bucket, Key.Hash, {.Value = {Request.RecordCacheValue}}); + } + ParseValues(Request); + Request.Source = Params.Source; + } + + Request.Complete = true; + for (ValueRequestData& Value : Request.Values) + { + if (Value.Exists) + { + continue; + } + CachePolicy ValuePolicy = Value.DownstreamPolicy; + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote)) + { + Request.Complete = false; + continue; + } + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData) || EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + if (const CbAttachment* Attachment = Params.Package.FindAttachment(Value.ContentId)) + { + if (CompressedBuffer Compressed = Attachment->AsCompressedBinary()) + { + Request.Source = Params.Source; + Value.Exists = true; + if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal)) + { + m_CidStore.AddChunk(Compressed.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash()); + } + if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + Value.Payload = Compressed; + } + } + else + { + ZEN_DEBUG("Uncompressed value '{}' from upstream cache record '{}/{}/{}'", + Value.ContentId, + *Namespace, + Key.Bucket, + Key.Hash); + } + } + if (!Value.Exists && !EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData)) + { + Request.Complete = false; + } + // Request.Complete does not need to be set to false for upstream SkipData attachments. + // In the PartialRecord==false case, the upstream will have failed the entire record if any SkipData attachment + // didn't exist and we will not get here. In the PartialRecord==true case, we do not need to inform the client of + // any missing SkipData attachments. + } + Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } + }; + + m_UpstreamCache.GetCacheRecords(*Namespace, UpstreamRequests, std::move(OnCacheRecordGetComplete)); + } + + CbPackage ResponsePackage; + CbObjectWriter ResponseObject; + + ResponseObject.BeginArray("Result"sv); + for (RecordRequestData& Request : Requests) + { + const CacheKey& Key = Request.Upstream.Key; + if (Request.Complete || + (Request.RecordObject && EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::PartialRecord))) + { + ResponseObject << Request.RecordObject; + for (ValueRequestData& Value : Request.Values) + { + if (!EnumHasAllFlags(Value.DownstreamPolicy, CachePolicy::SkipData) && Value.Payload) + { + ResponsePackage.AddAttachment(CbAttachment(Value.Payload, Value.ContentId)); + } + } + + ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {}{} ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceBytes(Request.RecordCacheValue.Size()), + Request.Complete ? ""sv : " (PARTIAL)"sv, + Request.Source ? Request.Source->Url : "LOCAL"sv, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount += Request.Source ? 1 : 0; + } + else + { + ResponseObject.AddNull(); + + if (!EnumHasAnyFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::Query)) + { + // If they requested no query, do not record this as a miss + ZEN_DEBUG("GETCACHERECORD DISABLEDQUERY - '{}/{}/{}' in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + } + else + { + ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}'{} ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + Request.RecordObject ? ""sv : " (PARTIAL)"sv, + Request.Source ? Request.Source->Url : "LOCAL"sv, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.MissCount++; + } + } + } + ResponseObject.EndArray(); + ResponsePackage.SetObject(ResponseObject.Save()); + return ResponsePackage; +} + +CbPackage +HttpStructuredCacheService::HandleRpcPutCacheValues(const CbPackage& BatchRequest) +{ + CbObjectView BatchObject = BatchRequest.GetObject(); + CbObjectView Params = BatchObject["Params"sv].AsObjectView(); + + std::string_view PolicyText = Params["DefaultPolicy"].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + std::vector<bool> Results; + for (CbFieldView RequestField : Params["Requests"sv]) + { + Stopwatch Timer; + + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyView = RequestObject["Key"sv].AsObjectView(); + + CacheKey Key; + if (!GetRpcRequestCacheKey(KeyView, Key)) + { + return CbPackage{}; + } + + PolicyText = RequestObject["Policy"sv].AsString(); + CachePolicy Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + IoHash RawHash = RequestObject["RawHash"sv].AsBinaryAttachment(); + uint64_t RawSize = RequestObject["RawSize"sv].AsUInt64(); + bool Succeeded = false; + uint64_t TransferredSize = 0; + + if (const CbAttachment* Attachment = BatchRequest.FindAttachment(RawHash)) + { + if (Attachment->IsCompressedBinary()) + { + CompressedBuffer Chunk = Attachment->AsCompressedBinary(); + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + // TODO: Implement upstream puts of CacheValues with StoreLocal == false. + // Currently ProcessCacheRecord requires that the value exist in the local cache to put it upstream. + Policy |= CachePolicy::StoreLocal; + } + + if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal)) + { + IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer(); + Value.SetContentType(ZenContentType::kCompressedBinary); + if (RawSize == 0) + { + RawSize = Chunk.DecodeRawSize(); + } + m_CacheStore.Put(*Namespace, Key.Bucket, Key.Hash, {.Value = Value, .RawSize = RawSize, .RawHash = RawHash}); + TransferredSize = Chunk.GetCompressedSize(); + } + Succeeded = true; + } + else + { + ZEN_WARN("PUTCACHEVALUES - '{}/{}/{}/{}' FAILED, value is not compressed", *Namespace, Key.Bucket, Key.Hash, RawHash); + return CbPackage{}; + } + } + else if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal)) + { + ZenCacheValue ExistingValue; + if (m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, ExistingValue) && + IsCompressedBinary(ExistingValue.Value.GetContentType())) + { + Succeeded = true; + } + } + // We do not search the Upstream. No data in a put means the caller is probing for whether they need to do a heavy put. + // If it doesn't exist locally they should do the heavy put rather than having us fetch it from upstream. + + if (Succeeded && EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCompressedBinary, .Namespace = *Namespace, .Key = Key}); + } + Results.push_back(Succeeded); + ZEN_DEBUG("PUTCACHEVALUES - '{}/{}/{}' {}, '{}' in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceBytes(TransferredSize), + Succeeded ? "Added"sv : "Invalid", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + } + if (Results.empty()) + { + return CbPackage{}; + } + + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (bool Value : Results) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + CbPackage RpcResponse; + RpcResponse.SetObject(ResponseObject.Save()); + + return RpcResponse; +} + +CbPackage +HttpStructuredCacheService::HandleRpcGetCacheValues(CbObjectView RpcRequest) +{ + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheValues"sv); + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + std::string_view PolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default; + std::optional<std::string> Namespace = GetRpcRequestNamespace(Params); + if (!Namespace) + { + return CbPackage{}; + } + + struct RequestData + { + CacheKey Key; + CachePolicy Policy; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + CompressedBuffer Result; + }; + std::vector<RequestData> Requests; + + std::vector<size_t> RemoteRequestIndexes; + + for (CbFieldView RequestField : Params["Requests"sv]) + { + Stopwatch Timer; + + RequestData& Request = Requests.emplace_back(); + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + + if (!GetRpcRequestCacheKey(KeyObject, Request.Key)) + { + return CbPackage{}; + } + + PolicyText = RequestObject["Policy"sv].AsString(); + Request.Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + + CacheKey& Key = Request.Key; + CachePolicy Policy = Request.Policy; + + ZenCacheValue CacheValue; + if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal)) + { + if (m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, CacheValue) && IsCompressedBinary(CacheValue.Value.GetContentType())) + { + Request.RawHash = CacheValue.RawHash; + Request.RawSize = CacheValue.RawSize; + Request.Result = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value)); + } + } + if (Request.Result) + { + ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + NiceBytes(Request.Result.GetCompressed().GetSize()), + "LOCAL"sv, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.HitCount++; + } + else if (EnumHasAllFlags(Policy, CachePolicy::QueryRemote)) + { + RemoteRequestIndexes.push_back(Requests.size() - 1); + } + else if (!EnumHasAnyFlags(Policy, CachePolicy::Query)) + { + // If they requested no query, do not record this as a miss + ZEN_DEBUG("GETCACHEVALUES DISABLEDQUERY - '{}/{}/{}'", *Namespace, Key.Bucket, Key.Hash); + } + else + { + ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}", + *Namespace, + Key.Bucket, + Key.Hash, + "LOCAL"sv, + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.MissCount++; + } + } + + if (!RemoteRequestIndexes.empty()) + { + std::vector<CacheValueRequest> RequestedRecordsData; + std::vector<CacheValueRequest*> CacheValueRequests; + RequestedRecordsData.reserve(RemoteRequestIndexes.size()); + CacheValueRequests.reserve(RemoteRequestIndexes.size()); + for (size_t Index : RemoteRequestIndexes) + { + RequestData& Request = Requests[Index]; + RequestedRecordsData.push_back({.Key = {Request.Key.Bucket, Request.Key.Hash}, .Policy = ConvertToUpstream(Request.Policy)}); + CacheValueRequests.push_back(&RequestedRecordsData.back()); + } + Stopwatch Timer; + m_UpstreamCache.GetCacheValues( + *Namespace, + CacheValueRequests, + [this, Namespace, &RequestedRecordsData, &Requests, &RemoteRequestIndexes, &Timer](CacheValueGetCompleteParams&& Params) { + CacheValueRequest& ChunkRequest = Params.Request; + if (Params.RawHash != IoHash::Zero) + { + size_t RequestOffset = std::distance(RequestedRecordsData.data(), &ChunkRequest); + size_t RequestIndex = RemoteRequestIndexes[RequestOffset]; + RequestData& Request = Requests[RequestIndex]; + Request.RawHash = Params.RawHash; + Request.RawSize = Params.RawSize; + const bool HasData = IsCompressedBinary(Params.Value.GetContentType()); + const bool SkipData = EnumHasAllFlags(Request.Policy, CachePolicy::SkipData); + const bool StoreData = EnumHasAllFlags(Request.Policy, CachePolicy::StoreLocal); + const bool IsHit = SkipData || HasData; + if (IsHit) + { + if (HasData && !SkipData) + { + Request.Result = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value)); + } + + if (HasData && StoreData) + { + m_CacheStore.Put(*Namespace, + Request.Key.Bucket, + Request.Key.Hash, + ZenCacheValue{.Value = Params.Value, .RawSize = Request.RawSize, .RawHash = Request.RawHash}); + } + + ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}", + *Namespace, + ChunkRequest.Key.Bucket, + ChunkRequest.Key.Hash, + NiceBytes(Request.Result.GetCompressed().GetSize()), + Params.Source ? Params.Source->Url : "UPSTREAM", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.HitCount++; + m_CacheStats.UpstreamHitCount++; + return; + } + } + ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}", + *Namespace, + ChunkRequest.Key.Bucket, + ChunkRequest.Key.Hash, + Params.Source ? Params.Source->Url : "UPSTREAM", + NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); + m_CacheStats.MissCount++; + }); + } + + if (Requests.empty()) + { + return CbPackage{}; + } + + CbPackage RpcResponse; + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"sv); + for (const RequestData& Request : Requests) + { + ResponseObject.BeginObject(); + { + const CompressedBuffer& Result = Request.Result; + if (Result) + { + ResponseObject.AddHash("RawHash"sv, Request.RawHash); + if (!EnumHasAllFlags(Request.Policy, CachePolicy::SkipData)) + { + RpcResponse.AddAttachment(CbAttachment(Result, Request.RawHash)); + } + else + { + ResponseObject.AddInteger("RawSize"sv, Request.RawSize); + } + } + else if (Request.RawHash != IoHash::Zero) + { + ResponseObject.AddHash("RawHash"sv, Request.RawHash); + ResponseObject.AddInteger("RawSize"sv, Request.RawSize); + } + } + ResponseObject.EndObject(); + } + ResponseObject.EndArray(); + + RpcResponse.SetObject(ResponseObject.Save()); + return RpcResponse; +} + +namespace cache::detail { + + struct RecordValue + { + Oid ValueId; + IoHash ContentId; + uint64_t RawSize; + }; + struct RecordBody + { + IoBuffer CacheValue; + std::vector<RecordValue> Values; + const UpstreamEndpointInfo* Source = nullptr; + CachePolicy DownstreamPolicy; + bool Exists = false; + bool HasRequest = false; + bool ValuesRead = false; + }; + struct ChunkRequest + { + CacheChunkRequest* Key = nullptr; + RecordBody* Record = nullptr; + CompressedBuffer Value; + const UpstreamEndpointInfo* Source = nullptr; + uint64_t RawSize = 0; + uint64_t RequestedSize = 0; + uint64_t RequestedOffset = 0; + CachePolicy DownstreamPolicy; + bool Exists = false; + bool RawSizeKnown = false; + bool IsRecordRequest = false; + uint64_t ElapsedTimeUs = 0; + }; + +} // namespace cache::detail + +CbPackage +HttpStructuredCacheService::HandleRpcGetCacheChunks(CbObjectView RpcRequest) +{ + using namespace cache::detail; + + std::string Namespace; + std::vector<CacheKeyRequest> RecordKeys; // Data about a Record necessary to identify it to the upstream + std::vector<RecordBody> Records; // Scratch-space data about a Record when fulfilling RecordRequests + std::vector<CacheChunkRequest> RequestKeys; // Data about a ChunkRequest necessary to identify it to the upstream + std::vector<ChunkRequest> Requests; // Intermediate and result data about a ChunkRequest + std::vector<ChunkRequest*> RecordRequests; // The ChunkRequests that are requesting a subvalue from a Record Key + std::vector<ChunkRequest*> ValueRequests; // The ChunkRequests that are requesting a Value Key + std::vector<CacheChunkRequest*> UpstreamChunks; // ChunkRequests that we need to send to the upstream + + // Parse requests from the CompactBinary body of the RpcRequest and divide it into RecordRequests and ValueRequests + if (!ParseGetCacheChunksRequest(Namespace, RecordKeys, Records, RequestKeys, Requests, RecordRequests, ValueRequests, RpcRequest)) + { + return CbPackage{}; + } + + // For each Record request, load the Record if necessary to find the Chunk's ContentId, load its Payloads if we + // have it locally, and otherwise append a request for the payload to UpstreamChunks + GetLocalCacheRecords(Namespace, RecordKeys, Records, RecordRequests, UpstreamChunks); + + // For each Value request, load the Value if we have it locally and otherwise append a request for the payload to UpstreamChunks + GetLocalCacheValues(Namespace, ValueRequests, UpstreamChunks); + + // Call GetCacheChunks on the upstream for any payloads we do not have locally + GetUpstreamCacheChunks(Namespace, UpstreamChunks, RequestKeys, Requests); + + // Send the payload and descriptive data about each chunk to the client + return WriteGetCacheChunksResponse(Namespace, Requests); +} + +bool +HttpStructuredCacheService::ParseGetCacheChunksRequest(std::string& Namespace, + std::vector<CacheKeyRequest>& RecordKeys, + std::vector<cache::detail::RecordBody>& Records, + std::vector<CacheChunkRequest>& RequestKeys, + std::vector<cache::detail::ChunkRequest>& Requests, + std::vector<cache::detail::ChunkRequest*>& RecordRequests, + std::vector<cache::detail::ChunkRequest*>& ValueRequests, + CbObjectView RpcRequest) +{ + using namespace cache::detail; + + ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheChunks"sv); + + CbObjectView Params = RpcRequest["Params"sv].AsObjectView(); + std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString(); + CachePolicy DefaultPolicy = !DefaultPolicyText.empty() ? ParseCachePolicy(DefaultPolicyText) : CachePolicy::Default; + + std::optional<std::string> NamespaceText = GetRpcRequestNamespace(Params); + if (!NamespaceText) + { + ZEN_WARN("GetCacheChunks: Invalid namespace in ChunkRequest."); + return false; + } + Namespace = *NamespaceText; + + CbArrayView ChunkRequestsArray = Params["ChunkRequests"sv].AsArrayView(); + size_t NumRequests = static_cast<size_t>(ChunkRequestsArray.Num()); + + // Note that these reservations allow us to take pointers to the elements while populating them. If the reservation is removed, + // we will need to change the pointers to indexes to handle reallocations. + RecordKeys.reserve(NumRequests); + Records.reserve(NumRequests); + RequestKeys.reserve(NumRequests); + Requests.reserve(NumRequests); + RecordRequests.reserve(NumRequests); + ValueRequests.reserve(NumRequests); + + CacheKeyRequest* PreviousRecordKey = nullptr; + RecordBody* PreviousRecord = nullptr; + + for (CbFieldView RequestView : ChunkRequestsArray) + { + CbObjectView RequestObject = RequestView.AsObjectView(); + CacheChunkRequest& RequestKey = RequestKeys.emplace_back(); + ChunkRequest& Request = Requests.emplace_back(); + CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView(); + + Request.Key = &RequestKey; + if (!GetRpcRequestCacheKey(KeyObject, Request.Key->Key)) + { + ZEN_WARN("GetCacheChunks: Invalid key in ChunkRequest."); + return false; + } + + RequestKey.ChunkId = RequestObject["ChunkId"sv].AsHash(); + RequestKey.ValueId = RequestObject["ValueId"sv].AsObjectId(); + RequestKey.RawOffset = RequestObject["RawOffset"sv].AsUInt64(); + RequestKey.RawSize = RequestObject["RawSize"sv].AsUInt64(UINT64_MAX); + Request.RequestedSize = RequestKey.RawSize; + Request.RequestedOffset = RequestKey.RawOffset; + std::string_view PolicyText = RequestObject["Policy"sv].AsString(); + Request.DownstreamPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy; + Request.IsRecordRequest = (bool)RequestKey.ValueId; + + if (!Request.IsRecordRequest) + { + ValueRequests.push_back(&Request); + } + else + { + RecordRequests.push_back(&Request); + CacheKeyRequest* RecordKey = nullptr; + RecordBody* Record = nullptr; + + if (!PreviousRecordKey || PreviousRecordKey->Key < RequestKey.Key) + { + RecordKey = &RecordKeys.emplace_back(); + PreviousRecordKey = RecordKey; + Record = &Records.emplace_back(); + PreviousRecord = Record; + RecordKey->Key = RequestKey.Key; + } + else if (RequestKey.Key == PreviousRecordKey->Key) + { + RecordKey = PreviousRecordKey; + Record = PreviousRecord; + } + else + { + ZEN_WARN("GetCacheChunks: Keys in ChunkRequest are not sorted: {}/{} came after {}/{}.", + RequestKey.Key.Bucket, + RequestKey.Key.Hash, + PreviousRecordKey->Key.Bucket, + PreviousRecordKey->Key.Hash); + return false; + } + Request.Record = Record; + if (RequestKey.ChunkId == RequestKey.ChunkId.Zero) + { + Record->DownstreamPolicy = + Record->HasRequest ? Union(Record->DownstreamPolicy, Request.DownstreamPolicy) : Request.DownstreamPolicy; + Record->HasRequest = true; + } + } + } + if (Requests.empty()) + { + return false; + } + return true; +} + +void +HttpStructuredCacheService::GetLocalCacheRecords(std::string_view Namespace, + std::vector<CacheKeyRequest>& RecordKeys, + std::vector<cache::detail::RecordBody>& Records, + std::vector<cache::detail::ChunkRequest*>& RecordRequests, + std::vector<CacheChunkRequest*>& OutUpstreamChunks) +{ + using namespace cache::detail; + + std::vector<CacheKeyRequest*> UpstreamRecordRequests; + for (size_t RecordIndex = 0; RecordIndex < Records.size(); ++RecordIndex) + { + Stopwatch Timer; + CacheKeyRequest& RecordKey = RecordKeys[RecordIndex]; + RecordBody& Record = Records[RecordIndex]; + if (Record.HasRequest) + { + Record.DownstreamPolicy |= CachePolicy::SkipData | CachePolicy::SkipMeta; + + if (!Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryLocal)) + { + ZenCacheValue CacheValue; + if (m_CacheStore.Get(Namespace, RecordKey.Key.Bucket, RecordKey.Key.Hash, CacheValue)) + { + Record.Exists = true; + Record.CacheValue = std::move(CacheValue.Value); + } + } + if (!Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryRemote)) + { + RecordKey.Policy = CacheRecordPolicy(ConvertToUpstream(Record.DownstreamPolicy)); + UpstreamRecordRequests.push_back(&RecordKey); + } + RecordRequests[RecordIndex]->ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } + } + + if (!UpstreamRecordRequests.empty()) + { + const auto OnCacheRecordGetComplete = + [this, Namespace, &RecordKeys, &Records, &RecordRequests](CacheRecordGetCompleteParams&& Params) { + if (!Params.Record) + { + return; + } + CacheKeyRequest& RecordKey = Params.Request; + size_t RecordIndex = std::distance(RecordKeys.data(), &RecordKey); + RecordRequests[RecordIndex]->ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); + RecordBody& Record = Records[RecordIndex]; + + const CacheKey& Key = RecordKey.Key; + Record.Exists = true; + CbObject ObjectBuffer = CbObject::Clone(Params.Record); + Record.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer(); + Record.CacheValue.SetContentType(ZenContentType::kCbObject); + Record.Source = Params.Source; + + if (EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::StoreLocal)) + { + m_CacheStore.Put(Namespace, Key.Bucket, Key.Hash, {.Value = Record.CacheValue}); + } + }; + m_UpstreamCache.GetCacheRecords(Namespace, UpstreamRecordRequests, std::move(OnCacheRecordGetComplete)); + } + + std::vector<CacheChunkRequest*> UpstreamPayloadRequests; + for (ChunkRequest* Request : RecordRequests) + { + Stopwatch Timer; + if (Request->Key->ChunkId == IoHash::Zero) + { + // Unreal uses a 12 byte ID to address cache record values. When the uncompressed hash (ChunkId) + // is missing, parse the cache record and try to find the raw hash from the ValueId. + RecordBody& Record = *Request->Record; + if (!Record.ValuesRead) + { + Record.ValuesRead = true; + if (Record.CacheValue && Record.CacheValue.GetContentType() == ZenContentType::kCbObject) + { + CbObjectView RecordObject = CbObjectView(Record.CacheValue.GetData()); + CbArrayView ValuesArray = RecordObject["Values"sv].AsArrayView(); + Record.Values.reserve(ValuesArray.Num()); + for (CbFieldView ValueField : ValuesArray) + { + CbObjectView ValueObject = ValueField.AsObjectView(); + Oid ValueId = ValueObject["Id"sv].AsObjectId(); + CbFieldView RawHashField = ValueObject["RawHash"sv]; + IoHash RawHash = RawHashField.AsBinaryAttachment(); + if (ValueId && !RawHashField.HasError()) + { + Record.Values.push_back({ValueId, RawHash, ValueObject["RawSize"sv].AsUInt64()}); + } + } + } + } + + for (const RecordValue& Value : Record.Values) + { + if (Value.ValueId == Request->Key->ValueId) + { + Request->Key->ChunkId = Value.ContentId; + Request->RawSize = Value.RawSize; + Request->RawSizeKnown = true; + break; + } + } + } + + // Now load the ContentId from the local ContentIdStore or from the upstream + if (Request->Key->ChunkId != IoHash::Zero) + { + if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal)) + { + if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData) && Request->RawSizeKnown) + { + if (m_CidStore.ContainsChunk(Request->Key->ChunkId)) + { + Request->Exists = true; + } + } + else if (IoBuffer Payload = m_CidStore.FindChunkByCid(Request->Key->ChunkId)) + { + if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData)) + { + Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(Payload)); + if (Request->Value) + { + Request->Exists = true; + Request->RawSizeKnown = false; + } + } + else + { + IoHash _; + if (CompressedBuffer::ValidateCompressedHeader(Payload, _, Request->RawSize)) + { + Request->Exists = true; + Request->RawSizeKnown = true; + } + } + } + } + if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote)) + { + Request->Key->Policy = ConvertToUpstream(Request->DownstreamPolicy); + OutUpstreamChunks.push_back(Request->Key); + } + } + Request->ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } +} + +void +HttpStructuredCacheService::GetLocalCacheValues(std::string_view Namespace, + std::vector<cache::detail::ChunkRequest*>& ValueRequests, + std::vector<CacheChunkRequest*>& OutUpstreamChunks) +{ + using namespace cache::detail; + + for (ChunkRequest* Request : ValueRequests) + { + Stopwatch Timer; + if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal)) + { + ZenCacheValue CacheValue; + if (m_CacheStore.Get(Namespace, Request->Key->Key.Bucket, Request->Key->Key.Hash, CacheValue)) + { + if (IsCompressedBinary(CacheValue.Value.GetContentType())) + { + Request->Key->ChunkId = CacheValue.RawHash; + Request->Exists = true; + Request->RawSize = CacheValue.RawSize; + Request->RawSizeKnown = true; + if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData)) + { + Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value)); + } + } + } + } + if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote)) + { + if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::StoreLocal)) + { + // Convert the Offset,Size request into a request for the entire value; we will need it all to be able to store it locally + Request->Key->RawOffset = 0; + Request->Key->RawSize = UINT64_MAX; + } + OutUpstreamChunks.push_back(Request->Key); + } + Request->ElapsedTimeUs += Timer.GetElapsedTimeUs(); + } +} + +void +HttpStructuredCacheService::GetUpstreamCacheChunks(std::string_view Namespace, + std::vector<CacheChunkRequest*>& UpstreamChunks, + std::vector<CacheChunkRequest>& RequestKeys, + std::vector<cache::detail::ChunkRequest>& Requests) +{ + using namespace cache::detail; + + if (!UpstreamChunks.empty()) + { + const auto OnCacheChunksGetComplete = [this, Namespace, &RequestKeys, &Requests](CacheChunkGetCompleteParams&& Params) { + if (Params.RawHash == Params.RawHash.Zero) + { + return; + } + + CacheChunkRequest& Key = Params.Request; + size_t RequestIndex = std::distance(RequestKeys.data(), &Key); + ChunkRequest& Request = Requests[RequestIndex]; + Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0); + if (EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal) || + !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData)) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value)); + if (!Compressed) + { + return; + } + + if (EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal)) + { + if (Request.IsRecordRequest) + { + m_CidStore.AddChunk(Params.Value, Params.RawHash); + } + else + { + m_CacheStore.Put(Namespace, + Key.Key.Bucket, + Key.Key.Hash, + {.Value = Params.Value, .RawSize = Params.RawSize, .RawHash = Params.RawHash}); + } + } + if (!EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData)) + { + Request.Value = std::move(Compressed); + } + } + Key.ChunkId = Params.RawHash; + Request.Exists = true; + Request.RawSize = Params.RawSize; + Request.RawSizeKnown = true; + Request.Source = Params.Source; + + m_CacheStats.UpstreamHitCount++; + }; + + m_UpstreamCache.GetCacheChunks(Namespace, UpstreamChunks, std::move(OnCacheChunksGetComplete)); + } +} + +CbPackage +HttpStructuredCacheService::WriteGetCacheChunksResponse(std::string_view Namespace, std::vector<cache::detail::ChunkRequest>& Requests) +{ + using namespace cache::detail; + + CbPackage RpcResponse; + CbObjectWriter Writer; + + Writer.BeginArray("Result"sv); + for (ChunkRequest& Request : Requests) + { + Writer.BeginObject(); + { + if (Request.Exists) + { + Writer.AddHash("RawHash"sv, Request.Key->ChunkId); + if (Request.Value && !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData)) + { + RpcResponse.AddAttachment(CbAttachment(Request.Value, Request.Key->ChunkId)); + } + else + { + Writer.AddInteger("RawSize"sv, Request.RawSize); + } + + ZEN_DEBUG("GETCACHECHUNKS HIT - '{}/{}/{}/{}' {} '{}' ({}) in {}", + Namespace, + Request.Key->Key.Bucket, + Request.Key->Key.Hash, + Request.Key->ValueId, + NiceBytes(Request.RawSize), + Request.IsRecordRequest ? "Record"sv : "Value"sv, + Request.Source ? Request.Source->Url : "LOCAL"sv, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.HitCount++; + } + else if (!EnumHasAnyFlags(Request.DownstreamPolicy, CachePolicy::Query)) + { + ZEN_DEBUG("GETCACHECHUNKS DISABLEDQUERY - '{}/{}/{}/{}' in {}", + Namespace, + Request.Key->Key.Bucket, + Request.Key->Key.Hash, + Request.Key->ValueId, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + } + else + { + ZEN_DEBUG("GETCACHECHUNKS MISS - '{}/{}/{}/{}' in {}", + Namespace, + Request.Key->Key.Bucket, + Request.Key->Key.Hash, + Request.Key->ValueId, + NiceLatencyNs(Request.ElapsedTimeUs * 1000)); + m_CacheStats.MissCount++; + } + } + Writer.EndObject(); + } + Writer.EndArray(); + + RpcResponse.SetObject(Writer.Save()); + return RpcResponse; +} + +void +HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request) +{ + CbObjectWriter Cbo; + + EmitSnapshot("requests", m_HttpRequests, Cbo); + EmitSnapshot("upstream_gets", m_UpstreamGetRequestTiming, Cbo); + + const uint64_t HitCount = m_CacheStats.HitCount; + const uint64_t UpstreamHitCount = m_CacheStats.UpstreamHitCount; + const uint64_t MissCount = m_CacheStats.MissCount; + const uint64_t TotalCount = HitCount + MissCount; + + const CidStoreSize CidSize = m_CidStore.TotalSize(); + const GcStorageSize CacheSize = m_CacheStore.StorageSize(); + + Cbo.BeginObject("cache"); + { + Cbo.BeginObject("size"); + { + Cbo << "disk" << CacheSize.DiskSize; + Cbo << "memory" << CacheSize.MemorySize; + } + Cbo.EndObject(); + + Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); + Cbo << "hits" << HitCount << "misses" << MissCount; + Cbo << "hit_ratio" << (TotalCount > 0 ? (double(HitCount) / double(TotalCount)) : 0.0); + Cbo << "upstream_hits" << m_CacheStats.UpstreamHitCount; + Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0); + } + Cbo.EndObject(); + + Cbo.BeginObject("upstream"); + { + m_UpstreamCache.GetStatus(Cbo); + } + Cbo.EndObject(); + + Cbo.BeginObject("cid"); + { + Cbo.BeginObject("size"); + { + Cbo << "tiny" << CidSize.TinySize; + Cbo << "small" << CidSize.SmallSize; + Cbo << "large" << CidSize.LargeSize; + Cbo << "total" << CidSize.TotalSize; + } + Cbo.EndObject(); + } + Cbo.EndObject(); + + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +void +HttpStructuredCacheService::HandleStatusRequest(HttpServerRequest& Request) +{ + CbObjectWriter Cbo; + Cbo << "ok" << true; + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +#if ZEN_WITH_TESTS + +TEST_CASE("z$service.parse.relative.Uri") +{ + HttpRequestData RootRequest; + CHECK(HttpRequestParseRelativeUri("", RootRequest)); + CHECK(!RootRequest.Namespace.has_value()); + CHECK(!RootRequest.Bucket.has_value()); + CHECK(!RootRequest.HashKey.has_value()); + CHECK(!RootRequest.ValueContentId.has_value()); + + RootRequest = {}; + CHECK(HttpRequestParseRelativeUri("/", RootRequest)); + CHECK(!RootRequest.Namespace.has_value()); + CHECK(!RootRequest.Bucket.has_value()); + CHECK(!RootRequest.HashKey.has_value()); + CHECK(!RootRequest.ValueContentId.has_value()); + + HttpRequestData LegacyBucketRequestBecomesNamespaceRequest; + CHECK(HttpRequestParseRelativeUri("test", LegacyBucketRequestBecomesNamespaceRequest)); + CHECK(LegacyBucketRequestBecomesNamespaceRequest.Namespace == "test"sv); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.Bucket.has_value()); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.HashKey.has_value()); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.ValueContentId.has_value()); + + HttpRequestData LegacyHashKeyRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", LegacyHashKeyRequest)); + CHECK(LegacyHashKeyRequest.Namespace == ZenCacheStore::DefaultNamespace); + CHECK(LegacyHashKeyRequest.Bucket == "test"sv); + CHECK(LegacyHashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv)); + CHECK(!LegacyHashKeyRequest.ValueContentId.has_value()); + + HttpRequestData LegacyValueContentIdRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789", + LegacyValueContentIdRequest)); + CHECK(LegacyValueContentIdRequest.Namespace == ZenCacheStore::DefaultNamespace); + CHECK(LegacyValueContentIdRequest.Bucket == "test"sv); + CHECK(LegacyValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv)); + CHECK(LegacyValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"sv)); + + HttpRequestData V2DefaultNamespaceRequest; + CHECK(HttpRequestParseRelativeUri("ue4.ddc", V2DefaultNamespaceRequest)); + CHECK(V2DefaultNamespaceRequest.Namespace == "ue4.ddc"); + CHECK(!V2DefaultNamespaceRequest.Bucket.has_value()); + CHECK(!V2DefaultNamespaceRequest.HashKey.has_value()); + CHECK(!V2DefaultNamespaceRequest.ValueContentId.has_value()); + + HttpRequestData V2NamespaceRequest; + CHECK(HttpRequestParseRelativeUri("nicenamespace", V2NamespaceRequest)); + CHECK(V2NamespaceRequest.Namespace == "nicenamespace"sv); + CHECK(!V2NamespaceRequest.Bucket.has_value()); + CHECK(!V2NamespaceRequest.HashKey.has_value()); + CHECK(!V2NamespaceRequest.ValueContentId.has_value()); + + HttpRequestData V2BucketRequestWithDefaultNamespace; + CHECK(HttpRequestParseRelativeUri("ue4.ddc/test", V2BucketRequestWithDefaultNamespace)); + CHECK(V2BucketRequestWithDefaultNamespace.Namespace == "ue4.ddc"); + CHECK(V2BucketRequestWithDefaultNamespace.Bucket == "test"sv); + CHECK(!V2BucketRequestWithDefaultNamespace.HashKey.has_value()); + CHECK(!V2BucketRequestWithDefaultNamespace.ValueContentId.has_value()); + + HttpRequestData V2BucketRequestWithNamespace; + CHECK(HttpRequestParseRelativeUri("nicenamespace/test", V2BucketRequestWithNamespace)); + CHECK(V2BucketRequestWithNamespace.Namespace == "nicenamespace"sv); + CHECK(V2BucketRequestWithNamespace.Bucket == "test"sv); + CHECK(!V2BucketRequestWithNamespace.HashKey.has_value()); + CHECK(!V2BucketRequestWithNamespace.ValueContentId.has_value()); + + HttpRequestData V2HashKeyRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", V2HashKeyRequest)); + CHECK(V2HashKeyRequest.Namespace == ZenCacheStore::DefaultNamespace); + CHECK(V2HashKeyRequest.Bucket == "test"); + CHECK(V2HashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv)); + CHECK(!V2HashKeyRequest.ValueContentId.has_value()); + + HttpRequestData V2ValueContentIdRequest; + CHECK( + HttpRequestParseRelativeUri("nicenamespace/test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789", + V2ValueContentIdRequest)); + CHECK(V2ValueContentIdRequest.Namespace == "nicenamespace"sv); + CHECK(V2ValueContentIdRequest.Bucket == "test"sv); + CHECK(V2ValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv)); + CHECK(V2ValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"sv)); + + HttpRequestData Invalid; + CHECK(!HttpRequestParseRelativeUri("bad\2_namespace", Invalid)); + CHECK(!HttpRequestParseRelativeUri("nice/\2\1bucket", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789a", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcdef1234", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/pppppppp89abcdef12340123456789abcdef1234", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcd", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/ppppppppdef12345678956789abcdef123456789", + Invalid)); +} + +#endif + +void +z$service_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenserver/cache/structuredcache.h b/src/zenserver/cache/structuredcache.h new file mode 100644 index 000000000..4e7b98ac9 --- /dev/null +++ b/src/zenserver/cache/structuredcache.h @@ -0,0 +1,187 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/stats.h> +#include <zenhttp/httpserver.h> + +#include "monitoring/httpstats.h" +#include "monitoring/httpstatus.h" + +#include <memory> +#include <vector> + +namespace spdlog { +class logger; +} + +namespace zen { + +struct CacheChunkRequest; +struct CacheKeyRequest; +class CidStore; +class CbObjectView; +struct PutRequestData; +class ScrubContext; +class UpstreamCache; +class ZenCacheStore; +enum class CachePolicy : uint32_t; +enum class RpcAcceptOptions : uint16_t; + +namespace cache { + class IRpcRequestReplayer; + class IRpcRequestRecorder; + namespace detail { + struct RecordBody; + struct ChunkRequest; + } // namespace detail +} // namespace cache + +/** + * Structured cache service. Imposes constraints on keys, supports blobs and + * structured values + * + * Keys are structured as: + * + * {BucketId}/{KeyHash} + * + * Where BucketId is a lower-case alphanumeric string, and KeyHash is a 40-character + * hexadecimal sequence. The hash value may be derived in any number of ways, it's + * up to the application to pick an approach. + * + * Values may be structured or unstructured. Structured values are encoded using Unreal + * Engine's compact binary encoding (see CbObject) + * + * Additionally, attachments may be addressed as: + * + * {BucketId}/{KeyHash}/{ValueHash} + * + * Where the two initial components are the same as for the main endpoint + * + * The storage strategy is as follows: + * + * - Structured values are stored in a dedicated backing store per bucket + * - Unstructured values and attachments are stored in the CAS pool + * + */ + +class HttpStructuredCacheService : public HttpService, public IHttpStatsProvider, public IHttpStatusProvider +{ +public: + HttpStructuredCacheService(ZenCacheStore& InCacheStore, + CidStore& InCidStore, + HttpStatsService& StatsService, + HttpStatusService& StatusService, + UpstreamCache& UpstreamCache); + ~HttpStructuredCacheService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + + void Flush(); + void Scrub(ScrubContext& Ctx); + +private: + struct CacheRef + { + std::string Namespace; + std::string BucketSegment; + IoHash HashKey; + IoHash ValueContentId; + }; + + struct CacheStats + { + std::atomic_uint64_t HitCount{}; + std::atomic_uint64_t UpstreamHitCount{}; + std::atomic_uint64_t MissCount{}; + }; + enum class PutResult + { + Success, + Fail, + Invalid, + }; + + void HandleCacheRecordRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandleGetCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandlePutCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandleCacheChunkRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandleGetCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandlePutCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl); + void HandleRpcRequest(HttpServerRequest& Request); + void HandleDetailsRequest(HttpServerRequest& Request); + + CbPackage HandleRpcPutCacheRecords(const CbPackage& BatchRequest); + CbPackage HandleRpcGetCacheRecords(CbObjectView BatchRequest); + CbPackage HandleRpcPutCacheValues(const CbPackage& BatchRequest); + CbPackage HandleRpcGetCacheValues(CbObjectView BatchRequest); + CbPackage HandleRpcGetCacheChunks(CbObjectView BatchRequest); + CbPackage HandleRpcRequest(const ZenContentType ContentType, + IoBuffer&& Body, + uint32_t& OutAcceptMagic, + RpcAcceptOptions& OutAcceptFlags, + int& OutTargetProcessId); + + void HandleCacheRequest(HttpServerRequest& Request); + void HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view Namespace); + void HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Namespace, std::string_view Bucket); + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + virtual void HandleStatusRequest(HttpServerRequest& Request) override; + PutResult PutCacheRecord(PutRequestData& Request, const CbPackage* Package); + + /** HandleRpcGetCacheChunks Helper: Parse the Body object into RecordValue Requests and Value Requests. */ + bool ParseGetCacheChunksRequest(std::string& Namespace, + std::vector<CacheKeyRequest>& RecordKeys, + std::vector<cache::detail::RecordBody>& Records, + std::vector<CacheChunkRequest>& RequestKeys, + std::vector<cache::detail::ChunkRequest>& Requests, + std::vector<cache::detail::ChunkRequest*>& RecordRequests, + std::vector<cache::detail::ChunkRequest*>& ValueRequests, + CbObjectView RpcRequest); + /** HandleRpcGetCacheChunks Helper: Load records to get ContentId for RecordRequests, and load their payloads if they exist locally. */ + void GetLocalCacheRecords(std::string_view Namespace, + std::vector<CacheKeyRequest>& RecordKeys, + std::vector<cache::detail::RecordBody>& Records, + std::vector<cache::detail::ChunkRequest*>& RecordRequests, + std::vector<CacheChunkRequest*>& OutUpstreamChunks); + /** HandleRpcGetCacheChunks Helper: For ValueRequests, load their payloads if they exist locally. */ + void GetLocalCacheValues(std::string_view Namespace, + std::vector<cache::detail::ChunkRequest*>& ValueRequests, + std::vector<CacheChunkRequest*>& OutUpstreamChunks); + /** HandleRpcGetCacheChunks Helper: Load payloads from upstream that did not exist locally. */ + void GetUpstreamCacheChunks(std::string_view Namespace, + std::vector<CacheChunkRequest*>& UpstreamChunks, + std::vector<CacheChunkRequest>& RequestKeys, + std::vector<cache::detail::ChunkRequest>& Requests); + /** HandleRpcGetCacheChunks Helper: Send response message containing all chunk results. */ + CbPackage WriteGetCacheChunksResponse(std::string_view Namespace, std::vector<cache::detail::ChunkRequest>& Requests); + + spdlog::logger& Log() { return m_Log; } + spdlog::logger& m_Log; + ZenCacheStore& m_CacheStore; + HttpStatsService& m_StatsService; + HttpStatusService& m_StatusService; + CidStore& m_CidStore; + UpstreamCache& m_UpstreamCache; + uint64_t m_LastScrubTime = 0; + metrics::OperationTiming m_HttpRequests; + metrics::OperationTiming m_UpstreamGetRequestTiming; + CacheStats m_CacheStats; + + void ReplayRequestRecorder(cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount); + + std::unique_ptr<cache::IRpcRequestRecorder> m_RequestRecorder; +}; + +/** Recognize both kBinary and kCompressedBinary as kCompressedBinary for structured cache value keys. + * We need this until the content type is preserved for kCompressedBinary when passing to and from upstream servers. */ +inline bool +IsCompressedBinary(ZenContentType Type) +{ + return Type == ZenContentType::kBinary || Type == ZenContentType::kCompressedBinary; +} + +void z$service_forcelink(); + +} // namespace zen diff --git a/src/zenserver/cache/structuredcachestore.cpp b/src/zenserver/cache/structuredcachestore.cpp new file mode 100644 index 000000000..26e970073 --- /dev/null +++ b/src/zenserver/cache/structuredcachestore.cpp @@ -0,0 +1,3648 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "structuredcachestore.h" + +#include <zencore/except.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/compress.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zenstore/cidstore.h> +#include <zenstore/scrubcontext.h> + +#include <xxhash.h> + +#include <limits> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/core.h> +#include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zencore/workthreadpool.h> +# include <random> +#endif + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +namespace { + +#pragma pack(push) +#pragma pack(1) + + struct CacheBucketIndexHeader + { + static constexpr uint32_t ExpectedMagic = 0x75696478; // 'uidx'; + static constexpr uint32_t Version2 = 2; + static constexpr uint32_t CurrentVersion = Version2; + + uint32_t Magic = ExpectedMagic; + uint32_t Version = CurrentVersion; + uint64_t EntryCount = 0; + uint64_t LogPosition = 0; + uint32_t PayloadAlignment = 0; + uint32_t Checksum = 0; + + static uint32_t ComputeChecksum(const CacheBucketIndexHeader& Header) + { + return XXH32(&Header.Magic, sizeof(CacheBucketIndexHeader) - sizeof(uint32_t), 0xC0C0'BABA); + } + }; + + static_assert(sizeof(CacheBucketIndexHeader) == 32); + +#pragma pack(pop) + + const char* IndexExtension = ".uidx"; + const char* LogExtension = ".slog"; + + std::filesystem::path GetIndexPath(const std::filesystem::path& BucketDir, const std::string& BucketName) + { + return BucketDir / (BucketName + IndexExtension); + } + + std::filesystem::path GetTempIndexPath(const std::filesystem::path& BucketDir, const std::string& BucketName) + { + return BucketDir / (BucketName + ".tmp"); + } + + std::filesystem::path GetLogPath(const std::filesystem::path& BucketDir, const std::string& BucketName) + { + return BucketDir / (BucketName + LogExtension); + } + + bool ValidateEntry(const DiskIndexEntry& Entry, std::string& OutReason) + { + if (Entry.Key == IoHash::Zero) + { + OutReason = fmt::format("Invalid hash key {}", Entry.Key.ToHexString()); + return false; + } + if (Entry.Location.GetFlags() & + ~(DiskLocation::kStandaloneFile | DiskLocation::kStructured | DiskLocation::kTombStone | DiskLocation::kCompressed)) + { + OutReason = fmt::format("Invalid flags {} for entry {}", Entry.Location.GetFlags(), Entry.Key.ToHexString()); + return false; + } + if (Entry.Location.IsFlagSet(DiskLocation::kTombStone)) + { + return true; + } + if (Entry.Location.Reserved != 0) + { + OutReason = fmt::format("Invalid reserved field {} for entry {}", Entry.Location.Reserved, Entry.Key.ToHexString()); + return false; + } + uint64_t Size = Entry.Location.Size(); + if (Size == 0) + { + OutReason = fmt::format("Invalid size {} for entry {}", Size, Entry.Key.ToHexString()); + return false; + } + return true; + } + + bool MoveAndDeleteDirectory(const std::filesystem::path& Dir) + { + int DropIndex = 0; + do + { + if (!std::filesystem::exists(Dir)) + { + return false; + } + + std::string DroppedName = fmt::format("[dropped]{}({})", Dir.filename().string(), DropIndex); + std::filesystem::path DroppedBucketPath = Dir.parent_path() / DroppedName; + if (std::filesystem::exists(DroppedBucketPath)) + { + DropIndex++; + continue; + } + + std::error_code Ec; + std::filesystem::rename(Dir, DroppedBucketPath, Ec); + if (!Ec) + { + DeleteDirectories(DroppedBucketPath); + return true; + } + // TODO: Do we need to bail at some point? + zen::Sleep(100); + } while (true); + } + +} // namespace + +namespace fs = std::filesystem; + +static CbObject +LoadCompactBinaryObject(const fs::path& Path) +{ + FileContents Result = ReadFile(Path); + + if (!Result.ErrorCode) + { + IoBuffer Buffer = Result.Flatten(); + if (CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All); Error == CbValidateError::None) + { + return LoadCompactBinaryObject(Buffer); + } + } + + return CbObject(); +} + +static void +SaveCompactBinaryObject(const fs::path& Path, const CbObject& Object) +{ + WriteFile(Path, Object.GetBuffer().AsIoBuffer()); +} + +ZenCacheNamespace::ZenCacheNamespace(GcManager& Gc, const std::filesystem::path& RootDir) +: GcStorage(Gc) +, GcContributor(Gc) +, m_RootDir(RootDir) +, m_DiskLayer(RootDir) +{ + ZEN_INFO("initializing structured cache at '{}'", RootDir); + CreateDirectories(RootDir); + + m_DiskLayer.DiscoverBuckets(); + +#if ZEN_USE_CACHE_TRACKER + m_AccessTracker.reset(new ZenCacheTracker(RootDir)); +#endif +} + +ZenCacheNamespace::~ZenCacheNamespace() +{ +} + +bool +ZenCacheNamespace::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue) +{ + ZEN_TRACE_CPU("Z$::Get"); + + bool Ok = m_MemLayer.Get(InBucket, HashKey, OutValue); + +#if ZEN_USE_CACHE_TRACKER + auto _ = MakeGuard([&] { + if (!Ok) + return; + + m_AccessTracker->TrackAccess(InBucket, HashKey); + }); +#endif + + if (Ok) + { + ZEN_ASSERT(OutValue.Value.Size()); + + return true; + } + + Ok = m_DiskLayer.Get(InBucket, HashKey, OutValue); + + if (Ok) + { + ZEN_ASSERT(OutValue.Value.Size()); + + if (OutValue.Value.Size() <= m_DiskLayerSizeThreshold) + { + m_MemLayer.Put(InBucket, HashKey, OutValue); + } + } + + return Ok; +} + +void +ZenCacheNamespace::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value) +{ + ZEN_TRACE_CPU("Z$::Put"); + + // Store value and index + + ZEN_ASSERT(Value.Value.Size()); + + m_DiskLayer.Put(InBucket, HashKey, Value); + +#if ZEN_USE_REF_TRACKING + if (Value.Value.GetContentType() == ZenContentType::kCbObject) + { + if (ValidateCompactBinary(Value.Value, CbValidateMode::All) == CbValidateError::None) + { + CbObject Object{SharedBuffer(Value.Value)}; + + uint8_t TempBuffer[8 * sizeof(IoHash)]; + std::pmr::monotonic_buffer_resource Linear{TempBuffer, sizeof TempBuffer}; + std::pmr::polymorphic_allocator Allocator{&Linear}; + std::pmr::vector<IoHash> CidReferences{Allocator}; + + Object.IterateAttachments([&](CbFieldView Field) { CidReferences.push_back(Field.AsAttachment()); }); + + m_Gc.OnNewCidReferences(CidReferences); + } + } +#endif + + if (Value.Value.Size() <= m_DiskLayerSizeThreshold) + { + m_MemLayer.Put(InBucket, HashKey, Value); + } +} + +bool +ZenCacheNamespace::DropBucket(std::string_view Bucket) +{ + ZEN_INFO("dropping bucket '{}'", Bucket); + + // TODO: should ensure this is done atomically across all layers + + const bool MemDropped = m_MemLayer.DropBucket(Bucket); + const bool DiskDropped = m_DiskLayer.DropBucket(Bucket); + const bool AnyDropped = MemDropped || DiskDropped; + + ZEN_INFO("bucket '{}' was {}", Bucket, AnyDropped ? "dropped" : "not found"); + + return AnyDropped; +} + +bool +ZenCacheNamespace::Drop() +{ + m_MemLayer.Drop(); + return m_DiskLayer.Drop(); +} + +void +ZenCacheNamespace::Flush() +{ + m_DiskLayer.Flush(); +} + +void +ZenCacheNamespace::Scrub(ScrubContext& Ctx) +{ + if (m_LastScrubTime == Ctx.ScrubTimestamp()) + { + return; + } + + m_LastScrubTime = Ctx.ScrubTimestamp(); + + m_DiskLayer.Scrub(Ctx); + m_MemLayer.Scrub(Ctx); +} + +void +ZenCacheNamespace::GatherReferences(GcContext& GcCtx) +{ + Stopwatch Timer; + const auto Guard = + MakeGuard([&] { ZEN_DEBUG("cache gathered all references from '{}' in {}", m_RootDir, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); }); + + access_tracking::AccessTimes AccessTimes; + m_MemLayer.GatherAccessTimes(AccessTimes); + + m_DiskLayer.UpdateAccessTimes(AccessTimes); + m_DiskLayer.GatherReferences(GcCtx); +} + +void +ZenCacheNamespace::CollectGarbage(GcContext& GcCtx) +{ + m_MemLayer.Reset(); + m_DiskLayer.CollectGarbage(GcCtx); +} + +GcStorageSize +ZenCacheNamespace::StorageSize() const +{ + return {.DiskSize = m_DiskLayer.TotalSize(), .MemorySize = m_MemLayer.TotalSize()}; +} + +ZenCacheNamespace::Info +ZenCacheNamespace::GetInfo() const +{ + ZenCacheNamespace::Info Info = {.Config = {.RootDir = m_RootDir, .DiskLayerThreshold = m_DiskLayerSizeThreshold}, + .DiskLayerInfo = m_DiskLayer.GetInfo(), + .MemoryLayerInfo = m_MemLayer.GetInfo()}; + std::unordered_set<std::string> BucketNames; + for (const std::string& BucketName : Info.DiskLayerInfo.BucketNames) + { + BucketNames.insert(BucketName); + } + for (const std::string& BucketName : Info.MemoryLayerInfo.BucketNames) + { + BucketNames.insert(BucketName); + } + Info.BucketNames.insert(Info.BucketNames.end(), BucketNames.begin(), BucketNames.end()); + return Info; +} + +std::optional<ZenCacheNamespace::BucketInfo> +ZenCacheNamespace::GetBucketInfo(std::string_view Bucket) const +{ + std::optional<ZenCacheDiskLayer::BucketInfo> DiskBucketInfo = m_DiskLayer.GetBucketInfo(Bucket); + if (!DiskBucketInfo.has_value()) + { + return {}; + } + ZenCacheNamespace::BucketInfo Info = {.DiskLayerInfo = *DiskBucketInfo, + .MemoryLayerInfo = m_MemLayer.GetBucketInfo(Bucket).value_or(ZenCacheMemoryLayer::BucketInfo{})}; + return Info; +} + +CacheValueDetails::NamespaceDetails +ZenCacheNamespace::GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const +{ + return m_DiskLayer.GetValueDetails(BucketFilter, ValueFilter); +} + +////////////////////////////////////////////////////////////////////////// + +ZenCacheMemoryLayer::ZenCacheMemoryLayer() +{ +} + +ZenCacheMemoryLayer::~ZenCacheMemoryLayer() +{ +} + +bool +ZenCacheMemoryLayer::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue) +{ + RwLock::SharedLockScope _(m_Lock); + + auto It = m_Buckets.find(std::string(InBucket)); + + if (It == m_Buckets.end()) + { + return false; + } + + CacheBucket* Bucket = It->second.get(); + + _.ReleaseNow(); + + // There's a race here. Since the lock is released early to allow + // inserts, the bucket delete path could end up deleting the + // underlying data structure + + return Bucket->Get(HashKey, OutValue); +} + +void +ZenCacheMemoryLayer::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value) +{ + const auto BucketName = std::string(InBucket); + CacheBucket* Bucket = nullptr; + + { + RwLock::SharedLockScope _(m_Lock); + + if (auto It = m_Buckets.find(std::string(InBucket)); It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + } + + if (Bucket == nullptr) + { + // New bucket + + RwLock::ExclusiveLockScope _(m_Lock); + + if (auto It = m_Buckets.find(std::string(InBucket)); It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + else + { + auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>()); + Bucket = InsertResult.first->second.get(); + } + } + + // Note that since the underlying IoBuffer is retained, the content type is also + + Bucket->Put(HashKey, Value); +} + +bool +ZenCacheMemoryLayer::DropBucket(std::string_view InBucket) +{ + RwLock::ExclusiveLockScope _(m_Lock); + + auto It = m_Buckets.find(std::string(InBucket)); + + if (It != m_Buckets.end()) + { + CacheBucket& Bucket = *It->second; + m_DroppedBuckets.push_back(std::move(It->second)); + m_Buckets.erase(It); + Bucket.Drop(); + return true; + } + return false; +} + +void +ZenCacheMemoryLayer::Drop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + std::vector<std::unique_ptr<CacheBucket>> Buckets; + Buckets.reserve(m_Buckets.size()); + while (!m_Buckets.empty()) + { + const auto& It = m_Buckets.begin(); + CacheBucket& Bucket = *It->second; + m_DroppedBuckets.push_back(std::move(It->second)); + m_Buckets.erase(It->first); + Bucket.Drop(); + } +} + +void +ZenCacheMemoryLayer::Scrub(ScrubContext& Ctx) +{ + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + Kv.second->Scrub(Ctx); + } +} + +void +ZenCacheMemoryLayer::GatherAccessTimes(zen::access_tracking::AccessTimes& AccessTimes) +{ + using namespace zen::access_tracking; + + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + std::vector<KeyAccessTime>& Bucket = AccessTimes.Buckets[Kv.first]; + Kv.second->GatherAccessTimes(Bucket); + } +} + +void +ZenCacheMemoryLayer::Reset() +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_Buckets.clear(); +} + +uint64_t +ZenCacheMemoryLayer::TotalSize() const +{ + uint64_t TotalSize{}; + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + TotalSize += Kv.second->TotalSize(); + } + + return TotalSize; +} + +ZenCacheMemoryLayer::Info +ZenCacheMemoryLayer::GetInfo() const +{ + ZenCacheMemoryLayer::Info Info = {.Config = m_Configuration, .TotalSize = TotalSize()}; + + RwLock::SharedLockScope _(m_Lock); + Info.BucketNames.reserve(m_Buckets.size()); + for (auto& Kv : m_Buckets) + { + Info.BucketNames.push_back(Kv.first); + Info.EntryCount += Kv.second->EntryCount(); + } + return Info; +} + +std::optional<ZenCacheMemoryLayer::BucketInfo> +ZenCacheMemoryLayer::GetBucketInfo(std::string_view Bucket) const +{ + RwLock::SharedLockScope _(m_Lock); + + if (auto It = m_Buckets.find(std::string(Bucket)); It != m_Buckets.end()) + { + return ZenCacheMemoryLayer::BucketInfo{.EntryCount = It->second->EntryCount(), .TotalSize = It->second->TotalSize()}; + } + return {}; +} + +void +ZenCacheMemoryLayer::CacheBucket::Scrub(ScrubContext& Ctx) +{ + RwLock::SharedLockScope _(m_BucketLock); + + std::vector<IoHash> BadHashes; + + auto ValidateEntry = [](const IoHash& Hash, ZenContentType ContentType, IoBuffer Buffer) { + if (ContentType == ZenContentType::kCbObject) + { + CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All); + return Error == CbValidateError::None; + } + if (ContentType == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize)) + { + return false; + } + if (Hash != RawHash) + { + return false; + } + } + return true; + }; + + for (auto& Kv : m_CacheMap) + { + const BucketPayload& Payload = m_Payloads[Kv.second]; + if (!ValidateEntry(Kv.first, Payload.Payload.GetContentType(), Payload.Payload)) + { + BadHashes.push_back(Kv.first); + } + } + + if (!BadHashes.empty()) + { + Ctx.ReportBadCidChunks(BadHashes); + } +} + +void +ZenCacheMemoryLayer::CacheBucket::GatherAccessTimes(std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes) +{ + RwLock::SharedLockScope _(m_BucketLock); + std::transform(m_CacheMap.begin(), m_CacheMap.end(), std::back_inserter(AccessTimes), [this](const auto& Kv) { + return access_tracking::KeyAccessTime{.Key = Kv.first, .LastAccess = m_AccessTimes[Kv.second]}; + }); +} + +bool +ZenCacheMemoryLayer::CacheBucket::Get(const IoHash& HashKey, ZenCacheValue& OutValue) +{ + RwLock::SharedLockScope _(m_BucketLock); + + if (auto It = m_CacheMap.find(HashKey); It != m_CacheMap.end()) + { + uint32_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size()); + ZEN_ASSERT_SLOW(m_AccessTimes.size() == m_Payloads.size()); + + const BucketPayload& Payload = m_Payloads[EntryIndex]; + OutValue = {.Value = Payload.Payload, .RawSize = Payload.RawSize, .RawHash = Payload.RawHash}; + m_AccessTimes[EntryIndex] = GcClock::TickCount(); + + return true; + } + + return false; +} + +void +ZenCacheMemoryLayer::CacheBucket::Put(const IoHash& HashKey, const ZenCacheValue& Value) +{ + size_t PayloadSize = Value.Value.GetSize(); + { + GcClock::Tick AccessTime = GcClock::TickCount(); + RwLock::ExclusiveLockScope _(m_BucketLock); + if (m_CacheMap.size() == std::numeric_limits<uint32_t>::max()) + { + // No more space in our memory cache! + return; + } + if (auto It = m_CacheMap.find(HashKey); It != m_CacheMap.end()) + { + uint32_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size()); + + m_TotalSize.fetch_sub(PayloadSize, std::memory_order::relaxed); + BucketPayload& Payload = m_Payloads[EntryIndex]; + Payload.Payload = Value.Value; + Payload.RawHash = Value.RawHash; + Payload.RawSize = gsl::narrow<uint32_t>(Value.RawSize); + m_AccessTimes[EntryIndex] = AccessTime; + } + else + { + uint32_t EntryIndex = gsl::narrow<uint32_t>(m_Payloads.size()); + m_Payloads.emplace_back( + BucketPayload{.Payload = Value.Value, .RawSize = gsl::narrow<uint32_t>(Value.RawSize), .RawHash = Value.RawHash}); + m_AccessTimes.emplace_back(AccessTime); + m_CacheMap.insert_or_assign(HashKey, EntryIndex); + } + ZEN_ASSERT_SLOW(m_Payloads.size() == m_CacheMap.size()); + ZEN_ASSERT_SLOW(m_AccessTimes.size() == m_Payloads.size()); + } + + m_TotalSize.fetch_add(PayloadSize, std::memory_order::relaxed); +} + +void +ZenCacheMemoryLayer::CacheBucket::Drop() +{ + RwLock::ExclusiveLockScope _(m_BucketLock); + m_CacheMap.clear(); + m_AccessTimes.clear(); + m_Payloads.clear(); + m_TotalSize.store(0); +} + +uint64_t +ZenCacheMemoryLayer::CacheBucket::EntryCount() const +{ + RwLock::SharedLockScope _(m_BucketLock); + return static_cast<uint64_t>(m_CacheMap.size()); +} + +////////////////////////////////////////////////////////////////////////// + +ZenCacheDiskLayer::CacheBucket::CacheBucket(std::string BucketName) : m_BucketName(std::move(BucketName)), m_BucketId(Oid::Zero) +{ +} + +ZenCacheDiskLayer::CacheBucket::~CacheBucket() +{ +} + +bool +ZenCacheDiskLayer::CacheBucket::OpenOrCreate(std::filesystem::path BucketDir, bool AllowCreate) +{ + using namespace std::literals; + + m_BlocksBasePath = BucketDir / "blocks"; + m_BucketDir = BucketDir; + + CreateDirectories(m_BucketDir); + + std::filesystem::path ManifestPath{m_BucketDir / "zen_manifest"}; + + bool IsNew = false; + + CbObject Manifest = LoadCompactBinaryObject(ManifestPath); + + if (Manifest) + { + m_BucketId = Manifest["BucketId"sv].AsObjectId(); + if (m_BucketId == Oid::Zero) + { + return false; + } + } + else if (AllowCreate) + { + m_BucketId.Generate(); + + CbObjectWriter Writer; + Writer << "BucketId"sv << m_BucketId; + Manifest = Writer.Save(); + SaveCompactBinaryObject(ManifestPath, Manifest); + IsNew = true; + } + else + { + return false; + } + + OpenLog(IsNew); + + if (!IsNew) + { + Stopwatch Timer; + const auto _ = + MakeGuard([&] { ZEN_INFO("read store manifest '{}' in {}", ManifestPath, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); }); + + for (CbFieldView Entry : Manifest["Timestamps"sv]) + { + const CbObjectView Obj = Entry.AsObjectView(); + const IoHash Key = Obj["Key"sv].AsHash(); + + if (auto It = m_Index.find(Key); It != m_Index.end()) + { + size_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size()); + m_AccessTimes[EntryIndex] = Obj["LastAccess"sv].AsInt64(); + } + } + for (CbFieldView Entry : Manifest["RawInfo"sv]) + { + const CbObjectView Obj = Entry.AsObjectView(); + const IoHash Key = Obj["Key"sv].AsHash(); + if (auto It = m_Index.find(Key); It != m_Index.end()) + { + size_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size()); + m_Payloads[EntryIndex].RawHash = Obj["RawHash"sv].AsHash(); + m_Payloads[EntryIndex].RawSize = Obj["RawSize"sv].AsUInt64(); + } + } + } + + return true; +} + +void +ZenCacheDiskLayer::CacheBucket::MakeIndexSnapshot() +{ + uint64_t LogCount = m_SlogFile.GetLogCount(); + if (m_LogFlushPosition == LogCount) + { + return; + } + + ZEN_DEBUG("write store snapshot for '{}'", m_BucketDir / m_BucketName); + uint64_t EntryCount = 0; + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("wrote store snapshot for '{}' containing {} entries in {}", + m_BucketDir / m_BucketName, + EntryCount, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + namespace fs = std::filesystem; + + fs::path IndexPath = GetIndexPath(m_BucketDir, m_BucketName); + fs::path STmpIndexPath = GetTempIndexPath(m_BucketDir, m_BucketName); + + // Move index away, we keep it if something goes wrong + if (fs::is_regular_file(STmpIndexPath)) + { + fs::remove(STmpIndexPath); + } + if (fs::is_regular_file(IndexPath)) + { + fs::rename(IndexPath, STmpIndexPath); + } + + try + { + // Write the current state of the location map to a new index state + std::vector<DiskIndexEntry> Entries; + + { + Entries.resize(m_Index.size()); + + uint64_t EntryIndex = 0; + for (auto& Entry : m_Index) + { + DiskIndexEntry& IndexEntry = Entries[EntryIndex++]; + IndexEntry.Key = Entry.first; + IndexEntry.Location = m_Payloads[Entry.second].Location; + } + } + + BasicFile ObjectIndexFile; + ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kTruncate); + CacheBucketIndexHeader Header = {.EntryCount = Entries.size(), + .LogPosition = LogCount, + .PayloadAlignment = gsl::narrow<uint32_t>(m_PayloadAlignment)}; + + Header.Checksum = CacheBucketIndexHeader::ComputeChecksum(Header); + + ObjectIndexFile.Write(&Header, sizeof(CacheBucketIndexHeader), 0); + ObjectIndexFile.Write(Entries.data(), Entries.size() * sizeof(DiskIndexEntry), sizeof(CacheBucketIndexHeader)); + ObjectIndexFile.Flush(); + ObjectIndexFile.Close(); + EntryCount = Entries.size(); + m_LogFlushPosition = LogCount; + } + catch (std::exception& Err) + { + ZEN_ERROR("snapshot FAILED, reason: '{}'", Err.what()); + + // Restore any previous snapshot + + if (fs::is_regular_file(STmpIndexPath)) + { + fs::remove(IndexPath); + fs::rename(STmpIndexPath, IndexPath); + } + } + if (fs::is_regular_file(STmpIndexPath)) + { + fs::remove(STmpIndexPath); + } +} + +uint64_t +ZenCacheDiskLayer::CacheBucket::ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t& OutVersion) +{ + if (std::filesystem::is_regular_file(IndexPath)) + { + BasicFile ObjectIndexFile; + ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kRead); + uint64_t Size = ObjectIndexFile.FileSize(); + if (Size >= sizeof(CacheBucketIndexHeader)) + { + CacheBucketIndexHeader Header; + ObjectIndexFile.Read(&Header, sizeof(Header), 0); + if ((Header.Magic == CacheBucketIndexHeader::ExpectedMagic) && + (Header.Checksum == CacheBucketIndexHeader::ComputeChecksum(Header)) && (Header.PayloadAlignment > 0)) + { + switch (Header.Version) + { + case CacheBucketIndexHeader::Version2: + { + uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(CacheBucketIndexHeader))) / sizeof(DiskIndexEntry); + if (Header.EntryCount > ExpectedEntryCount) + { + break; + } + size_t EntryCount = 0; + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("read store '{}' index containing {} entries in {}", + IndexPath, + EntryCount, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + m_PayloadAlignment = Header.PayloadAlignment; + + std::vector<DiskIndexEntry> Entries; + Entries.resize(Header.EntryCount); + ObjectIndexFile.Read(Entries.data(), + Header.EntryCount * sizeof(DiskIndexEntry), + sizeof(CacheBucketIndexHeader)); + + m_Payloads.reserve(Header.EntryCount); + m_AccessTimes.reserve(Header.EntryCount); + m_Index.reserve(Header.EntryCount); + + std::string InvalidEntryReason; + for (const DiskIndexEntry& Entry : Entries) + { + if (!ValidateEntry(Entry, InvalidEntryReason)) + { + ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", IndexPath, InvalidEntryReason); + continue; + } + size_t EntryIndex = m_Payloads.size(); + m_Payloads.emplace_back(BucketPayload{.Location = Entry.Location, .RawSize = 0, .RawHash = IoHash::Zero}); + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_Index.insert_or_assign(Entry.Key, EntryIndex); + EntryCount++; + } + OutVersion = CacheBucketIndexHeader::Version2; + return Header.LogPosition; + } + break; + default: + break; + } + } + } + ZEN_WARN("skipping invalid index file '{}'", IndexPath); + } + return 0; +} + +uint64_t +ZenCacheDiskLayer::CacheBucket::ReadLog(const std::filesystem::path& LogPath, uint64_t SkipEntryCount) +{ + if (std::filesystem::is_regular_file(LogPath)) + { + uint64_t LogEntryCount = 0; + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("read store '{}' log containing {} entries in {}", LogPath, LogEntryCount, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + TCasLogFile<DiskIndexEntry> CasLog; + CasLog.Open(LogPath, CasLogFile::Mode::kRead); + if (CasLog.Initialize()) + { + uint64_t EntryCount = CasLog.GetLogCount(); + if (EntryCount < SkipEntryCount) + { + ZEN_WARN("reading full log at '{}', reason: Log position from index snapshot is out of range", LogPath); + SkipEntryCount = 0; + } + LogEntryCount = EntryCount - SkipEntryCount; + m_Index.reserve(LogEntryCount); + uint64_t InvalidEntryCount = 0; + CasLog.Replay( + [&](const DiskIndexEntry& Record) { + std::string InvalidEntryReason; + if (Record.Location.Flags & DiskLocation::kTombStone) + { + m_Index.erase(Record.Key); + return; + } + if (!ValidateEntry(Record, InvalidEntryReason)) + { + ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", LogPath, InvalidEntryReason); + ++InvalidEntryCount; + return; + } + size_t EntryIndex = m_Payloads.size(); + m_Payloads.emplace_back(BucketPayload{.Location = Record.Location, .RawSize = 0u, .RawHash = IoHash::Zero}); + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_Index.insert_or_assign(Record.Key, EntryIndex); + }, + SkipEntryCount); + if (InvalidEntryCount) + { + ZEN_WARN("found {} invalid entries in '{}'", InvalidEntryCount, m_BucketDir / m_BucketName); + } + return LogEntryCount; + } + } + return 0; +}; + +void +ZenCacheDiskLayer::CacheBucket::OpenLog(const bool IsNew) +{ + m_TotalStandaloneSize = 0; + + m_Index.clear(); + m_Payloads.clear(); + m_AccessTimes.clear(); + + std::filesystem::path LogPath = GetLogPath(m_BucketDir, m_BucketName); + std::filesystem::path IndexPath = GetIndexPath(m_BucketDir, m_BucketName); + + if (IsNew) + { + fs::remove(LogPath); + fs::remove(IndexPath); + fs::remove_all(m_BlocksBasePath); + } + + uint64_t LogEntryCount = 0; + { + uint32_t IndexVersion = 0; + m_LogFlushPosition = ReadIndexFile(IndexPath, IndexVersion); + if (IndexVersion == 0 && std::filesystem::is_regular_file(IndexPath)) + { + ZEN_WARN("removing invalid index file at '{}'", IndexPath); + fs::remove(IndexPath); + } + + if (TCasLogFile<DiskIndexEntry>::IsValid(LogPath)) + { + LogEntryCount = ReadLog(LogPath, m_LogFlushPosition); + } + else + { + ZEN_WARN("removing invalid cas log at '{}'", LogPath); + fs::remove(LogPath); + } + } + + CreateDirectories(m_BucketDir); + + m_SlogFile.Open(LogPath, CasLogFile::Mode::kWrite); + + std::vector<BlockStoreLocation> KnownLocations; + KnownLocations.reserve(m_Index.size()); + for (const auto& Entry : m_Index) + { + size_t EntryIndex = Entry.second; + const BucketPayload& Payload = m_Payloads[EntryIndex]; + const DiskLocation& Location = Payload.Location; + + if (Location.IsFlagSet(DiskLocation::kStandaloneFile)) + { + m_TotalStandaloneSize.fetch_add(Location.Size(), std::memory_order::relaxed); + continue; + } + const BlockStoreLocation& BlockLocation = Location.GetBlockLocation(m_PayloadAlignment); + KnownLocations.push_back(BlockLocation); + } + + m_BlockStore.Initialize(m_BlocksBasePath, MaxBlockSize, BlockStoreDiskLocation::MaxBlockIndex + 1, KnownLocations); + if (IsNew || LogEntryCount > 0) + { + MakeIndexSnapshot(); + } + // TODO: should validate integrity of container files here +} + +void +ZenCacheDiskLayer::CacheBucket::BuildPath(PathBuilderBase& Path, const IoHash& HashKey) const +{ + char HexString[sizeof(HashKey.Hash) * 2]; + ToHexBytes(HashKey.Hash, sizeof HashKey.Hash, HexString); + + Path.Append(m_BucketDir); + Path.AppendSeparator(); + Path.Append(L"blob"); + Path.AppendSeparator(); + Path.AppendAsciiRange(HexString, HexString + 3); + Path.AppendSeparator(); + Path.AppendAsciiRange(HexString + 3, HexString + 5); + Path.AppendSeparator(); + Path.AppendAsciiRange(HexString + 5, HexString + sizeof(HexString)); +} + +IoBuffer +ZenCacheDiskLayer::CacheBucket::GetInlineCacheValue(const DiskLocation& Loc) const +{ + BlockStoreLocation Location = Loc.GetBlockLocation(m_PayloadAlignment); + + IoBuffer Value = m_BlockStore.TryGetChunk(Location); + if (Value) + { + Value.SetContentType(Loc.GetContentType()); + } + + return Value; +} + +IoBuffer +ZenCacheDiskLayer::CacheBucket::GetStandaloneCacheValue(const DiskLocation& Loc, const IoHash& HashKey) const +{ + ExtendablePathBuilder<256> DataFilePath; + BuildPath(DataFilePath, HashKey); + + RwLock::SharedLockScope ValueLock(LockForHash(HashKey)); + + if (IoBuffer Data = IoBufferBuilder::MakeFromFile(DataFilePath.ToPath())) + { + Data.SetContentType(Loc.GetContentType()); + + return Data; + } + + return {}; +} + +bool +ZenCacheDiskLayer::CacheBucket::Get(const IoHash& HashKey, ZenCacheValue& OutValue) +{ + RwLock::SharedLockScope _(m_IndexLock); + auto It = m_Index.find(HashKey); + if (It == m_Index.end()) + { + return false; + } + size_t EntryIndex = It.value(); + const BucketPayload& Payload = m_Payloads[EntryIndex]; + m_AccessTimes[EntryIndex] = GcClock::TickCount(); + DiskLocation Location = Payload.Location; + OutValue.RawSize = Payload.RawSize; + OutValue.RawHash = Payload.RawHash; + if (Location.IsFlagSet(DiskLocation::kStandaloneFile)) + { + // We don't need to hold the index lock when we read a standalone file + _.ReleaseNow(); + OutValue.Value = GetStandaloneCacheValue(Location, HashKey); + } + else + { + OutValue.Value = GetInlineCacheValue(Location); + } + _.ReleaseNow(); + + if (!Location.IsFlagSet(DiskLocation::kStructured)) + { + if (OutValue.RawHash == IoHash::Zero && OutValue.RawSize == 0 && OutValue.Value.GetSize() > 0) + { + if (Location.IsFlagSet(DiskLocation::kCompressed)) + { + (void)CompressedBuffer::FromCompressed(SharedBuffer(OutValue.Value), OutValue.RawHash, OutValue.RawSize); + } + else + { + OutValue.RawHash = IoHash::HashBuffer(OutValue.Value); + OutValue.RawSize = OutValue.Value.GetSize(); + } + RwLock::ExclusiveLockScope __(m_IndexLock); + if (auto WriteIt = m_Index.find(HashKey); WriteIt != m_Index.end()) + { + BucketPayload& WritePayload = m_Payloads[WriteIt.value()]; + WritePayload.RawHash = OutValue.RawHash; + WritePayload.RawSize = OutValue.RawSize; + + m_LogFlushPosition = 0; // Force resave of index on exit + } + } + } + + return (bool)OutValue.Value; +} + +void +ZenCacheDiskLayer::CacheBucket::Put(const IoHash& HashKey, const ZenCacheValue& Value) +{ + if (Value.Value.Size() >= m_LargeObjectThreshold) + { + return PutStandaloneCacheValue(HashKey, Value); + } + PutInlineCacheValue(HashKey, Value); +} + +bool +ZenCacheDiskLayer::CacheBucket::Drop() +{ + RwLock::ExclusiveLockScope _(m_IndexLock); + + std::vector<std::unique_ptr<RwLock::ExclusiveLockScope>> ShardLocks; + ShardLocks.reserve(256); + for (RwLock& Lock : m_ShardedLocks) + { + ShardLocks.push_back(std::make_unique<RwLock::ExclusiveLockScope>(Lock)); + } + m_BlockStore.Close(); + m_SlogFile.Close(); + + bool Deleted = MoveAndDeleteDirectory(m_BucketDir); + + m_Index.clear(); + m_Payloads.clear(); + m_AccessTimes.clear(); + return Deleted; +} + +void +ZenCacheDiskLayer::CacheBucket::Flush() +{ + m_BlockStore.Flush(); + + RwLock::SharedLockScope _(m_IndexLock); + m_SlogFile.Flush(); + MakeIndexSnapshot(); + SaveManifest(); +} + +void +ZenCacheDiskLayer::CacheBucket::SaveManifest() +{ + using namespace std::literals; + + CbObjectWriter Writer; + Writer << "BucketId"sv << m_BucketId; + + if (!m_Index.empty()) + { + Writer.BeginArray("Timestamps"sv); + for (auto& Kv : m_Index) + { + const IoHash& Key = Kv.first; + GcClock::Tick AccessTime = m_AccessTimes[Kv.second]; + + Writer.BeginObject(); + Writer << "Key"sv << Key; + Writer << "LastAccess"sv << AccessTime; + Writer.EndObject(); + } + Writer.EndArray(); + + Writer.BeginArray("RawInfo"sv); + { + for (auto& Kv : m_Index) + { + const IoHash& Key = Kv.first; + const BucketPayload& Payload = m_Payloads[Kv.second]; + if (Payload.RawHash != IoHash::Zero) + { + Writer.BeginObject(); + Writer << "Key"sv << Key; + Writer << "RawHash"sv << Payload.RawHash; + Writer << "RawSize"sv << Payload.RawSize; + Writer.EndObject(); + } + } + } + Writer.EndArray(); + } + + SaveCompactBinaryObject(m_BucketDir / "zen_manifest", Writer.Save()); +} + +void +ZenCacheDiskLayer::CacheBucket::Scrub(ScrubContext& Ctx) +{ + std::vector<IoHash> BadKeys; + uint64_t ChunkCount{0}, ChunkBytes{0}; + std::vector<BlockStoreLocation> ChunkLocations; + std::vector<IoHash> ChunkIndexToChunkHash; + + auto ValidateEntry = [](const IoHash& Hash, ZenContentType ContentType, IoBuffer Buffer) { + if (ContentType == ZenContentType::kCbObject) + { + CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All); + return Error == CbValidateError::None; + } + if (ContentType == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize)) + { + return false; + } + if (RawHash != Hash) + { + return false; + } + } + return true; + }; + + RwLock::SharedLockScope _(m_IndexLock); + + const size_t BlockChunkInitialCount = m_Index.size() / 4; + ChunkLocations.reserve(BlockChunkInitialCount); + ChunkIndexToChunkHash.reserve(BlockChunkInitialCount); + + for (auto& Kv : m_Index) + { + const IoHash& HashKey = Kv.first; + const BucketPayload& Payload = m_Payloads[Kv.second]; + const DiskLocation& Loc = Payload.Location; + + if (Loc.IsFlagSet(DiskLocation::kStandaloneFile)) + { + ++ChunkCount; + ChunkBytes += Loc.Size(); + if (Loc.GetContentType() == ZenContentType::kBinary) + { + ExtendablePathBuilder<256> DataFilePath; + BuildPath(DataFilePath, HashKey); + + RwLock::SharedLockScope ValueLock(LockForHash(HashKey)); + + std::error_code Ec; + uintmax_t size = std::filesystem::file_size(DataFilePath.ToPath(), Ec); + if (Ec) + { + BadKeys.push_back(HashKey); + } + if (size != Loc.Size()) + { + BadKeys.push_back(HashKey); + } + continue; + } + IoBuffer Buffer = GetStandaloneCacheValue(Loc, HashKey); + if (!Buffer) + { + BadKeys.push_back(HashKey); + continue; + } + if (!ValidateEntry(HashKey, Loc.GetContentType(), Buffer)) + { + BadKeys.push_back(HashKey); + continue; + } + } + else + { + ChunkLocations.emplace_back(Loc.GetBlockLocation(m_PayloadAlignment)); + ChunkIndexToChunkHash.push_back(HashKey); + continue; + } + } + + const auto ValidateSmallChunk = [&](size_t ChunkIndex, const void* Data, uint64_t Size) { + ++ChunkCount; + ChunkBytes += Size; + const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex]; + if (!Data) + { + // ChunkLocation out of range of stored blocks + BadKeys.push_back(Hash); + return; + } + IoBuffer Buffer(IoBuffer::Wrap, Data, Size); + if (!Buffer) + { + BadKeys.push_back(Hash); + return; + } + const BucketPayload& Payload = m_Payloads[m_Index.at(Hash)]; + ZenContentType ContentType = Payload.Location.GetContentType(); + if (!ValidateEntry(Hash, ContentType, Buffer)) + { + BadKeys.push_back(Hash); + return; + } + }; + + const auto ValidateLargeChunk = [&](size_t ChunkIndex, BlockStoreFile& File, uint64_t Offset, uint64_t Size) { + ++ChunkCount; + ChunkBytes += Size; + const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex]; + // TODO: Add API to verify compressed buffer and possible structure data without having to memorymap the whole file + IoBuffer Buffer(IoBuffer::BorrowedFile, File.GetBasicFile().Handle(), Offset, Size); + if (!Buffer) + { + BadKeys.push_back(Hash); + return; + } + const BucketPayload& Payload = m_Payloads[m_Index.at(Hash)]; + ZenContentType ContentType = Payload.Location.GetContentType(); + if (!ValidateEntry(Hash, ContentType, Buffer)) + { + BadKeys.push_back(Hash); + return; + } + }; + + m_BlockStore.IterateChunks(ChunkLocations, ValidateSmallChunk, ValidateLargeChunk); + + _.ReleaseNow(); + + Ctx.ReportScrubbed(ChunkCount, ChunkBytes); + + if (!BadKeys.empty()) + { + ZEN_WARN("Scrubbing found {} bad chunks in '{}'", BadKeys.size(), m_BucketDir / m_BucketName); + + if (Ctx.RunRecovery()) + { + // Deal with bad chunks by removing them from our lookup map + + std::vector<DiskIndexEntry> LogEntries; + LogEntries.reserve(BadKeys.size()); + + { + RwLock::ExclusiveLockScope __(m_IndexLock); + for (const IoHash& BadKey : BadKeys) + { + // Log a tombstone and delete the in-memory index for the bad entry + const auto It = m_Index.find(BadKey); + const BucketPayload& Payload = m_Payloads[It->second]; + DiskLocation Location = Payload.Location; + Location.Flags |= DiskLocation::kTombStone; + LogEntries.push_back(DiskIndexEntry{.Key = BadKey, .Location = Location}); + m_Index.erase(BadKey); + } + } + for (const DiskIndexEntry& Entry : LogEntries) + { + if (Entry.Location.IsFlagSet(DiskLocation::kStandaloneFile)) + { + ExtendablePathBuilder<256> Path; + BuildPath(Path, Entry.Key); + fs::path FilePath = Path.ToPath(); + RwLock::ExclusiveLockScope ValueLock(LockForHash(Entry.Key)); + if (fs::is_regular_file(FilePath)) + { + ZEN_DEBUG("deleting bad standalone cache file '{}'", Path.ToUtf8()); + std::error_code Ec; + fs::remove(FilePath, Ec); // We don't care if we fail, we are no longer tracking this file... + } + m_TotalStandaloneSize.fetch_sub(Entry.Location.Size(), std::memory_order::relaxed); + } + } + m_SlogFile.Append(LogEntries); + + // Clean up m_AccessTimes and m_Payloads vectors + { + std::vector<BucketPayload> Payloads; + std::vector<AccessTime> AccessTimes; + IndexMap Index; + + { + RwLock::ExclusiveLockScope __(m_IndexLock); + size_t EntryCount = m_Index.size(); + Payloads.reserve(EntryCount); + AccessTimes.reserve(EntryCount); + Index.reserve(EntryCount); + for (auto It : m_Index) + { + size_t EntryIndex = Payloads.size(); + Payloads.push_back(m_Payloads[EntryIndex]); + AccessTimes.push_back(m_AccessTimes[EntryIndex]); + Index.insert({It.first, EntryIndex}); + } + m_Index.swap(Index); + m_Payloads.swap(Payloads); + m_AccessTimes.swap(AccessTimes); + } + } + } + } + + // Let whomever it concerns know about the bad chunks. This could + // be used to invalidate higher level data structures more efficiently + // than a full validation pass might be able to do + Ctx.ReportBadCidChunks(BadKeys); + + ZEN_INFO("cache bucket scrubbed: {} chunks ({})", ChunkCount, NiceBytes(ChunkBytes)); +} + +void +ZenCacheDiskLayer::CacheBucket::GatherReferences(GcContext& GcCtx) +{ + ZEN_TRACE_CPU("Z$::DiskLayer::CacheBucket::GatherReferences"); + + uint64_t WriteBlockTimeUs = 0; + uint64_t WriteBlockLongestTimeUs = 0; + uint64_t ReadBlockTimeUs = 0; + uint64_t ReadBlockLongestTimeUs = 0; + + Stopwatch TotalTimer; + const auto _ = MakeGuard([&] { + ZEN_DEBUG("gathered references from '{}' in {} write lock: {} ({}), read lock: {} ({})", + m_BucketDir / m_BucketName, + NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()), + NiceLatencyNs(WriteBlockTimeUs), + NiceLatencyNs(WriteBlockLongestTimeUs), + NiceLatencyNs(ReadBlockTimeUs), + NiceLatencyNs(ReadBlockLongestTimeUs)); + }); + + const GcClock::TimePoint ExpireTime = GcCtx.ExpireTime(); + + const GcClock::Tick ExpireTicks = ExpireTime.time_since_epoch().count(); + + IndexMap Index; + std::vector<AccessTime> AccessTimes; + std::vector<BucketPayload> Payloads; + { + RwLock::SharedLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ___ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + Index = m_Index; + AccessTimes = m_AccessTimes; + Payloads = m_Payloads; + } + + std::vector<IoHash> ExpiredKeys; + ExpiredKeys.reserve(1024); + + std::vector<IoHash> Cids; + Cids.reserve(1024); + + for (const auto& Entry : Index) + { + const IoHash& Key = Entry.first; + GcClock::Tick AccessTime = AccessTimes[Entry.second]; + if (AccessTime < ExpireTicks) + { + ExpiredKeys.push_back(Key); + continue; + } + + const DiskLocation& Loc = Payloads[Entry.second].Location; + + if (Loc.IsFlagSet(DiskLocation::kStructured)) + { + if (Cids.size() > 1024) + { + GcCtx.AddRetainedCids(Cids); + Cids.clear(); + } + + IoBuffer Buffer; + { + RwLock::SharedLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ___ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + if (Loc.IsFlagSet(DiskLocation::kStandaloneFile)) + { + // We don't need to hold the index lock when we read a standalone file + __.ReleaseNow(); + if (Buffer = GetStandaloneCacheValue(Loc, Key); !Buffer) + { + continue; + } + } + else if (Buffer = GetInlineCacheValue(Loc); !Buffer) + { + continue; + } + } + + ZEN_ASSERT(Buffer); + ZEN_ASSERT(Buffer.GetContentType() == ZenContentType::kCbObject); + CbObject Obj(SharedBuffer{Buffer}); + Obj.IterateAttachments([&Cids](CbFieldView Field) { Cids.push_back(Field.AsAttachment()); }); + } + } + + GcCtx.AddRetainedCids(Cids); + GcCtx.SetExpiredCacheKeys(m_BucketDir.string(), std::move(ExpiredKeys)); +} + +void +ZenCacheDiskLayer::CacheBucket::CollectGarbage(GcContext& GcCtx) +{ + ZEN_TRACE_CPU("Z$::DiskLayer::CacheBucket::CollectGarbage"); + + ZEN_DEBUG("collecting garbage from '{}'", m_BucketDir / m_BucketName); + + Stopwatch TotalTimer; + uint64_t WriteBlockTimeUs = 0; + uint64_t WriteBlockLongestTimeUs = 0; + uint64_t ReadBlockTimeUs = 0; + uint64_t ReadBlockLongestTimeUs = 0; + uint64_t TotalChunkCount = 0; + uint64_t DeletedSize = 0; + uint64_t OldTotalSize = TotalSize(); + + std::unordered_set<IoHash> DeletedChunks; + uint64_t MovedCount = 0; + + const auto _ = MakeGuard([&] { + ZEN_DEBUG( + "garbage collect from '{}' DONE after {}, write lock: {} ({}), read lock: {} ({}), collected {} bytes, deleted {} and moved " + "{} " + "of {} " + "entires ({}).", + m_BucketDir / m_BucketName, + NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()), + NiceLatencyNs(WriteBlockTimeUs), + NiceLatencyNs(WriteBlockLongestTimeUs), + NiceLatencyNs(ReadBlockTimeUs), + NiceLatencyNs(ReadBlockLongestTimeUs), + NiceBytes(DeletedSize), + DeletedChunks.size(), + MovedCount, + TotalChunkCount, + NiceBytes(OldTotalSize)); + RwLock::SharedLockScope _(m_IndexLock); + SaveManifest(); + }); + + m_SlogFile.Flush(); + + std::span<const IoHash> ExpiredCacheKeys = GcCtx.ExpiredCacheKeys(m_BucketDir.string()); + std::vector<IoHash> DeleteCacheKeys; + DeleteCacheKeys.reserve(ExpiredCacheKeys.size()); + GcCtx.FilterCids(ExpiredCacheKeys, [&](const IoHash& ChunkHash, bool Keep) { + if (Keep) + { + return; + } + DeleteCacheKeys.push_back(ChunkHash); + }); + if (DeleteCacheKeys.empty()) + { + ZEN_DEBUG("garbage collect SKIPPED, for '{}', no expired cache keys found", m_BucketDir / m_BucketName); + return; + } + + auto __ = MakeGuard([&]() { + if (!DeletedChunks.empty()) + { + // Clean up m_AccessTimes and m_Payloads vectors + std::vector<BucketPayload> Payloads; + std::vector<AccessTime> AccessTimes; + IndexMap Index; + + { + RwLock::ExclusiveLockScope _(m_IndexLock); + Stopwatch Timer; + const auto ___ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + size_t EntryCount = m_Index.size(); + Payloads.reserve(EntryCount); + AccessTimes.reserve(EntryCount); + Index.reserve(EntryCount); + for (auto It : m_Index) + { + size_t EntryIndex = Payloads.size(); + Payloads.push_back(m_Payloads[EntryIndex]); + AccessTimes.push_back(m_AccessTimes[EntryIndex]); + Index.insert({It.first, EntryIndex}); + } + m_Index.swap(Index); + m_Payloads.swap(Payloads); + m_AccessTimes.swap(AccessTimes); + } + GcCtx.AddDeletedCids(std::vector<IoHash>(DeletedChunks.begin(), DeletedChunks.end())); + } + }); + + std::vector<DiskIndexEntry> ExpiredStandaloneEntries; + IndexMap Index; + BlockStore::ReclaimSnapshotState BlockStoreState; + { + RwLock::SharedLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ____ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + if (m_Index.empty()) + { + ZEN_DEBUG("garbage collect SKIPPED, for '{}', container is empty", m_BucketDir / m_BucketName); + return; + } + BlockStoreState = m_BlockStore.GetReclaimSnapshotState(); + + SaveManifest(); + Index = m_Index; + + for (const IoHash& Key : DeleteCacheKeys) + { + if (auto It = Index.find(Key); It != Index.end()) + { + const BucketPayload& Payload = m_Payloads[It->second]; + DiskIndexEntry Entry = {.Key = It->first, .Location = Payload.Location}; + if (Entry.Location.Flags & DiskLocation::kStandaloneFile) + { + Entry.Location.Flags |= DiskLocation::kTombStone; + ExpiredStandaloneEntries.push_back(Entry); + } + } + } + if (GcCtx.IsDeletionMode()) + { + for (const auto& Entry : ExpiredStandaloneEntries) + { + m_Index.erase(Entry.Key); + m_TotalStandaloneSize.fetch_sub(Entry.Location.Size(), std::memory_order::relaxed); + DeletedChunks.insert(Entry.Key); + } + m_SlogFile.Append(ExpiredStandaloneEntries); + } + } + + if (GcCtx.IsDeletionMode()) + { + std::error_code Ec; + ExtendablePathBuilder<256> Path; + + for (const auto& Entry : ExpiredStandaloneEntries) + { + const IoHash& Key = Entry.Key; + const DiskLocation& Loc = Entry.Location; + + Path.Reset(); + BuildPath(Path, Key); + fs::path FilePath = Path.ToPath(); + + { + RwLock::SharedLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ____ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + if (m_Index.contains(Key)) + { + // Someone added it back, let the file on disk be + ZEN_DEBUG("skipping z$ delete standalone of file '{}' FAILED, it has been added back", Path.ToUtf8()); + continue; + } + __.ReleaseNow(); + + RwLock::ExclusiveLockScope ValueLock(LockForHash(Key)); + if (fs::is_regular_file(FilePath)) + { + ZEN_DEBUG("deleting standalone cache file '{}'", Path.ToUtf8()); + fs::remove(FilePath, Ec); + } + } + + if (Ec) + { + ZEN_WARN("delete expired z$ standalone file '{}' FAILED, reason: '{}'", Path.ToUtf8(), Ec.message()); + Ec.clear(); + DiskLocation RestoreLocation = Loc; + RestoreLocation.Flags &= ~DiskLocation::kTombStone; + + RwLock::ExclusiveLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ___ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + ReadBlockTimeUs += ElapsedUs; + ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs); + }); + if (m_Index.contains(Key)) + { + continue; + } + m_SlogFile.Append(DiskIndexEntry{.Key = Key, .Location = RestoreLocation}); + size_t EntryIndex = m_Payloads.size(); + m_Payloads.emplace_back(BucketPayload{.Location = RestoreLocation}); + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_Index.insert({Key, EntryIndex}); + m_TotalStandaloneSize.fetch_add(RestoreLocation.Size(), std::memory_order::relaxed); + DeletedChunks.erase(Key); + continue; + } + DeletedSize += Entry.Location.Size(); + } + } + + TotalChunkCount = Index.size(); + + std::vector<IoHash> TotalChunkHashes; + TotalChunkHashes.reserve(TotalChunkCount); + for (const auto& Entry : Index) + { + const DiskLocation& Location = m_Payloads[Entry.second].Location; + + if (Location.Flags & DiskLocation::kStandaloneFile) + { + continue; + } + TotalChunkHashes.push_back(Entry.first); + } + + if (TotalChunkHashes.empty()) + { + return; + } + TotalChunkCount = TotalChunkHashes.size(); + + std::vector<BlockStoreLocation> ChunkLocations; + BlockStore::ChunkIndexArray KeepChunkIndexes; + std::vector<IoHash> ChunkIndexToChunkHash; + ChunkLocations.reserve(TotalChunkCount); + ChunkLocations.reserve(TotalChunkCount); + ChunkIndexToChunkHash.reserve(TotalChunkCount); + + GcCtx.FilterCids(TotalChunkHashes, [&](const IoHash& ChunkHash, bool Keep) { + auto KeyIt = Index.find(ChunkHash); + const DiskLocation& DiskLocation = m_Payloads[KeyIt->second].Location; + BlockStoreLocation Location = DiskLocation.GetBlockLocation(m_PayloadAlignment); + size_t ChunkIndex = ChunkLocations.size(); + ChunkLocations.push_back(Location); + ChunkIndexToChunkHash[ChunkIndex] = ChunkHash; + if (Keep) + { + KeepChunkIndexes.push_back(ChunkIndex); + } + }); + + size_t DeleteCount = TotalChunkCount - KeepChunkIndexes.size(); + + const bool PerformDelete = GcCtx.IsDeletionMode() && GcCtx.CollectSmallObjects(); + if (!PerformDelete) + { + m_BlockStore.ReclaimSpace(BlockStoreState, ChunkLocations, KeepChunkIndexes, m_PayloadAlignment, true); + uint64_t CurrentTotalSize = TotalSize(); + ZEN_DEBUG("garbage collect from '{}' DISABLED, found {} chunks of total {} {}", + m_BucketDir / m_BucketName, + DeleteCount, + TotalChunkCount, + NiceBytes(CurrentTotalSize)); + return; + } + + m_BlockStore.ReclaimSpace( + BlockStoreState, + ChunkLocations, + KeepChunkIndexes, + m_PayloadAlignment, + false, + [&](const BlockStore::MovedChunksArray& MovedChunks, const BlockStore::ChunkIndexArray& RemovedChunks) { + std::vector<DiskIndexEntry> LogEntries; + LogEntries.reserve(MovedChunks.size() + RemovedChunks.size()); + for (const auto& Entry : MovedChunks) + { + size_t ChunkIndex = Entry.first; + const BlockStoreLocation& NewLocation = Entry.second; + const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex]; + const BucketPayload& OldPayload = m_Payloads[Index[ChunkHash]]; + const DiskLocation& OldDiskLocation = OldPayload.Location; + LogEntries.push_back( + {.Key = ChunkHash, .Location = DiskLocation(NewLocation, m_PayloadAlignment, OldDiskLocation.GetFlags())}); + } + for (const size_t ChunkIndex : RemovedChunks) + { + const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex]; + const BucketPayload& OldPayload = m_Payloads[Index[ChunkHash]]; + const DiskLocation& OldDiskLocation = OldPayload.Location; + LogEntries.push_back({.Key = ChunkHash, + .Location = DiskLocation(OldDiskLocation.GetBlockLocation(m_PayloadAlignment), + m_PayloadAlignment, + OldDiskLocation.GetFlags() | DiskLocation::kTombStone)}); + DeletedChunks.insert(ChunkHash); + } + + m_SlogFile.Append(LogEntries); + m_SlogFile.Flush(); + { + RwLock::ExclusiveLockScope __(m_IndexLock); + Stopwatch Timer; + const auto ____ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + ReadBlockTimeUs += ElapsedUs; + ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs); + }); + for (const DiskIndexEntry& Entry : LogEntries) + { + if (Entry.Location.GetFlags() & DiskLocation::kTombStone) + { + m_Index.erase(Entry.Key); + continue; + } + m_Payloads[m_Index[Entry.Key]].Location = Entry.Location; + } + } + }, + [&]() { return GcCtx.CollectSmallObjects(); }); +} + +void +ZenCacheDiskLayer::CacheBucket::UpdateAccessTimes(const std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes) +{ + using namespace access_tracking; + + for (const KeyAccessTime& KeyTime : AccessTimes) + { + if (auto It = m_Index.find(KeyTime.Key); It != m_Index.end()) + { + size_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size()); + m_AccessTimes[EntryIndex] = KeyTime.LastAccess; + } + } +} + +uint64_t +ZenCacheDiskLayer::CacheBucket::EntryCount() const +{ + RwLock::SharedLockScope _(m_IndexLock); + return static_cast<uint64_t>(m_Index.size()); +} + +CacheValueDetails::ValueDetails +ZenCacheDiskLayer::CacheBucket::GetValueDetails(const IoHash& Key, size_t Index) const +{ + std::vector<IoHash> Attachments; + const BucketPayload& Payload = m_Payloads[Index]; + if (Payload.Location.IsFlagSet(DiskLocation::kStructured)) + { + IoBuffer Value = Payload.Location.IsFlagSet(DiskLocation::kStandaloneFile) ? GetStandaloneCacheValue(Payload.Location, Key) + : GetInlineCacheValue(Payload.Location); + CbObject Obj(SharedBuffer{Value}); + Obj.IterateAttachments([&Attachments](CbFieldView Field) { Attachments.emplace_back(Field.AsAttachment()); }); + } + return CacheValueDetails::ValueDetails{.Size = Payload.Location.Size(), + .RawSize = Payload.RawSize, + .RawHash = Payload.RawHash, + .LastAccess = m_AccessTimes[Index], + .Attachments = std::move(Attachments), + .ContentType = Payload.Location.GetContentType()}; +} + +CacheValueDetails::BucketDetails +ZenCacheDiskLayer::CacheBucket::GetValueDetails(const std::string_view ValueFilter) const +{ + CacheValueDetails::BucketDetails Details; + RwLock::SharedLockScope _(m_IndexLock); + if (ValueFilter.empty()) + { + Details.Values.reserve(m_Index.size()); + for (const auto& It : m_Index) + { + Details.Values.insert_or_assign(It.first, GetValueDetails(It.first, It.second)); + } + } + else + { + IoHash Key = IoHash::FromHexString(ValueFilter); + if (auto It = m_Index.find(Key); It != m_Index.end()) + { + Details.Values.insert_or_assign(It->first, GetValueDetails(It->first, It->second)); + } + } + return Details; +} + +void +ZenCacheDiskLayer::CollectGarbage(GcContext& GcCtx) +{ + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + CacheBucket& Bucket = *Kv.second; + Bucket.CollectGarbage(GcCtx); + } +} + +void +ZenCacheDiskLayer::UpdateAccessTimes(const zen::access_tracking::AccessTimes& AccessTimes) +{ + RwLock::SharedLockScope _(m_Lock); + + for (const auto& Kv : AccessTimes.Buckets) + { + if (auto It = m_Buckets.find(Kv.first); It != m_Buckets.end()) + { + CacheBucket& Bucket = *It->second; + Bucket.UpdateAccessTimes(Kv.second); + } + } +} + +void +ZenCacheDiskLayer::CacheBucket::PutStandaloneCacheValue(const IoHash& HashKey, const ZenCacheValue& Value) +{ + uint64_t NewFileSize = Value.Value.Size(); + + TemporaryFile DataFile; + + std::error_code Ec; + DataFile.CreateTemporary(m_BucketDir.c_str(), Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to open temporary file for put in '{}'", m_BucketDir)); + } + + bool CleanUpTempFile = false; + auto __ = MakeGuard([&] { + if (CleanUpTempFile) + { + std::error_code Ec; + std::filesystem::remove(DataFile.GetPath(), Ec); + if (Ec) + { + ZEN_WARN("Failed to clean up temporary file '{}' for put in '{}', reason '{}'", + DataFile.GetPath(), + m_BucketDir, + Ec.message()); + } + } + }); + + DataFile.WriteAll(Value.Value, Ec); + if (Ec) + { + throw std::system_error(Ec, + fmt::format("Failed to write payload ({} bytes) to temporary file '{}' for put in '{}'", + NiceBytes(NewFileSize), + DataFile.GetPath().string(), + m_BucketDir)); + } + + ExtendablePathBuilder<256> DataFilePath; + BuildPath(DataFilePath, HashKey); + std::filesystem::path FsPath{DataFilePath.ToPath()}; + + RwLock::ExclusiveLockScope ValueLock(LockForHash(HashKey)); + + // We do a speculative remove of the file instead of probing with a exists call and check the error code instead + std::filesystem::remove(FsPath, Ec); + if (Ec) + { + if (Ec.value() != ENOENT) + { + ZEN_WARN("Failed to remove file '{}' for put in '{}', reason: '{}', retrying.", FsPath, m_BucketDir, Ec.message()); + Sleep(100); + Ec.clear(); + std::filesystem::remove(FsPath, Ec); + if (Ec && Ec.value() != ENOENT) + { + throw std::system_error(Ec, fmt::format("Failed to remove file '{}' for put in '{}'", FsPath, m_BucketDir)); + } + } + } + + DataFile.MoveTemporaryIntoPlace(FsPath, Ec); + if (Ec) + { + CreateDirectories(FsPath.parent_path()); + Ec.clear(); + + // Try again + DataFile.MoveTemporaryIntoPlace(FsPath, Ec); + if (Ec) + { + ZEN_WARN("Failed to finalize file '{}', moving from '{}' for put in '{}', reason: '{}', retrying.", + FsPath, + DataFile.GetPath(), + m_BucketDir, + Ec.message()); + Sleep(100); + Ec.clear(); + DataFile.MoveTemporaryIntoPlace(FsPath, Ec); + if (Ec) + { + throw std::system_error( + Ec, + fmt::format("Failed to finalize file '{}', moving from '{}' for put in '{}'", FsPath, DataFile.GetPath(), m_BucketDir)); + } + } + } + + // Once we have called MoveTemporaryIntoPlace automatic clean up the temp file + // will be disabled as the file handle has already been closed + CleanUpTempFile = false; + + uint8_t EntryFlags = DiskLocation::kStandaloneFile; + + if (Value.Value.GetContentType() == ZenContentType::kCbObject) + { + EntryFlags |= DiskLocation::kStructured; + } + else if (Value.Value.GetContentType() == ZenContentType::kCompressedBinary) + { + EntryFlags |= DiskLocation::kCompressed; + } + + DiskLocation Loc(NewFileSize, EntryFlags); + + RwLock::ExclusiveLockScope _(m_IndexLock); + if (auto It = m_Index.find(HashKey); It == m_Index.end()) + { + // Previously unknown object + size_t EntryIndex = m_Payloads.size(); + m_Payloads.emplace_back(BucketPayload{.Location = Loc, .RawSize = Value.RawSize, .RawHash = Value.RawHash}); + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_Index.insert_or_assign(HashKey, EntryIndex); + } + else + { + // TODO: should check if write is idempotent and bail out if it is? + size_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size()); + m_Payloads[EntryIndex] = BucketPayload{.Location = Loc, .RawSize = Value.RawSize, .RawHash = Value.RawHash}; + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_TotalStandaloneSize.fetch_sub(Loc.Size(), std::memory_order::relaxed); + } + + m_SlogFile.Append({.Key = HashKey, .Location = Loc}); + m_TotalStandaloneSize.fetch_add(NewFileSize, std::memory_order::relaxed); +} + +void +ZenCacheDiskLayer::CacheBucket::PutInlineCacheValue(const IoHash& HashKey, const ZenCacheValue& Value) +{ + uint8_t EntryFlags = 0; + + if (Value.Value.GetContentType() == ZenContentType::kCbObject) + { + EntryFlags |= DiskLocation::kStructured; + } + else if (Value.Value.GetContentType() == ZenContentType::kCompressedBinary) + { + EntryFlags |= DiskLocation::kCompressed; + } + + m_BlockStore.WriteChunk(Value.Value.Data(), Value.Value.Size(), m_PayloadAlignment, [&](const BlockStoreLocation& BlockStoreLocation) { + DiskLocation Location(BlockStoreLocation, m_PayloadAlignment, EntryFlags); + m_SlogFile.Append({.Key = HashKey, .Location = Location}); + + RwLock::ExclusiveLockScope _(m_IndexLock); + if (auto It = m_Index.find(HashKey); It != m_Index.end()) + { + // TODO: should check if write is idempotent and bail out if it is? + // this would requiring comparing contents on disk unless we add a + // content hash to the index entry + size_t EntryIndex = It.value(); + ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size()); + m_Payloads[EntryIndex] = (BucketPayload{.Location = Location, .RawSize = Value.RawSize, .RawHash = Value.RawHash}); + m_AccessTimes[EntryIndex] = GcClock::TickCount(); + } + else + { + size_t EntryIndex = m_Payloads.size(); + m_Payloads.emplace_back(BucketPayload{.Location = Location, .RawSize = Value.RawSize, .RawHash = Value.RawHash}); + m_AccessTimes.emplace_back(GcClock::TickCount()); + m_Index.insert_or_assign(HashKey, EntryIndex); + } + }); +} + +////////////////////////////////////////////////////////////////////////// + +ZenCacheDiskLayer::ZenCacheDiskLayer(const std::filesystem::path& RootDir) : m_RootDir(RootDir) +{ +} + +ZenCacheDiskLayer::~ZenCacheDiskLayer() = default; + +bool +ZenCacheDiskLayer::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue) +{ + const auto BucketName = std::string(InBucket); + CacheBucket* Bucket = nullptr; + + { + RwLock::SharedLockScope _(m_Lock); + + auto It = m_Buckets.find(BucketName); + + if (It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + } + + if (Bucket == nullptr) + { + // Bucket needs to be opened/created + + RwLock::ExclusiveLockScope _(m_Lock); + + if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + else + { + auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName)); + Bucket = InsertResult.first->second.get(); + + std::filesystem::path BucketPath = m_RootDir; + BucketPath /= BucketName; + + if (!Bucket->OpenOrCreate(BucketPath)) + { + m_Buckets.erase(InsertResult.first); + return false; + } + } + } + + ZEN_ASSERT(Bucket != nullptr); + return Bucket->Get(HashKey, OutValue); +} + +void +ZenCacheDiskLayer::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value) +{ + const auto BucketName = std::string(InBucket); + CacheBucket* Bucket = nullptr; + + { + RwLock::SharedLockScope _(m_Lock); + + auto It = m_Buckets.find(BucketName); + + if (It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + } + + if (Bucket == nullptr) + { + // New bucket needs to be created + + RwLock::ExclusiveLockScope _(m_Lock); + + if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end()) + { + Bucket = It->second.get(); + } + else + { + auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName)); + Bucket = InsertResult.first->second.get(); + + std::filesystem::path BucketPath = m_RootDir; + BucketPath /= BucketName; + + try + { + if (!Bucket->OpenOrCreate(BucketPath)) + { + ZEN_WARN("Found directory '{}' in our base directory '{}' but it is not a valid bucket", BucketName, m_RootDir); + m_Buckets.erase(InsertResult.first); + return; + } + } + catch (const std::exception& Err) + { + ZEN_ERROR("creating bucket '{}' in '{}' FAILED, reason: '{}'", BucketName, BucketPath, Err.what()); + return; + } + } + } + + ZEN_ASSERT(Bucket != nullptr); + + Bucket->Put(HashKey, Value); +} + +void +ZenCacheDiskLayer::DiscoverBuckets() +{ + DirectoryContent DirContent; + GetDirectoryContent(m_RootDir, DirectoryContent::IncludeDirsFlag, DirContent); + + // Initialize buckets + + RwLock::ExclusiveLockScope _(m_Lock); + + for (const std::filesystem::path& BucketPath : DirContent.Directories) + { + const std::string BucketName = PathToUtf8(BucketPath.stem()); + // New bucket needs to be created + if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end()) + { + continue; + } + + auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName)); + CacheBucket& Bucket = *InsertResult.first->second; + + try + { + if (!Bucket.OpenOrCreate(BucketPath, /* AllowCreate */ false)) + { + ZEN_WARN("Found directory '{}' in our base directory '{}' but it is not a valid bucket", BucketName, m_RootDir); + + m_Buckets.erase(InsertResult.first); + continue; + } + } + catch (const std::exception& Err) + { + ZEN_ERROR("creating bucket '{}' in '{}' FAILED, reason: '{}'", BucketName, BucketPath, Err.what()); + return; + } + ZEN_INFO("Discovered bucket '{}'", BucketName); + } +} + +bool +ZenCacheDiskLayer::DropBucket(std::string_view InBucket) +{ + RwLock::ExclusiveLockScope _(m_Lock); + + auto It = m_Buckets.find(std::string(InBucket)); + + if (It != m_Buckets.end()) + { + CacheBucket& Bucket = *It->second; + m_DroppedBuckets.push_back(std::move(It->second)); + m_Buckets.erase(It); + + return Bucket.Drop(); + } + + // Make sure we remove the folder even if we don't know about the bucket + std::filesystem::path BucketPath = m_RootDir; + BucketPath /= std::string(InBucket); + return MoveAndDeleteDirectory(BucketPath); +} + +bool +ZenCacheDiskLayer::Drop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + std::vector<std::unique_ptr<CacheBucket>> Buckets; + Buckets.reserve(m_Buckets.size()); + while (!m_Buckets.empty()) + { + const auto& It = m_Buckets.begin(); + CacheBucket& Bucket = *It->second; + m_DroppedBuckets.push_back(std::move(It->second)); + m_Buckets.erase(It->first); + if (!Bucket.Drop()) + { + return false; + } + } + return MoveAndDeleteDirectory(m_RootDir); +} + +void +ZenCacheDiskLayer::Flush() +{ + std::vector<CacheBucket*> Buckets; + + { + RwLock::SharedLockScope _(m_Lock); + Buckets.reserve(m_Buckets.size()); + for (auto& Kv : m_Buckets) + { + CacheBucket* Bucket = Kv.second.get(); + Buckets.push_back(Bucket); + } + } + + for (auto& Bucket : Buckets) + { + Bucket->Flush(); + } +} + +void +ZenCacheDiskLayer::Scrub(ScrubContext& Ctx) +{ + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + CacheBucket& Bucket = *Kv.second; + Bucket.Scrub(Ctx); + } +} + +void +ZenCacheDiskLayer::GatherReferences(GcContext& GcCtx) +{ + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + CacheBucket& Bucket = *Kv.second; + Bucket.GatherReferences(GcCtx); + } +} + +uint64_t +ZenCacheDiskLayer::TotalSize() const +{ + uint64_t TotalSize{}; + RwLock::SharedLockScope _(m_Lock); + + for (auto& Kv : m_Buckets) + { + TotalSize += Kv.second->TotalSize(); + } + + return TotalSize; +} + +ZenCacheDiskLayer::Info +ZenCacheDiskLayer::GetInfo() const +{ + ZenCacheDiskLayer::Info Info = {.Config = {.RootDir = m_RootDir}, .TotalSize = TotalSize()}; + + RwLock::SharedLockScope _(m_Lock); + Info.BucketNames.reserve(m_Buckets.size()); + for (auto& Kv : m_Buckets) + { + Info.BucketNames.push_back(Kv.first); + Info.EntryCount += Kv.second->EntryCount(); + } + return Info; +} + +std::optional<ZenCacheDiskLayer::BucketInfo> +ZenCacheDiskLayer::GetBucketInfo(std::string_view Bucket) const +{ + RwLock::SharedLockScope _(m_Lock); + + if (auto It = m_Buckets.find(std::string(Bucket)); It != m_Buckets.end()) + { + return ZenCacheDiskLayer::BucketInfo{.EntryCount = It->second->EntryCount(), .TotalSize = It->second->TotalSize()}; + } + return {}; +} + +CacheValueDetails::NamespaceDetails +ZenCacheDiskLayer::GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const +{ + RwLock::SharedLockScope _(m_Lock); + CacheValueDetails::NamespaceDetails Details; + if (BucketFilter.empty()) + { + Details.Buckets.reserve(BucketFilter.empty() ? m_Buckets.size() : 1); + for (auto& Kv : m_Buckets) + { + Details.Buckets[Kv.first] = Kv.second->GetValueDetails(ValueFilter); + } + } + else if (auto It = m_Buckets.find(std::string(BucketFilter)); It != m_Buckets.end()) + { + Details.Buckets[It->first] = It->second->GetValueDetails(ValueFilter); + } + return Details; +} + +//////////////////////////// ZenCacheStore + +static constexpr std::string_view UE4DDCNamespaceName = "ue4.ddc"; + +ZenCacheStore::ZenCacheStore(GcManager& Gc, const Configuration& Configuration) : m_Gc(Gc), m_Configuration(Configuration) +{ + CreateDirectories(m_Configuration.BasePath); + + DirectoryContent DirContent; + GetDirectoryContent(m_Configuration.BasePath, DirectoryContent::IncludeDirsFlag, DirContent); + + std::vector<std::string> Namespaces; + for (const std::filesystem::path& DirPath : DirContent.Directories) + { + std::string DirName = PathToUtf8(DirPath.filename()); + if (DirName.starts_with(NamespaceDiskPrefix)) + { + Namespaces.push_back(DirName.substr(NamespaceDiskPrefix.length())); + continue; + } + } + + ZEN_INFO("Found {} namespaces in '{}'", Namespaces.size(), m_Configuration.BasePath); + + if (std::find(Namespaces.begin(), Namespaces.end(), UE4DDCNamespaceName) == Namespaces.end()) + { + // default (unspecified) and ue4-ddc namespace points to the same namespace instance + + std::filesystem::path DefaultNamespaceFolder = + m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, UE4DDCNamespaceName); + CreateDirectories(DefaultNamespaceFolder); + Namespaces.push_back(std::string(UE4DDCNamespaceName)); + } + + for (const std::string& NamespaceName : Namespaces) + { + m_Namespaces[NamespaceName] = + std::make_unique<ZenCacheNamespace>(Gc, m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, NamespaceName)); + } +} + +ZenCacheStore::~ZenCacheStore() +{ + m_Namespaces.clear(); +} + +bool +ZenCacheStore::Get(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue) +{ + if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store) + { + return Store->Get(Bucket, HashKey, OutValue); + } + ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Get, bucket '{}', key '{}'", Namespace, Bucket, HashKey.ToHexString()); + + return false; +} + +void +ZenCacheStore::Put(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value) +{ + if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store) + { + return Store->Put(Bucket, HashKey, Value); + } + ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Put, bucket '{}', key '{}'", Namespace, Bucket, HashKey.ToHexString()); +} + +bool +ZenCacheStore::DropBucket(std::string_view Namespace, std::string_view Bucket) +{ + if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store) + { + return Store->DropBucket(Bucket); + } + ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::DropBucket, bucket '{}'", Namespace, Bucket); + return false; +} + +bool +ZenCacheStore::DropNamespace(std::string_view InNamespace) +{ + RwLock::SharedLockScope _(m_NamespacesLock); + if (auto It = m_Namespaces.find(std::string(InNamespace)); It != m_Namespaces.end()) + { + ZenCacheNamespace& Namespace = *It->second; + m_DroppedNamespaces.push_back(std::move(It->second)); + m_Namespaces.erase(It); + return Namespace.Drop(); + } + ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::DropNamespace", InNamespace); + return false; +} + +void +ZenCacheStore::Flush() +{ + IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) { Store.Flush(); }); +} + +void +ZenCacheStore::Scrub(ScrubContext& Ctx) +{ + IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) { Store.Scrub(Ctx); }); +} + +CacheValueDetails +ZenCacheStore::GetValueDetails(const std::string_view NamespaceFilter, + const std::string_view BucketFilter, + const std::string_view ValueFilter) const +{ + CacheValueDetails Details; + if (NamespaceFilter.empty()) + { + IterateNamespaces([&](std::string_view Namespace, ZenCacheNamespace& Store) { + Details.Namespaces[std::string(Namespace)] = Store.GetValueDetails(BucketFilter, ValueFilter); + }); + } + else if (const ZenCacheNamespace* Store = FindNamespace(NamespaceFilter); Store != nullptr) + { + Details.Namespaces[std::string(NamespaceFilter)] = Store->GetValueDetails(BucketFilter, ValueFilter); + } + return Details; +} + +ZenCacheNamespace* +ZenCacheStore::GetNamespace(std::string_view Namespace) +{ + RwLock::SharedLockScope _(m_NamespacesLock); + if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end()) + { + return It->second.get(); + } + if (Namespace == DefaultNamespace) + { + if (auto It = m_Namespaces.find(std::string(UE4DDCNamespaceName)); It != m_Namespaces.end()) + { + return It->second.get(); + } + } + _.ReleaseNow(); + + if (!m_Configuration.AllowAutomaticCreationOfNamespaces) + { + return nullptr; + } + + RwLock::ExclusiveLockScope __(m_NamespacesLock); + if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end()) + { + return It->second.get(); + } + + auto NewNamespace = m_Namespaces.insert_or_assign( + std::string(Namespace), + std::make_unique<ZenCacheNamespace>(m_Gc, m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, Namespace))); + return NewNamespace.first->second.get(); +} + +const ZenCacheNamespace* +ZenCacheStore::FindNamespace(std::string_view Namespace) const +{ + RwLock::SharedLockScope _(m_NamespacesLock); + if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end()) + { + return It->second.get(); + } + if (Namespace == DefaultNamespace) + { + if (auto It = m_Namespaces.find(std::string(UE4DDCNamespaceName)); It != m_Namespaces.end()) + { + return It->second.get(); + } + } + return nullptr; +} + +void +ZenCacheStore::IterateNamespaces(const std::function<void(std::string_view Namespace, ZenCacheNamespace& Store)>& Callback) const +{ + std::vector<std::pair<std::string, ZenCacheNamespace&>> Namespaces; + { + RwLock::SharedLockScope _(m_NamespacesLock); + Namespaces.reserve(m_Namespaces.size()); + for (const auto& Entry : m_Namespaces) + { + if (Entry.first == DefaultNamespace) + { + continue; + } + Namespaces.push_back({Entry.first, *Entry.second}); + } + } + for (auto& Entry : Namespaces) + { + Callback(Entry.first, Entry.second); + } +} + +GcStorageSize +ZenCacheStore::StorageSize() const +{ + GcStorageSize Size; + IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) { + GcStorageSize StoreSize = Store.StorageSize(); + Size.MemorySize += StoreSize.MemorySize; + Size.DiskSize += StoreSize.DiskSize; + }); + return Size; +} + +ZenCacheStore::Info +ZenCacheStore::GetInfo() const +{ + ZenCacheStore::Info Info = {.Config = m_Configuration, .StorageSize = StorageSize()}; + + IterateNamespaces([&Info](std::string_view NamespaceName, ZenCacheNamespace& Namespace) { + Info.NamespaceNames.push_back(std::string(NamespaceName)); + ZenCacheNamespace::Info NamespaceInfo = Namespace.GetInfo(); + Info.DiskEntryCount += NamespaceInfo.DiskLayerInfo.EntryCount; + Info.MemoryEntryCount += NamespaceInfo.MemoryLayerInfo.EntryCount; + }); + + return Info; +} + +std::optional<ZenCacheNamespace::Info> +ZenCacheStore::GetNamespaceInfo(std::string_view NamespaceName) +{ + if (const ZenCacheNamespace* Namespace = FindNamespace(NamespaceName); Namespace) + { + return Namespace->GetInfo(); + } + return {}; +} + +std::optional<ZenCacheNamespace::BucketInfo> +ZenCacheStore::GetBucketInfo(std::string_view NamespaceName, std::string_view BucketName) +{ + if (const ZenCacheNamespace* Namespace = FindNamespace(NamespaceName); Namespace) + { + return Namespace->GetBucketInfo(BucketName); + } + return {}; +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +using namespace std::literals; + +namespace testutils { + IoHash CreateKey(size_t KeyValue) { return IoHash::HashBuffer(&KeyValue, sizeof(size_t)); } + + IoBuffer CreateBinaryCacheValue(uint64_t Size) + { + static std::random_device rd; + static std::mt19937 g(rd()); + + std::vector<uint8_t> Values; + Values.resize(Size); + for (size_t Idx = 0; Idx < Size; ++Idx) + { + Values[Idx] = static_cast<uint8_t>(Idx); + } + std::shuffle(Values.begin(), Values.end(), g); + + IoBuffer Buf(IoBuffer::Clone, Values.data(), Values.size()); + Buf.SetContentType(ZenContentType::kBinary); + return Buf; + }; + +} // namespace testutils + +TEST_CASE("z$.store") +{ + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + const int kIterationCount = 100; + + for (int i = 0; i < kIterationCount; ++i) + { + const IoHash Key = IoHash::HashBuffer(&i, sizeof i); + + CbObjectWriter Cbo; + Cbo << "hey" << i; + CbObject Obj = Cbo.Save(); + + ZenCacheValue Value; + Value.Value = Obj.GetBuffer().AsIoBuffer(); + Value.Value.SetContentType(ZenContentType::kCbObject); + + Zcs.Put("test_bucket"sv, Key, Value); + } + + for (int i = 0; i < kIterationCount; ++i) + { + const IoHash Key = IoHash::HashBuffer(&i, sizeof i); + + ZenCacheValue Value; + Zcs.Get("test_bucket"sv, Key, /* out */ Value); + + REQUIRE(Value.Value); + CHECK(Value.Value.GetContentType() == ZenContentType::kCbObject); + CHECK_EQ(ValidateCompactBinary(Value.Value, CbValidateMode::All), CbValidateError::None); + CbObject Obj = LoadCompactBinaryObject(Value.Value); + CHECK_EQ(Obj["hey"].AsInt32(), i); + } +} + +TEST_CASE("z$.size") +{ + const auto CreateCacheValue = [](size_t Size) -> CbObject { + std::vector<uint8_t> Buf; + Buf.resize(Size); + + CbObjectWriter Writer; + Writer.AddBinary("Binary"sv, Buf.data(), Buf.size()); + return Writer.Save(); + }; + + SUBCASE("mem/disklayer") + { + const size_t Count = 16; + ScopedTemporaryDirectory TempDir; + + GcStorageSize CacheSize; + + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + CbObject CacheValue = CreateCacheValue(Zcs.DiskLayerThreshold() - 256); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + for (size_t Key = 0; Key < Count; ++Key) + { + const size_t Bucket = Key % 4; + Zcs.Put(fmt::format("test_bucket-{}", Bucket), IoHash::HashBuffer(&Key, sizeof(uint32_t)), ZenCacheValue{.Value = Buffer}); + } + + CacheSize = Zcs.StorageSize(); + CHECK_LE(CacheValue.GetSize() * Count, CacheSize.DiskSize); + CHECK_LE(CacheValue.GetSize() * Count, CacheSize.MemorySize); + } + + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + const GcStorageSize SerializedSize = Zcs.StorageSize(); + CHECK_EQ(SerializedSize.MemorySize, 0); + CHECK_LE(SerializedSize.DiskSize, CacheSize.DiskSize); + + for (size_t Bucket = 0; Bucket < 4; ++Bucket) + { + Zcs.DropBucket(fmt::format("test_bucket-{}", Bucket)); + } + CHECK_EQ(0, Zcs.StorageSize().DiskSize); + } + } + + SUBCASE("disklayer") + { + const size_t Count = 16; + ScopedTemporaryDirectory TempDir; + + GcStorageSize CacheSize; + + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + CbObject CacheValue = CreateCacheValue(Zcs.DiskLayerThreshold() + 64); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + for (size_t Key = 0; Key < Count; ++Key) + { + const size_t Bucket = Key % 4; + Zcs.Put(fmt::format("test_bucket-{}", Bucket), IoHash::HashBuffer(&Key, sizeof(uint32_t)), {.Value = Buffer}); + } + + CacheSize = Zcs.StorageSize(); + CHECK_LE(CacheValue.GetSize() * Count, CacheSize.DiskSize); + CHECK_EQ(0, CacheSize.MemorySize); + } + + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + const GcStorageSize SerializedSize = Zcs.StorageSize(); + CHECK_EQ(SerializedSize.MemorySize, 0); + CHECK_LE(SerializedSize.DiskSize, CacheSize.DiskSize); + + for (size_t Bucket = 0; Bucket < 4; ++Bucket) + { + Zcs.DropBucket(fmt::format("test_bucket-{}", Bucket)); + } + CHECK_EQ(0, Zcs.StorageSize().DiskSize); + } + } +} + +TEST_CASE("z$.gc") +{ + using namespace testutils; + + SUBCASE("gather references does NOT add references for expired cache entries") + { + ScopedTemporaryDirectory TempDir; + std::vector<IoHash> Cids{CreateKey(1), CreateKey(2), CreateKey(3)}; + + const auto CollectAndFilter = [](GcManager& Gc, + GcClock::TimePoint Time, + GcClock::Duration MaxDuration, + std::span<const IoHash> Cids, + std::vector<IoHash>& OutKeep) { + GcContext GcCtx(Time - MaxDuration); + Gc.CollectGarbage(GcCtx); + OutKeep.clear(); + GcCtx.FilterCids(Cids, [&OutKeep](const IoHash& Hash) { OutKeep.push_back(Hash); }); + }; + + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + const auto Bucket = "teardrinker"sv; + + // Create a cache record + const IoHash Key = CreateKey(42); + CbObjectWriter Record; + Record << "Key"sv + << "SomeRecord"sv; + + for (size_t Idx = 0; auto& Cid : Cids) + { + Record.AddBinaryAttachment(fmt::format("attachment-{}", Idx++), Cid); + } + + IoBuffer Buffer = Record.Save().GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + Zcs.Put(Bucket, Key, {.Value = Buffer}); + + std::vector<IoHash> Keep; + + // Collect garbage with 1 hour max cache duration + { + CollectAndFilter(Gc, GcClock::Now(), std::chrono::hours(1), Cids, Keep); + CHECK_EQ(Cids.size(), Keep.size()); + } + + // Move forward in time + { + CollectAndFilter(Gc, GcClock::Now() + std::chrono::hours(2), std::chrono::hours(1), Cids, Keep); + CHECK_EQ(0, Keep.size()); + } + } + + // Expect timestamps to be serialized + { + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + std::vector<IoHash> Keep; + + // Collect garbage with 1 hour max cache duration + { + CollectAndFilter(Gc, GcClock::Now(), std::chrono::hours(1), Cids, Keep); + CHECK_EQ(3, Keep.size()); + } + + // Move forward in time + { + CollectAndFilter(Gc, GcClock::Now() + std::chrono::hours(2), std::chrono::hours(1), Cids, Keep); + CHECK_EQ(0, Keep.size()); + } + } + } + + SUBCASE("gc removes standalone values") + { + ScopedTemporaryDirectory TempDir; + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + const auto Bucket = "fortysixandtwo"sv; + const GcClock::TimePoint CurrentTime = GcClock::Now(); + + std::vector<IoHash> Keys{CreateKey(1), CreateKey(2), CreateKey(3)}; + + for (const auto& Key : Keys) + { + IoBuffer Value = testutils::CreateBinaryCacheValue(128 << 10); + Zcs.Put(Bucket, Key, {.Value = Value}); + } + + { + GcContext GcCtx(CurrentTime - std::chrono::hours(46)); + + Gc.CollectGarbage(GcCtx); + + for (const auto& Key : Keys) + { + ZenCacheValue CacheValue; + const bool Exists = Zcs.Get(Bucket, Key, CacheValue); + CHECK(Exists); + } + } + + // Move forward in time and collect again + { + GcContext GcCtx(CurrentTime + std::chrono::minutes(2)); + Gc.CollectGarbage(GcCtx); + + for (const auto& Key : Keys) + { + ZenCacheValue CacheValue; + const bool Exists = Zcs.Get(Bucket, Key, CacheValue); + CHECK(!Exists); + } + + CHECK_EQ(0, Zcs.StorageSize().DiskSize); + } + } + + SUBCASE("gc removes small objects") + { + ScopedTemporaryDirectory TempDir; + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + const auto Bucket = "rightintwo"sv; + + std::vector<IoHash> Keys{CreateKey(1), CreateKey(2), CreateKey(3)}; + + for (const auto& Key : Keys) + { + IoBuffer Value = testutils::CreateBinaryCacheValue(128); + Zcs.Put(Bucket, Key, {.Value = Value}); + } + + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(2)); + GcCtx.CollectSmallObjects(true); + + Gc.CollectGarbage(GcCtx); + + for (const auto& Key : Keys) + { + ZenCacheValue CacheValue; + const bool Exists = Zcs.Get(Bucket, Key, CacheValue); + CHECK(Exists); + } + } + + // Move forward in time and collect again + { + GcContext GcCtx(GcClock::Now() + std::chrono::minutes(2)); + GcCtx.CollectSmallObjects(true); + + Zcs.Flush(); + Gc.CollectGarbage(GcCtx); + + for (const auto& Key : Keys) + { + ZenCacheValue CacheValue; + const bool Exists = Zcs.Get(Bucket, Key, CacheValue); + CHECK(!Exists); + } + + CHECK_EQ(0, Zcs.StorageSize().DiskSize); + } + } +} + +TEST_CASE("z$.threadedinsert") // * doctest::skip(true)) +{ + // for (uint32_t i = 0; i < 100; ++i) + { + ScopedTemporaryDirectory TempDir; + + const uint64_t kChunkSize = 1048; + const int32_t kChunkCount = 8192; + + struct Chunk + { + std::string Bucket; + IoBuffer Buffer; + }; + std::unordered_map<IoHash, Chunk, IoHash::Hasher> Chunks; + Chunks.reserve(kChunkCount); + + const std::string Bucket1 = "rightinone"; + const std::string Bucket2 = "rightintwo"; + + for (int32_t Idx = 0; Idx < kChunkCount; ++Idx) + { + while (true) + { + IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize); + IoHash Hash = HashBuffer(Chunk); + if (Chunks.contains(Hash)) + { + continue; + } + Chunks[Hash] = {.Bucket = Bucket1, .Buffer = Chunk}; + break; + } + while (true) + { + IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize); + IoHash Hash = HashBuffer(Chunk); + if (Chunks.contains(Hash)) + { + continue; + } + Chunks[Hash] = {.Bucket = Bucket2, .Buffer = Chunk}; + break; + } + } + + CreateDirectories(TempDir.Path()); + + WorkerThreadPool ThreadPool(4); + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path()); + + { + std::atomic<size_t> WorkCompleted = 0; + for (const auto& Chunk : Chunks) + { + ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, &Chunk]() { + Zcs.Put(Chunk.second.Bucket, Chunk.first, {.Value = Chunk.second.Buffer}); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < Chunks.size()) + { + Sleep(1); + } + } + + const uint64_t TotalSize = Zcs.StorageSize().DiskSize; + CHECK_LE(kChunkSize * Chunks.size(), TotalSize); + + { + std::atomic<size_t> WorkCompleted = 0; + for (const auto& Chunk : Chunks) + { + ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, &Chunk]() { + std::string Bucket = Chunk.second.Bucket; + IoHash ChunkHash = Chunk.first; + ZenCacheValue CacheValue; + + CHECK(Zcs.Get(Bucket, ChunkHash, CacheValue)); + IoHash Hash = IoHash::HashBuffer(CacheValue.Value); + CHECK(ChunkHash == Hash); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < Chunks.size()) + { + Sleep(1); + } + } + std::unordered_map<IoHash, std::string, IoHash::Hasher> GcChunkHashes; + GcChunkHashes.reserve(Chunks.size()); + for (const auto& Chunk : Chunks) + { + GcChunkHashes[Chunk.first] = Chunk.second.Bucket; + } + { + std::unordered_map<IoHash, Chunk, IoHash::Hasher> NewChunks; + + for (int32_t Idx = 0; Idx < kChunkCount; ++Idx) + { + { + IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize); + IoHash Hash = HashBuffer(Chunk); + NewChunks[Hash] = {.Bucket = Bucket1, .Buffer = Chunk}; + } + { + IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize); + IoHash Hash = HashBuffer(Chunk); + NewChunks[Hash] = {.Bucket = Bucket2, .Buffer = Chunk}; + } + } + + std::atomic<size_t> WorkCompleted = 0; + std::atomic_uint32_t AddedChunkCount = 0; + for (const auto& Chunk : NewChunks) + { + ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk, &AddedChunkCount]() { + Zcs.Put(Chunk.second.Bucket, Chunk.first, {.Value = Chunk.second.Buffer}); + AddedChunkCount.fetch_add(1); + WorkCompleted.fetch_add(1); + }); + } + + for (const auto& Chunk : Chunks) + { + ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk]() { + ZenCacheValue CacheValue; + if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue)) + { + CHECK(Chunk.first == IoHash::HashBuffer(CacheValue.Value)); + } + WorkCompleted.fetch_add(1); + }); + } + while (AddedChunkCount.load() < NewChunks.size()) + { + // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope + for (const auto& Chunk : NewChunks) + { + ZenCacheValue CacheValue; + if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue)) + { + GcChunkHashes[Chunk.first] = Chunk.second.Bucket; + } + } + std::vector<IoHash> KeepHashes; + KeepHashes.reserve(GcChunkHashes.size()); + for (const auto& Entry : GcChunkHashes) + { + KeepHashes.push_back(Entry.first); + } + size_t C = 0; + while (C < KeepHashes.size()) + { + if (C % 155 == 0) + { + if (C < KeepHashes.size() - 1) + { + KeepHashes[C] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + if (C + 3 < KeepHashes.size() - 1) + { + KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + } + C++; + } + + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + GcCtx.AddRetainedCids(KeepHashes); + Zcs.CollectGarbage(GcCtx); + const HashKeySet& Deleted = GcCtx.DeletedCids(); + Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); }); + } + + while (WorkCompleted < NewChunks.size() + Chunks.size()) + { + Sleep(1); + } + + { + // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope + for (const auto& Chunk : NewChunks) + { + ZenCacheValue CacheValue; + if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue)) + { + GcChunkHashes[Chunk.first] = Chunk.second.Bucket; + } + } + std::vector<IoHash> KeepHashes; + KeepHashes.reserve(GcChunkHashes.size()); + for (const auto& Entry : GcChunkHashes) + { + KeepHashes.push_back(Entry.first); + } + size_t C = 0; + while (C < KeepHashes.size()) + { + if (C % 155 == 0) + { + if (C < KeepHashes.size() - 1) + { + KeepHashes[C] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + if (C + 3 < KeepHashes.size() - 1) + { + KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + } + C++; + } + + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + GcCtx.AddRetainedCids(KeepHashes); + Zcs.CollectGarbage(GcCtx); + const HashKeySet& Deleted = GcCtx.DeletedCids(); + Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); }); + } + } + { + std::atomic<size_t> WorkCompleted = 0; + for (const auto& Chunk : GcChunkHashes) + { + ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk]() { + ZenCacheValue CacheValue; + CHECK(Zcs.Get(Chunk.second, Chunk.first, CacheValue)); + CHECK(Chunk.first == IoHash::HashBuffer(CacheValue.Value)); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < GcChunkHashes.size()) + { + Sleep(1); + } + } + } +} + +TEST_CASE("z$.namespaces") +{ + using namespace testutils; + + const auto CreateCacheValue = [](size_t Size) -> CbObject { + std::vector<uint8_t> Buf; + Buf.resize(Size); + + CbObjectWriter Writer; + Writer.AddBinary("Binary"sv, Buf.data(), Buf.size()); + return Writer.Save(); + }; + + ScopedTemporaryDirectory TempDir; + CreateDirectories(TempDir.Path()); + + IoHash Key1; + IoHash Key2; + { + GcManager Gc; + ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = false}); + const auto Bucket = "teardrinker"sv; + const auto CustomNamespace = "mynamespace"sv; + + // Create a cache record + Key1 = CreateKey(42); + CbObject CacheValue = CreateCacheValue(4096); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + ZenCacheValue PutValue = {.Value = Buffer}; + Zcs.Put(ZenCacheStore::DefaultNamespace, Bucket, Key1, PutValue); + + ZenCacheValue GetValue; + CHECK(Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key1, GetValue)); + CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue)); + + // This should just be dropped as we don't allow creating of namespaces on the fly + Zcs.Put(CustomNamespace, Bucket, Key1, PutValue); + CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue)); + } + + { + GcManager Gc; + ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true}); + const auto Bucket = "teardrinker"sv; + const auto CustomNamespace = "mynamespace"sv; + + Key2 = CreateKey(43); + CbObject CacheValue2 = CreateCacheValue(4096); + + IoBuffer Buffer2 = CacheValue2.GetBuffer().AsIoBuffer(); + Buffer2.SetContentType(ZenContentType::kCbObject); + ZenCacheValue PutValue2 = {.Value = Buffer2}; + Zcs.Put(CustomNamespace, Bucket, Key2, PutValue2); + + ZenCacheValue GetValue; + CHECK(!Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key2, GetValue)); + CHECK(Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key1, GetValue)); + CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue)); + CHECK(Zcs.Get(CustomNamespace, Bucket, Key2, GetValue)); + } +} + +TEST_CASE("z$.drop.bucket") +{ + using namespace testutils; + + const auto CreateCacheValue = [](size_t Size) -> CbObject { + std::vector<uint8_t> Buf; + Buf.resize(Size); + + CbObjectWriter Writer; + Writer.AddBinary("Binary"sv, Buf.data(), Buf.size()); + return Writer.Save(); + }; + + ScopedTemporaryDirectory TempDir; + CreateDirectories(TempDir.Path()); + + IoHash Key1; + IoHash Key2; + + auto PutValue = + [&CreateCacheValue](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, size_t KeyIndex, size_t Size) { + // Create a cache record + IoHash Key = CreateKey(KeyIndex); + CbObject CacheValue = CreateCacheValue(Size); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + ZenCacheValue PutValue = {.Value = Buffer}; + Zcs.Put(Namespace, Bucket, Key, PutValue); + return Key; + }; + auto GetValue = [](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, const IoHash& Key) { + ZenCacheValue GetValue; + Zcs.Get(Namespace, Bucket, Key, GetValue); + return GetValue; + }; + WorkerThreadPool Workers(1); + { + GcManager Gc; + ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true}); + const auto Bucket = "teardrinker"sv; + const auto Namespace = "mynamespace"sv; + + Key1 = PutValue(Zcs, Namespace, Bucket, 42, 4096); + Key2 = PutValue(Zcs, Namespace, Bucket, 43, 2048); + + ZenCacheValue Value1 = GetValue(Zcs, Namespace, Bucket, Key1); + CHECK(Value1.Value); + + std::atomic_bool WorkComplete = false; + Workers.ScheduleWork([&]() { + zen::Sleep(100); + Value1.Value = IoBuffer{}; + WorkComplete = true; + }); + // On Windows, DropBucket() will be blocked as long as we hold a reference to a buffer in the bucket + // Our DropBucket execution blocks any incoming request from completing until we are done with the drop + CHECK(Zcs.DropBucket(Namespace, Bucket)); + while (!WorkComplete) + { + zen::Sleep(1); + } + + // Entire bucket should be dropped, but doing a request should will re-create the namespace but it must still be empty + Value1 = GetValue(Zcs, Namespace, Bucket, Key1); + CHECK(!Value1.Value); + ZenCacheValue Value2 = GetValue(Zcs, Namespace, Bucket, Key2); + CHECK(!Value2.Value); + } +} + +TEST_CASE("z$.drop.namespace") +{ + using namespace testutils; + + const auto CreateCacheValue = [](size_t Size) -> CbObject { + std::vector<uint8_t> Buf; + Buf.resize(Size); + + CbObjectWriter Writer; + Writer.AddBinary("Binary"sv, Buf.data(), Buf.size()); + return Writer.Save(); + }; + + ScopedTemporaryDirectory TempDir; + CreateDirectories(TempDir.Path()); + + auto PutValue = + [&CreateCacheValue](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, size_t KeyIndex, size_t Size) { + // Create a cache record + IoHash Key = CreateKey(KeyIndex); + CbObject CacheValue = CreateCacheValue(Size); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + ZenCacheValue PutValue = {.Value = Buffer}; + Zcs.Put(Namespace, Bucket, Key, PutValue); + return Key; + }; + auto GetValue = [](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, const IoHash& Key) { + ZenCacheValue GetValue; + Zcs.Get(Namespace, Bucket, Key, GetValue); + return GetValue; + }; + WorkerThreadPool Workers(1); + { + GcManager Gc; + ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true}); + const auto Bucket1 = "teardrinker1"sv; + const auto Bucket2 = "teardrinker2"sv; + const auto Namespace1 = "mynamespace1"sv; + const auto Namespace2 = "mynamespace2"sv; + + IoHash Key1 = PutValue(Zcs, Namespace1, Bucket1, 42, 4096); + IoHash Key2 = PutValue(Zcs, Namespace1, Bucket2, 43, 2048); + IoHash Key3 = PutValue(Zcs, Namespace2, Bucket1, 44, 4096); + IoHash Key4 = PutValue(Zcs, Namespace2, Bucket2, 45, 2048); + + ZenCacheValue Value1 = GetValue(Zcs, Namespace1, Bucket1, Key1); + CHECK(Value1.Value); + ZenCacheValue Value2 = GetValue(Zcs, Namespace1, Bucket2, Key2); + CHECK(Value2.Value); + ZenCacheValue Value3 = GetValue(Zcs, Namespace2, Bucket1, Key3); + CHECK(Value3.Value); + ZenCacheValue Value4 = GetValue(Zcs, Namespace2, Bucket2, Key4); + CHECK(Value4.Value); + + std::atomic_bool WorkComplete = false; + Workers.ScheduleWork([&]() { + zen::Sleep(100); + Value1.Value = IoBuffer{}; + Value2.Value = IoBuffer{}; + Value3.Value = IoBuffer{}; + Value4.Value = IoBuffer{}; + WorkComplete = true; + }); + // On Windows, DropBucket() will be blocked as long as we hold a reference to a buffer in the bucket + // Our DropBucket execution blocks any incoming request from completing until we are done with the drop + CHECK(Zcs.DropNamespace(Namespace1)); + while (!WorkComplete) + { + zen::Sleep(1); + } + + // Entire namespace should be dropped, but doing a request should will re-create the namespace but it must still be empty + Value1 = GetValue(Zcs, Namespace1, Bucket1, Key1); + CHECK(!Value1.Value); + Value2 = GetValue(Zcs, Namespace1, Bucket2, Key2); + CHECK(!Value2.Value); + Value3 = GetValue(Zcs, Namespace2, Bucket1, Key3); + CHECK(Value3.Value); + Value4 = GetValue(Zcs, Namespace2, Bucket2, Key4); + CHECK(Value4.Value); + } +} + +TEST_CASE("z$.blocked.disklayer.put") +{ + ScopedTemporaryDirectory TempDir; + + GcStorageSize CacheSize; + + const auto CreateCacheValue = [](size_t Size) -> CbObject { + std::vector<uint8_t> Buf; + Buf.resize(Size, Size & 0xff); + + CbObjectWriter Writer; + Writer.AddBinary("Binary"sv, Buf.data(), Buf.size()); + return Writer.Save(); + }; + + GcManager Gc; + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + + CbObject CacheValue = CreateCacheValue(64 * 1024 + 64); + + IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer(); + Buffer.SetContentType(ZenContentType::kCbObject); + + size_t Key = Buffer.Size(); + IoHash HashKey = IoHash::HashBuffer(&Key, sizeof(uint32_t)); + Zcs.Put("test_bucket", HashKey, {.Value = Buffer}); + + ZenCacheValue BufferGet; + CHECK(Zcs.Get("test_bucket", HashKey, BufferGet)); + + CbObject CacheValue2 = CreateCacheValue(64 * 1024 + 64 + 1); + IoBuffer Buffer2 = CacheValue2.GetBuffer().AsIoBuffer(); + Buffer2.SetContentType(ZenContentType::kCbObject); + + // We should be able to overwrite even if the file is open for read + Zcs.Put("test_bucket", HashKey, {.Value = Buffer2}); + + MemoryView OldView = BufferGet.Value.GetView(); + + ZenCacheValue BufferGet2; + CHECK(Zcs.Get("test_bucket", HashKey, BufferGet2)); + MemoryView NewView = BufferGet2.Value.GetView(); + + // Make sure file openend for read before we wrote it still have old data + CHECK(OldView.GetSize() == Buffer.GetSize()); + CHECK(memcmp(OldView.GetData(), Buffer.GetData(), OldView.GetSize()) == 0); + + // Make sure we get the new data when reading after we write new data + CHECK(NewView.GetSize() == Buffer2.GetSize()); + CHECK(memcmp(NewView.GetData(), Buffer2.GetData(), NewView.GetSize()) == 0); +} + +TEST_CASE("z$.scrub") +{ + ScopedTemporaryDirectory TempDir; + + using namespace testutils; + + struct CacheRecord + { + IoBuffer Record; + std::vector<CompressedBuffer> Attachments; + }; + + auto CreateCacheRecord = [](bool Structured, std::string_view Bucket, const IoHash& Key, const std::vector<size_t>& AttachmentSizes) { + CacheRecord Result; + if (Structured) + { + Result.Attachments.resize(AttachmentSizes.size()); + CbObjectWriter Record; + Record.BeginObject("Key"sv); + { + Record << "Bucket"sv << Bucket; + Record << "Hash"sv << Key; + } + Record.EndObject(); + for (size_t Index = 0; Index < AttachmentSizes.size(); Index++) + { + IoBuffer AttachmentData = CreateBinaryCacheValue(AttachmentSizes[Index]); + CompressedBuffer CompressedAttachmentData = CompressedBuffer::Compress(SharedBuffer(AttachmentData)); + Record.AddBinaryAttachment(fmt::format("attachment-{}", Index), CompressedAttachmentData.DecodeRawHash()); + Result.Attachments[Index] = CompressedAttachmentData; + } + Result.Record = Record.Save().GetBuffer().AsIoBuffer(); + Result.Record.SetContentType(ZenContentType::kCbObject); + } + else + { + std::string RecordData = fmt::format("{}:{}", Bucket, Key.ToHexString()); + size_t TotalSize = RecordData.length() + 1; + for (size_t AttachmentSize : AttachmentSizes) + { + TotalSize += AttachmentSize; + } + Result.Record = IoBuffer(TotalSize); + char* DataPtr = (char*)Result.Record.MutableData(); + memcpy(DataPtr, RecordData.c_str(), RecordData.length() + 1); + DataPtr += RecordData.length() + 1; + for (size_t AttachmentSize : AttachmentSizes) + { + IoBuffer AttachmentData = CreateBinaryCacheValue(AttachmentSize); + memcpy(DataPtr, AttachmentData.GetData(), AttachmentData.GetSize()); + DataPtr += AttachmentData.GetSize(); + } + } + return Result; + }; + + GcManager Gc; + CidStore CidStore(Gc); + ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache"); + CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + CidStore.Initialize(CidConfig); + + auto CreateRecords = + [&](bool IsStructured, std::string_view BucketName, const std::vector<IoHash>& Cids, const std::vector<size_t>& AttachmentSizes) { + for (const IoHash& Cid : Cids) + { + CacheRecord Record = CreateCacheRecord(IsStructured, BucketName, Cid, AttachmentSizes); + Zcs.Put("mybucket", Cid, {.Value = Record.Record}); + for (const CompressedBuffer& Attachment : Record.Attachments) + { + CidStore.AddChunk(Attachment.GetCompressed().Flatten().AsIoBuffer(), Attachment.DecodeRawHash()); + } + } + }; + + std::vector<size_t> AttachmentSizes = {16, 1000, 2000, 4000, 8000, 64000, 80000}; + + std::vector<IoHash> UnstructuredCids{CreateKey(4), CreateKey(5), CreateKey(6)}; + CreateRecords(false, "mybucket"sv, UnstructuredCids, AttachmentSizes); + + std::vector<IoHash> StructuredCids{CreateKey(1), CreateKey(2), CreateKey(3)}; + CreateRecords(true, "mybucket"sv, StructuredCids, AttachmentSizes); + + ScrubContext ScrubCtx; + Zcs.Scrub(ScrubCtx); + CidStore.Scrub(ScrubCtx); + CHECK(ScrubCtx.ScrubbedChunks() == (StructuredCids.size() + StructuredCids.size() * AttachmentSizes.size()) + UnstructuredCids.size()); + CHECK(ScrubCtx.BadCids().GetSize() == 0); +} + +#endif + +void +z$_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenserver/cache/structuredcachestore.h b/src/zenserver/cache/structuredcachestore.h new file mode 100644 index 000000000..3fb4f035d --- /dev/null +++ b/src/zenserver/cache/structuredcachestore.h @@ -0,0 +1,535 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/thread.h> +#include <zencore/uid.h> +#include <zenstore/blockstore.h> +#include <zenstore/caslog.h> +#include <zenstore/gc.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <atomic> +#include <compare> +#include <filesystem> +#include <unordered_map> + +#define ZEN_USE_CACHE_TRACKER 0 + +namespace zen { + +class PathBuilderBase; +class GcManager; +class ZenCacheTracker; +class ScrubContext; + +/****************************************************************************** + + /$$$$$$$$ /$$$$$$ /$$ + |_____ $$ /$$__ $$ | $$ + /$$/ /$$$$$$ /$$$$$$$ | $$ \__/ /$$$$$$ /$$$$$$| $$$$$$$ /$$$$$$ + /$$/ /$$__ $| $$__ $$ | $$ |____ $$/$$_____| $$__ $$/$$__ $$ + /$$/ | $$$$$$$| $$ \ $$ | $$ /$$$$$$| $$ | $$ \ $| $$$$$$$$ + /$$/ | $$_____| $$ | $$ | $$ $$/$$__ $| $$ | $$ | $| $$_____/ + /$$$$$$$| $$$$$$| $$ | $$ | $$$$$$| $$$$$$| $$$$$$| $$ | $| $$$$$$$ + |________/\_______|__/ |__/ \______/ \_______/\_______|__/ |__/\_______/ + + Cache store for UE5. Restricts keys to "{bucket}/{hash}" pairs where the hash + is 40 (hex) chars in size. Values may be opaque blobs or structured objects + which can in turn contain references to other objects (or blobs). + +******************************************************************************/ + +namespace access_tracking { + + struct KeyAccessTime + { + IoHash Key; + GcClock::Tick LastAccess{}; + }; + + struct AccessTimes + { + std::unordered_map<std::string, std::vector<KeyAccessTime>> Buckets; + }; +}; // namespace access_tracking + +struct ZenCacheValue +{ + IoBuffer Value; + uint64_t RawSize = 0; + IoHash RawHash = IoHash::Zero; +}; + +struct CacheValueDetails +{ + struct ValueDetails + { + uint64_t Size; + uint64_t RawSize; + IoHash RawHash; + GcClock::Tick LastAccess{}; + std::vector<IoHash> Attachments; + ZenContentType ContentType; + }; + + struct BucketDetails + { + std::unordered_map<IoHash, ValueDetails, IoHash::Hasher> Values; + }; + + struct NamespaceDetails + { + std::unordered_map<std::string, BucketDetails> Buckets; + }; + + std::unordered_map<std::string, NamespaceDetails> Namespaces; +}; + +////////////////////////////////////////////////////////////////////////// + +#pragma pack(push) +#pragma pack(1) + +struct DiskLocation +{ + inline DiskLocation() = default; + + inline DiskLocation(uint64_t ValueSize, uint8_t Flags) : Flags(Flags | kStandaloneFile) { Location.StandaloneSize = ValueSize; } + + inline DiskLocation(const BlockStoreLocation& Location, uint64_t PayloadAlignment, uint8_t Flags) : Flags(Flags & ~kStandaloneFile) + { + this->Location.BlockLocation = BlockStoreDiskLocation(Location, PayloadAlignment); + } + + inline BlockStoreLocation GetBlockLocation(uint64_t PayloadAlignment) const + { + ZEN_ASSERT(!(Flags & kStandaloneFile)); + return Location.BlockLocation.Get(PayloadAlignment); + } + + inline uint64_t Size() const { return (Flags & kStandaloneFile) ? Location.StandaloneSize : Location.BlockLocation.GetSize(); } + inline uint8_t IsFlagSet(uint64_t Flag) const { return Flags & Flag; } + inline uint8_t GetFlags() const { return Flags; } + inline ZenContentType GetContentType() const + { + ZenContentType ContentType = ZenContentType::kBinary; + + if (IsFlagSet(kStructured)) + { + ContentType = ZenContentType::kCbObject; + } + + if (IsFlagSet(kCompressed)) + { + ContentType = ZenContentType::kCompressedBinary; + } + + return ContentType; + } + + union + { + BlockStoreDiskLocation BlockLocation; // 10 bytes + uint64_t StandaloneSize = 0; // 8 bytes + } Location; + + static const uint8_t kStandaloneFile = 0x80u; // Stored as a separate file + static const uint8_t kStructured = 0x40u; // Serialized as compact binary + static const uint8_t kTombStone = 0x20u; // Represents a deleted key/value + static const uint8_t kCompressed = 0x10u; // Stored in compressed buffer format + + uint8_t Flags = 0; + uint8_t Reserved = 0; +}; + +struct DiskIndexEntry +{ + IoHash Key; // 20 bytes + DiskLocation Location; // 12 bytes +}; + +#pragma pack(pop) + +static_assert(sizeof(DiskIndexEntry) == 32); + +// This store the access time as seconds since epoch internally in a 32-bit value giving is a range of 136 years since epoch +struct AccessTime +{ + explicit AccessTime(GcClock::Tick Tick) noexcept : SecondsSinceEpoch(ToSeconds(Tick)) {} + AccessTime& operator=(GcClock::Tick Tick) noexcept + { + SecondsSinceEpoch.store(ToSeconds(Tick), std::memory_order_relaxed); + return *this; + } + operator GcClock::Tick() const noexcept + { + return std::chrono::duration_cast<GcClock::Duration>(std::chrono::seconds(SecondsSinceEpoch.load(std::memory_order_relaxed))) + .count(); + } + + AccessTime(AccessTime&& Rhs) noexcept : SecondsSinceEpoch(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed)) {} + AccessTime(const AccessTime& Rhs) noexcept : SecondsSinceEpoch(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed)) {} + AccessTime& operator=(AccessTime&& Rhs) noexcept + { + SecondsSinceEpoch.store(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed), std::memory_order_relaxed); + return *this; + } + AccessTime& operator=(const AccessTime& Rhs) noexcept + { + SecondsSinceEpoch.store(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed), std::memory_order_relaxed); + return *this; + } + +private: + static uint32_t ToSeconds(GcClock::Tick Tick) + { + return gsl::narrow<uint32_t>(std::chrono::duration_cast<std::chrono::seconds>(GcClock::Duration(Tick)).count()); + } + std::atomic_uint32_t SecondsSinceEpoch; +}; + +/** In-memory cache storage + + Intended for small values which are frequently accessed + + This should have a better memory management policy to maintain reasonable + footprint. + */ +class ZenCacheMemoryLayer +{ +public: + struct Configuration + { + uint64_t TargetFootprintBytes = 16 * 1024 * 1024; + uint64_t ScavengeThreshold = 4 * 1024 * 1024; + }; + + struct BucketInfo + { + uint64_t EntryCount = 0; + uint64_t TotalSize = 0; + }; + + struct Info + { + Configuration Config; + std::vector<std::string> BucketNames; + uint64_t EntryCount = 0; + uint64_t TotalSize = 0; + }; + + ZenCacheMemoryLayer(); + ~ZenCacheMemoryLayer(); + + bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value); + void Drop(); + bool DropBucket(std::string_view Bucket); + void Scrub(ScrubContext& Ctx); + void GatherAccessTimes(zen::access_tracking::AccessTimes& AccessTimes); + void Reset(); + uint64_t TotalSize() const; + + Info GetInfo() const; + std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const; + + const Configuration& GetConfiguration() const { return m_Configuration; } + void SetConfiguration(const Configuration& NewConfig) { m_Configuration = NewConfig; } + +private: + struct CacheBucket + { +#pragma pack(push) +#pragma pack(1) + struct BucketPayload + { + IoBuffer Payload; // 8 + uint32_t RawSize; // 4 + IoHash RawHash; // 20 + }; +#pragma pack(pop) + static_assert(sizeof(BucketPayload) == 32u); + static_assert(sizeof(AccessTime) == 4u); + + mutable RwLock m_BucketLock; + std::vector<AccessTime> m_AccessTimes; + std::vector<BucketPayload> m_Payloads; + tsl::robin_map<IoHash, uint32_t> m_CacheMap; + + std::atomic_uint64_t m_TotalSize{}; + + bool Get(const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(const IoHash& HashKey, const ZenCacheValue& Value); + void Drop(); + void Scrub(ScrubContext& Ctx); + void GatherAccessTimes(std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes); + inline uint64_t TotalSize() const { return m_TotalSize; } + uint64_t EntryCount() const; + }; + + mutable RwLock m_Lock; + std::unordered_map<std::string, std::unique_ptr<CacheBucket>> m_Buckets; + std::vector<std::unique_ptr<CacheBucket>> m_DroppedBuckets; + Configuration m_Configuration; + + ZenCacheMemoryLayer(const ZenCacheMemoryLayer&) = delete; + ZenCacheMemoryLayer& operator=(const ZenCacheMemoryLayer&) = delete; +}; + +class ZenCacheDiskLayer +{ +public: + struct Configuration + { + std::filesystem::path RootDir; + }; + + struct BucketInfo + { + uint64_t EntryCount = 0; + uint64_t TotalSize = 0; + }; + + struct Info + { + Configuration Config; + std::vector<std::string> BucketNames; + uint64_t EntryCount = 0; + uint64_t TotalSize = 0; + }; + + explicit ZenCacheDiskLayer(const std::filesystem::path& RootDir); + ~ZenCacheDiskLayer(); + + bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value); + bool Drop(); + bool DropBucket(std::string_view Bucket); + void Flush(); + void Scrub(ScrubContext& Ctx); + void GatherReferences(GcContext& GcCtx); + void CollectGarbage(GcContext& GcCtx); + void UpdateAccessTimes(const zen::access_tracking::AccessTimes& AccessTimes); + + void DiscoverBuckets(); + uint64_t TotalSize() const; + + Info GetInfo() const; + std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const; + + CacheValueDetails::NamespaceDetails GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const; + +private: + /** A cache bucket manages a single directory containing + metadata and data for that bucket + */ + struct CacheBucket + { + CacheBucket(std::string BucketName); + ~CacheBucket(); + + bool OpenOrCreate(std::filesystem::path BucketDir, bool AllowCreate = true); + bool Get(const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(const IoHash& HashKey, const ZenCacheValue& Value); + bool Drop(); + void Flush(); + void Scrub(ScrubContext& Ctx); + void GatherReferences(GcContext& GcCtx); + void CollectGarbage(GcContext& GcCtx); + void UpdateAccessTimes(const std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes); + + inline uint64_t TotalSize() const { return m_TotalStandaloneSize.load(std::memory_order::relaxed) + m_BlockStore.TotalSize(); } + uint64_t EntryCount() const; + + CacheValueDetails::BucketDetails GetValueDetails(const std::string_view ValueFilter) const; + + private: + const uint64_t MaxBlockSize = 1ull << 30; + uint64_t m_PayloadAlignment = 1ull << 4; + + std::string m_BucketName; + std::filesystem::path m_BucketDir; + std::filesystem::path m_BlocksBasePath; + BlockStore m_BlockStore; + Oid m_BucketId; + uint64_t m_LargeObjectThreshold = 128 * 1024; + + // These files are used to manage storage of small objects for this bucket + + TCasLogFile<DiskIndexEntry> m_SlogFile; + uint64_t m_LogFlushPosition = 0; + +#pragma pack(push) +#pragma pack(1) + struct BucketPayload + { + DiskLocation Location; // 12 + uint64_t RawSize; // 8 + IoHash RawHash; // 20 + }; +#pragma pack(pop) + static_assert(sizeof(BucketPayload) == 40u); + static_assert(sizeof(AccessTime) == 4u); + + using IndexMap = tsl::robin_map<IoHash, size_t, IoHash::Hasher>; + + mutable RwLock m_IndexLock; + std::vector<AccessTime> m_AccessTimes; + std::vector<BucketPayload> m_Payloads; + IndexMap m_Index; + + std::atomic_uint64_t m_TotalStandaloneSize{}; + + void BuildPath(PathBuilderBase& Path, const IoHash& HashKey) const; + void PutStandaloneCacheValue(const IoHash& HashKey, const ZenCacheValue& Value); + IoBuffer GetStandaloneCacheValue(const DiskLocation& Loc, const IoHash& HashKey) const; + void PutInlineCacheValue(const IoHash& HashKey, const ZenCacheValue& Value); + IoBuffer GetInlineCacheValue(const DiskLocation& Loc) const; + void MakeIndexSnapshot(); + uint64_t ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t& OutVersion); + uint64_t ReadLog(const std::filesystem::path& LogPath, uint64_t LogPosition); + void OpenLog(const bool IsNew); + void SaveManifest(); + CacheValueDetails::ValueDetails GetValueDetails(const IoHash& Key, size_t Index) const; + // These locks are here to avoid contention on file creation, therefore it's sufficient + // that we take the same lock for the same hash + // + // These locks are small and should really be spaced out so they don't share cache lines, + // but we don't currently access them at particularly high frequency so it should not be + // an issue in practice + + mutable RwLock m_ShardedLocks[256]; + inline RwLock& LockForHash(const IoHash& Hash) const { return m_ShardedLocks[Hash.Hash[19]]; } + }; + + std::filesystem::path m_RootDir; + mutable RwLock m_Lock; + std::unordered_map<std::string, std::unique_ptr<CacheBucket>> m_Buckets; // TODO: make this case insensitive + std::vector<std::unique_ptr<CacheBucket>> m_DroppedBuckets; + + ZenCacheDiskLayer(const ZenCacheDiskLayer&) = delete; + ZenCacheDiskLayer& operator=(const ZenCacheDiskLayer&) = delete; +}; + +class ZenCacheNamespace final : public RefCounted, public GcStorage, public GcContributor +{ +public: + struct Configuration + { + std::filesystem::path RootDir; + uint64_t DiskLayerThreshold = 0; + }; + struct BucketInfo + { + ZenCacheDiskLayer::BucketInfo DiskLayerInfo; + ZenCacheMemoryLayer::BucketInfo MemoryLayerInfo; + }; + struct Info + { + Configuration Config; + std::vector<std::string> BucketNames; + ZenCacheDiskLayer::Info DiskLayerInfo; + ZenCacheMemoryLayer::Info MemoryLayerInfo; + }; + + ZenCacheNamespace(GcManager& Gc, const std::filesystem::path& RootDir); + ~ZenCacheNamespace(); + + bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value); + bool Drop(); + bool DropBucket(std::string_view Bucket); + void Flush(); + void Scrub(ScrubContext& Ctx); + uint64_t DiskLayerThreshold() const { return m_DiskLayerSizeThreshold; } + virtual void GatherReferences(GcContext& GcCtx) override; + virtual void CollectGarbage(GcContext& GcCtx) override; + virtual GcStorageSize StorageSize() const override; + Info GetInfo() const; + std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const; + + CacheValueDetails::NamespaceDetails GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const; + +private: + std::filesystem::path m_RootDir; + ZenCacheMemoryLayer m_MemLayer; + ZenCacheDiskLayer m_DiskLayer; + uint64_t m_DiskLayerSizeThreshold = 1 * 1024; + uint64_t m_LastScrubTime = 0; + +#if ZEN_USE_CACHE_TRACKER + std::unique_ptr<ZenCacheTracker> m_AccessTracker; +#endif + + ZenCacheNamespace(const ZenCacheNamespace&) = delete; + ZenCacheNamespace& operator=(const ZenCacheNamespace&) = delete; +}; + +class ZenCacheStore final +{ +public: + static constexpr std::string_view DefaultNamespace = + "!default!"; // This is intentionally not a valid namespace name and will only be used for mapping when no namespace is given + static constexpr std::string_view NamespaceDiskPrefix = "ns_"; + + struct Configuration + { + std::filesystem::path BasePath; + bool AllowAutomaticCreationOfNamespaces = false; + }; + + struct Info + { + Configuration Config; + std::vector<std::string> NamespaceNames; + uint64_t DiskEntryCount = 0; + uint64_t MemoryEntryCount = 0; + GcStorageSize StorageSize; + }; + + ZenCacheStore(GcManager& Gc, const Configuration& Configuration); + ~ZenCacheStore(); + + bool Get(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue); + void Put(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value); + bool DropBucket(std::string_view Namespace, std::string_view Bucket); + bool DropNamespace(std::string_view Namespace); + void Flush(); + void Scrub(ScrubContext& Ctx); + + CacheValueDetails GetValueDetails(const std::string_view NamespaceFilter, + const std::string_view BucketFilter, + const std::string_view ValueFilter) const; + + GcStorageSize StorageSize() const; + // const Configuration& GetConfiguration() const { return m_Configuration; } + + Info GetInfo() const; + std::optional<ZenCacheNamespace::Info> GetNamespaceInfo(std::string_view Namespace); + std::optional<ZenCacheNamespace::BucketInfo> GetBucketInfo(std::string_view Namespace, std::string_view Bucket); + +private: + const ZenCacheNamespace* FindNamespace(std::string_view Namespace) const; + ZenCacheNamespace* GetNamespace(std::string_view Namespace); + void IterateNamespaces(const std::function<void(std::string_view Namespace, ZenCacheNamespace& Store)>& Callback) const; + + typedef std::unordered_map<std::string, std::unique_ptr<ZenCacheNamespace>> NamespaceMap; + + mutable RwLock m_NamespacesLock; + NamespaceMap m_Namespaces; + std::vector<std::unique_ptr<ZenCacheNamespace>> m_DroppedNamespaces; + + GcManager& m_Gc; + Configuration m_Configuration; +}; + +void z$_forcelink(); + +} // namespace zen diff --git a/src/zenserver/cidstore.cpp b/src/zenserver/cidstore.cpp new file mode 100644 index 000000000..bce4f1dfb --- /dev/null +++ b/src/zenserver/cidstore.cpp @@ -0,0 +1,124 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "cidstore.h" + +#include <zencore/compress.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zenstore/cidstore.h> + +#include <gsl/gsl-lite.hpp> + +namespace zen { + +HttpCidService::HttpCidService(CidStore& Store) : m_CidStore(Store) +{ + m_Router.AddPattern("cid", "([0-9A-Fa-f]{40})"); + + m_Router.RegisterRoute( + "{cid}", + [this](HttpRouterRequest& Req) { + IoHash Hash = IoHash::FromHexString(Req.GetCapture(1)); + ZEN_DEBUG("CID request for {}", Hash); + + HttpServerRequest& ServerRequest = Req.ServerRequest(); + + switch (ServerRequest.RequestVerb()) + { + case HttpVerb::kGet: + case HttpVerb::kHead: + { + if (IoBuffer Value = m_CidStore.FindChunkByCid(Hash)) + { + return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value); + } + + return ServerRequest.WriteResponse(HttpResponseCode::NotFound); + } + break; + + case HttpVerb::kPut: + { + IoBuffer Payload = ServerRequest.ReadPayload(); + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize)) + { + return ServerRequest.WriteResponse(HttpResponseCode::UnsupportedMediaType); + } + + // URI hash must match content hash + if (RawHash != Hash) + { + return ServerRequest.WriteResponse(HttpResponseCode::BadRequest); + } + + m_CidStore.AddChunk(Payload, RawHash); + + return ServerRequest.WriteResponse(HttpResponseCode::OK); + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPut | HttpVerb::kHead); +} + +const char* +HttpCidService::BaseUri() const +{ + return "/cid/"; +} + +void +HttpCidService::HandleRequest(zen::HttpServerRequest& Request) +{ + if (Request.RelativeUri().empty()) + { + // Root URI request + + switch (Request.RequestVerb()) + { + case HttpVerb::kPut: + case HttpVerb::kPost: + { + IoBuffer Payload = Request.ReadPayload(); + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize)) + { + return Request.WriteResponse(HttpResponseCode::UnsupportedMediaType); + } + + ZEN_DEBUG("CID POST request for {} ({} bytes)", RawHash, Payload.Size()); + + auto InsertResult = m_CidStore.AddChunk(Payload, RawHash); + + if (InsertResult.New) + { + return Request.WriteResponse(HttpResponseCode::Created); + } + else + { + return Request.WriteResponse(HttpResponseCode::OK); + } + } + break; + + case HttpVerb::kGet: + case HttpVerb::kHead: + break; + + default: + break; + } + } + else + { + m_Router.HandleRequest(Request); + } +} + +} // namespace zen diff --git a/src/zenserver/cidstore.h b/src/zenserver/cidstore.h new file mode 100644 index 000000000..8e7832b35 --- /dev/null +++ b/src/zenserver/cidstore.h @@ -0,0 +1,35 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +namespace zen { + +/** + * Simple CID store HTTP endpoint + * + * Note that since this does not end up pinning any of the chunks it's only really useful for a small subset of use cases where you know a + * chunk exists in the underlying CID store. Thus it's mainly useful for internal use when communicating between Zen store instances + * + * Using this interface for adding CID chunks makes little sense except for testing purposes as garbage collection may reap anything you add + * before anything ever gets to access it + */ + +class CidStore; + +class HttpCidService : public HttpService +{ +public: + explicit HttpCidService(CidStore& Store); + ~HttpCidService() = default; + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + CidStore& m_CidStore; + HttpRequestRouter m_Router; +}; + +} // namespace zen diff --git a/src/zenserver/compute/function.cpp b/src/zenserver/compute/function.cpp new file mode 100644 index 000000000..493e2666e --- /dev/null +++ b/src/zenserver/compute/function.cpp @@ -0,0 +1,629 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "function.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <upstream/jupiter.h> +# include <upstream/upstreamapply.h> +# include <upstream/upstreamcache.h> +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compress.h> +# include <zencore/except.h> +# include <zencore/filesystem.h> +# include <zencore/fmtutils.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/scopeguard.h> +# include <zenstore/cidstore.h> + +# include <span> + +using namespace std::literals; + +namespace zen { + +HttpFunctionService::HttpFunctionService(CidStore& InCidStore, + const CloudCacheClientOptions& ComputeOptions, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const UpstreamAuthConfig& StorageAuthConfig, + AuthMgr& Mgr) +: m_Log(logging::Get("apply")) +, m_CidStore(InCidStore) +{ + m_UpstreamApply = UpstreamApply::Create({}, m_CidStore); + + InitializeThread = std::thread{[this, ComputeOptions, StorageOptions, ComputeAuthConfig, StorageAuthConfig, &Mgr] { + auto HordeUpstreamEndpoint = UpstreamApplyEndpoint::CreateHordeEndpoint(ComputeOptions, + ComputeAuthConfig, + StorageOptions, + StorageAuthConfig, + m_CidStore, + Mgr); + m_UpstreamApply->RegisterEndpoint(std::move(HordeUpstreamEndpoint)); + m_UpstreamApply->Initialize(); + }}; + + m_Router.AddPattern("job", "([[:digit:]]+)"); + m_Router.AddPattern("worker", "([[:xdigit:]]{40})"); + m_Router.AddPattern("action", "([[:xdigit:]]{40})"); + + m_Router.RegisterRoute( + "ready", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + return HttpReq.WriteResponse(m_UpstreamApply->IsHealthy() ? HttpResponseCode::OK : HttpResponseCode::ServiceUnavailable); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "workers/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end()) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + else + { + const WorkerDesc& Desc = It->second; + return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor); + } + } + break; + + case HttpVerb::kPost: + { + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + CbObject FunctionSpec = HttpReq.ReadPayloadObject(); + + // Determine which pieces are missing and need to be transmitted to populate CAS + + HashKeySet ChunkSet; + + FunctionSpec.IterateAttachments([&](CbFieldView Field) { + const IoHash Hash = Field.AsHash(); + ChunkSet.AddHashToSet(Hash); + }); + + // Note that we store executables uncompressed to make it + // more straightforward and efficient to materialize them, hence + // the CAS lookup here instead of CID for the input payloads + + m_CidStore.FilterChunks(ChunkSet); + + if (ChunkSet.IsEmpty()) + { + RwLock::ExclusiveLockScope _(m_WorkerLock); + + m_WorkerMap.insert_or_assign(WorkerId, WorkerDesc{FunctionSpec}); + + ZEN_DEBUG("worker {}: all attachments already available", WorkerId); + + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + else + { + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("need"); + + ChunkSet.IterateHashes([&](const IoHash& Hash) { + ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash); + + ResponseWriter.AddHash(Hash); + }); + + ResponseWriter.EndArray(); + + ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize()); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save()); + } + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage FunctionSpec = HttpReq.ReadPayloadPackage(); + + CbObject Obj = FunctionSpec.GetObject(); + + std::span<const CbAttachment> Attachments = FunctionSpec.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer Buffer = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + TotalAttachmentBytes += Buffer.GetCompressedSize(); + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += Buffer.GetCompressedSize(); + ++NewAttachmentCount; + } + } + + ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments", + WorkerId, + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + RwLock::ExclusiveLockScope _(m_WorkerLock); + + m_WorkerMap.insert_or_assign(WorkerId, WorkerDesc{.Descriptor = Obj}); + + return HttpReq.WriteResponse(HttpResponseCode::NoContent); + } + break; + + default: + break; + } + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{job}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + break; + + case HttpVerb::kPost: + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{worker}/{action}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2)); + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbPackage Output; + HttpResponseCode ResponseCode = ExecActionUpstreamResult(WorkerId, ActionId, Output); + if (ResponseCode != HttpResponseCode::OK) + { + return HttpReq.WriteResponse(ResponseCode); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + break; + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "simple/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + WorkerDesc Worker; + + { + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end()) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + else + { + Worker = It->second; + } + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + { + CbObject Output; + HttpResponseCode ResponseCode = ExecActionUpstreamResult(WorkerId, Output); + if (ResponseCode != HttpResponseCode::OK) + { + return HttpReq.WriteResponse(ResponseCode); + } + + { + RwLock::SharedLockScope _(m_WorkerLock); + m_WorkerMap.erase(WorkerId); + } + + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + break; + + case HttpVerb::kPost: + { + CbObject Output; + HttpResponseCode ResponseCode = ExecActionUpstream(Worker, Output); + if (ResponseCode != HttpResponseCode::OK) + { + return HttpReq.WriteResponse(ResponseCode); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "jobs/{worker}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1)); + + WorkerDesc Worker; + + { + RwLock::SharedLockScope _(m_WorkerLock); + + if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end()) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + else + { + Worker = It->second; + } + } + + switch (HttpReq.RequestVerb()) + { + case HttpVerb::kGet: + // TODO: return status of all pending or executing jobs + break; + + case HttpVerb::kPost: + switch (HttpReq.RequestContentType()) + { + case HttpContentType::kCbObject: + { + // This operation takes the proposed job spec and identifies which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject RequestObject = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + RequestObject.IterateAttachments([&](CbFieldView Field) { + const IoHash FileHash = Field.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + NeedList.push_back(FileHash); + } + }); + + if (NeedList.empty()) + { + // We already have everything + CbObject Output; + HttpResponseCode ResponseCode = ExecActionUpstream(Worker, RequestObject, Output); + + if (ResponseCode != HttpResponseCode::OK) + { + return HttpReq.WriteResponse(ResponseCode); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response); + } + break; + + case HttpContentType::kCbPackage: + { + CbPackage Action = HttpReq.ReadPayloadPackage(); + CbObject ActionObj = Action.GetObject(); + + std::span<const CbAttachment> Attachments = Action.GetAttachments(); + + int AttachmentCount = 0; + int NewAttachmentCount = 0; + uint64_t TotalAttachmentBytes = 0; + uint64_t TotalNewBytes = 0; + + for (const CbAttachment& Attachment : Attachments) + { + ZEN_ASSERT(Attachment.IsCompressedBinary()); + + const IoHash DataHash = Attachment.GetHash(); + CompressedBuffer DataView = Attachment.AsCompressedBinary(); + + ZEN_UNUSED(DataHash); + + const uint64_t CompressedSize = DataView.GetCompressedSize(); + + TotalAttachmentBytes += CompressedSize; + ++AttachmentCount; + + const CidStore::InsertResult InsertResult = + m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash); + + if (InsertResult.New) + { + TotalNewBytes += CompressedSize; + ++NewAttachmentCount; + } + } + + ZEN_DEBUG("new action: {} in {} attachments. {} new ({} attachments)", + zen::NiceBytes(TotalAttachmentBytes), + AttachmentCount, + zen::NiceBytes(TotalNewBytes), + NewAttachmentCount); + + CbObject Output; + HttpResponseCode ResponseCode = ExecActionUpstream(Worker, ActionObj, Output); + + if (ResponseCode != HttpResponseCode::OK) + { + return HttpReq.WriteResponse(ResponseCode); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, Output); + } + break; + + default: + break; + } + break; + + default: + break; + } + }, + HttpVerb::kPost); +} + +HttpFunctionService::~HttpFunctionService() +{ +} + +const char* +HttpFunctionService::BaseUri() const +{ + return "/apply/"; +} + +void +HttpFunctionService::HandleRequest(HttpServerRequest& Request) +{ + if (m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +HttpResponseCode +HttpFunctionService::ExecActionUpstream(const WorkerDesc& Worker, CbObject& Object) +{ + const IoHash WorkerId = Worker.Descriptor.GetHash(); + + ZEN_INFO("Action {} being processed...", WorkerId.ToHexString()); + + auto EnqueueResult = m_UpstreamApply->EnqueueUpstream({.WorkerDescriptor = Worker.Descriptor, .Type = UpstreamApplyType::Simple}); + if (!EnqueueResult.Success) + { + ZEN_ERROR("Error enqueuing upstream Action {}", WorkerId.ToHexString()); + return HttpResponseCode::InternalServerError; + } + + CbObjectWriter Writer; + Writer.AddHash("worker", WorkerId); + + Object = Writer.Save(); + return HttpResponseCode::OK; +} + +HttpResponseCode +HttpFunctionService::ExecActionUpstreamResult(const IoHash& WorkerId, CbObject& Object) +{ + const static IoHash Empty = CbObject().GetHash(); + auto Status = m_UpstreamApply->GetStatus(WorkerId, Empty); + if (!Status.Success) + { + return HttpResponseCode::NotFound; + } + + if (Status.Status.State != UpstreamApplyState::Complete) + { + return HttpResponseCode::Accepted; + } + + GetUpstreamApplyResult& Completed = Status.Status.Result; + + if (!Completed.Success) + { + ZEN_ERROR("Action {} failed:\n stdout: {}\n stderr: {}\n reason: {}\n errorcode: {}", + WorkerId.ToHexString(), + Completed.StdOut, + Completed.StdErr, + Completed.Error.Reason, + Completed.Error.ErrorCode); + + if (Completed.Error.ErrorCode == 0) + { + Completed.Error.ErrorCode = -1; + } + if (Completed.StdErr.empty() && !Completed.Error.Reason.empty()) + { + Completed.StdErr = Completed.Error.Reason; + } + } + else + { + ZEN_INFO("Action {} completed with {} files ExitCode={}", + WorkerId.ToHexString(), + Completed.OutputFiles.size(), + Completed.Error.ErrorCode); + } + + CbObjectWriter ResultObject; + + ResultObject.AddString("agent"sv, Completed.Agent); + ResultObject.AddString("detail"sv, Completed.Detail); + ResultObject.AddString("stdout"sv, Completed.StdOut); + ResultObject.AddString("stderr"sv, Completed.StdErr); + ResultObject.AddInteger("exitcode"sv, Completed.Error.ErrorCode); + ResultObject.BeginArray("stats"sv); + for (const auto& Timepoint : Completed.Timepoints) + { + ResultObject.BeginObject(); + ResultObject.AddString("name"sv, Timepoint.first); + ResultObject.AddDateTimeTicks("time"sv, Timepoint.second); + ResultObject.EndObject(); + } + ResultObject.EndArray(); + + ResultObject.BeginArray("files"sv); + for (const auto& File : Completed.OutputFiles) + { + ResultObject.BeginObject(); + ResultObject.AddString("name"sv, File.first.string()); + ResultObject.AddBinary("data"sv, Completed.FileData[File.second]); + ResultObject.EndObject(); + } + ResultObject.EndArray(); + + Object = ResultObject.Save(); + return HttpResponseCode::OK; +} + +HttpResponseCode +HttpFunctionService::ExecActionUpstream(const WorkerDesc& Worker, CbObject Action, CbObject& Object) +{ + const IoHash WorkerId = Worker.Descriptor.GetHash(); + const IoHash ActionId = Action.GetHash(); + + Action.MakeOwned(); + + ZEN_INFO("Action {}/{} being processed...", WorkerId.ToHexString(), ActionId.ToHexString()); + + auto EnqueueResult = m_UpstreamApply->EnqueueUpstream( + {.WorkerDescriptor = Worker.Descriptor, .Action = std::move(Action), .Type = UpstreamApplyType::Asset}); + + if (!EnqueueResult.Success) + { + ZEN_ERROR("Error enqueuing upstream Action {}/{}", WorkerId.ToHexString(), ActionId.ToHexString()); + return HttpResponseCode::InternalServerError; + } + + CbObjectWriter Writer; + Writer.AddHash("worker", WorkerId); + Writer.AddHash("action", ActionId); + + Object = Writer.Save(); + return HttpResponseCode::OK; +} + +HttpResponseCode +HttpFunctionService::ExecActionUpstreamResult(const IoHash& WorkerId, const IoHash& ActionId, CbPackage& Package) +{ + auto Status = m_UpstreamApply->GetStatus(WorkerId, ActionId); + if (!Status.Success) + { + return HttpResponseCode::NotFound; + } + + if (Status.Status.State != UpstreamApplyState::Complete) + { + return HttpResponseCode::Accepted; + } + + GetUpstreamApplyResult& Completed = Status.Status.Result; + if (!Completed.Success || Completed.Error.ErrorCode != 0) + { + ZEN_ERROR("Action {}/{} failed:\n stdout: {}\n stderr: {}\n reason: {}\n errorcode: {}", + WorkerId.ToHexString(), + ActionId.ToHexString(), + Completed.StdOut, + Completed.StdErr, + Completed.Error.Reason, + Completed.Error.ErrorCode); + + return HttpResponseCode::InternalServerError; + } + + ZEN_INFO("Action {}/{} completed with {} attachments ({} compressed, {} uncompressed)", + WorkerId.ToHexString(), + ActionId.ToHexString(), + Completed.OutputPackage.GetAttachments().size(), + NiceBytes(Completed.TotalAttachmentBytes), + NiceBytes(Completed.TotalRawAttachmentBytes)); + + Package = std::move(Completed.OutputPackage); + return HttpResponseCode::OK; +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/compute/function.h b/src/zenserver/compute/function.h new file mode 100644 index 000000000..650cee757 --- /dev/null +++ b/src/zenserver/compute/function.h @@ -0,0 +1,73 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#if !defined(ZEN_WITH_COMPUTE_SERVICES) +# define ZEN_WITH_COMPUTE_SERVICES 1 +#endif + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/iohash.h> +# include <zencore/logging.h> +# include <zenhttp/httpserver.h> + +# include <filesystem> +# include <unordered_map> + +namespace zen { + +class CidStore; +class UpstreamApply; +class CloudCacheClient; +class AuthMgr; + +struct UpstreamAuthConfig; +struct CloudCacheClientOptions; + +/** + * Lambda style compute function service + */ +class HttpFunctionService : public HttpService +{ +public: + HttpFunctionService(CidStore& InCidStore, + const CloudCacheClientOptions& ComputeOptions, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const UpstreamAuthConfig& StorageAuthConfig, + AuthMgr& Mgr); + ~HttpFunctionService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + +private: + std::thread InitializeThread; + spdlog::logger& Log() { return m_Log; } + spdlog::logger& m_Log; + HttpRequestRouter m_Router; + CidStore& m_CidStore; + std::unique_ptr<UpstreamApply> m_UpstreamApply; + + struct WorkerDesc + { + CbObject Descriptor; + }; + + [[nodiscard]] HttpResponseCode ExecActionUpstream(const WorkerDesc& Worker, CbObject& Object); + [[nodiscard]] HttpResponseCode ExecActionUpstreamResult(const IoHash& WorkerId, CbObject& Object); + + [[nodiscard]] HttpResponseCode ExecActionUpstream(const WorkerDesc& Worker, CbObject Action, CbObject& Object); + [[nodiscard]] HttpResponseCode ExecActionUpstreamResult(const IoHash& WorkerId, const IoHash& ActionId, CbPackage& Package); + + RwLock m_WorkerLock; + std::unordered_map<IoHash, WorkerDesc> m_WorkerMap; +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/config.cpp b/src/zenserver/config.cpp new file mode 100644 index 000000000..cff93d67b --- /dev/null +++ b/src/zenserver/config.cpp @@ -0,0 +1,902 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "config.h" + +#include "diag/logging.h" + +#include <zencore/crypto.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/string.h> +#include <zenhttp/zenhttp.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <zencore/logging.h> +#include <cxxopts.hpp> +#include <sol/sol.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# include <conio.h> +#else +# include <pwd.h> +#endif + +#if ZEN_PLATFORM_WINDOWS + +// Used for getting My Documents for default data directory +# include <ShlObj.h> +# pragma comment(lib, "shell32.lib") + +std::filesystem::path +PickDefaultStateDirectory() +{ + // Pick sensible default + PWSTR programDataDir = nullptr; + HRESULT hRes = SHGetKnownFolderPath(FOLDERID_ProgramData, 0, NULL, &programDataDir); + + if (SUCCEEDED(hRes)) + { + std::filesystem::path finalPath(programDataDir); + finalPath /= L"Epic\\Zen\\Data"; + ::CoTaskMemFree(programDataDir); + + return finalPath; + } + + return L""; +} + +#else + +std::filesystem::path +PickDefaultStateDirectory() +{ + int UserId = getuid(); + const passwd* Passwd = getpwuid(UserId); + return std::filesystem::path(Passwd->pw_dir) / ".zen"; +} + +#endif + +void +ValidateOptions(ZenServerOptions& ServerOptions) +{ + if (ServerOptions.EncryptionKey.empty() == false) + { + const auto Key = zen::AesKey256Bit::FromString(ServerOptions.EncryptionKey); + + if (Key.IsValid() == false) + { + throw cxxopts::OptionParseException("Invalid AES encryption key"); + } + } + + if (ServerOptions.EncryptionIV.empty() == false) + { + const auto IV = zen::AesIV128Bit::FromString(ServerOptions.EncryptionIV); + + if (IV.IsValid() == false) + { + throw cxxopts::OptionParseException("Invalid AES initialization vector"); + } + } +} + +UpstreamCachePolicy +ParseUpstreamCachePolicy(std::string_view Options) +{ + if (Options == "readonly") + { + return UpstreamCachePolicy::Read; + } + else if (Options == "writeonly") + { + return UpstreamCachePolicy::Write; + } + else if (Options == "disabled") + { + return UpstreamCachePolicy::Disabled; + } + else + { + return UpstreamCachePolicy::ReadWrite; + } +} + +ZenObjectStoreConfig +ParseBucketConfigs(std::span<std::string> Buckets) +{ + using namespace std::literals; + + ZenObjectStoreConfig Cfg; + + // split bucket args in the form of "{BucketName};{LocalPath}" + for (std::string_view Bucket : Buckets) + { + ZenObjectStoreConfig::BucketConfig NewBucket; + + if (auto Idx = Bucket.find_first_of(";"); Idx != std::string_view::npos) + { + NewBucket.Name = Bucket.substr(0, Idx); + NewBucket.Directory = Bucket.substr(Idx + 1); + } + else + { + NewBucket.Name = Bucket; + } + + Cfg.Buckets.push_back(std::move(NewBucket)); + } + + return Cfg; +} + +void ParseConfigFile(const std::filesystem::path& Path, ZenServerOptions& ServerOptions); + +void +ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions) +{ +#if ZEN_WITH_HTTPSYS + const char* DefaultHttp = "httpsys"; +#else + const char* DefaultHttp = "asio"; +#endif + + // Note to those adding future options; std::filesystem::path-type options + // must be read into a std::string first. As of cxxopts-3.0.0 it uses a >> + // stream operator to convert argv value into the options type. std::fs::path + // expects paths in streams to be quoted but argv paths are unquoted. By + // going into a std::string first, paths with whitespace parse correctly. + std::string DataDir; + std::string ContentDir; + std::string AbsLogFile; + std::string ConfigFile; + + cxxopts::Options options("zenserver", "Zen Server"); + options.add_options()("dedicated", + "Enable dedicated server mode", + cxxopts::value<bool>(ServerOptions.IsDedicated)->default_value("false")); + options.add_options()("d, debug", "Enable debugging", cxxopts::value<bool>(ServerOptions.IsDebug)->default_value("false")); + options.add_options()("help", "Show command line help"); + options.add_options()("t, test", "Enable test mode", cxxopts::value<bool>(ServerOptions.IsTest)->default_value("false")); + options.add_options()("log-id", "Specify id for adding context to log output", cxxopts::value<std::string>(ServerOptions.LogId)); + options.add_options()("data-dir", "Specify persistence root", cxxopts::value<std::string>(DataDir)); + options.add_options()("content-dir", "Frontend content directory", cxxopts::value<std::string>(ContentDir)); + options.add_options()("abslog", "Path to log file", cxxopts::value<std::string>(AbsLogFile)); + options.add_options()("config", "Path to Lua config file", cxxopts::value<std::string>(ConfigFile)); + options.add_options()("no-sentry", + "Disable Sentry crash handler", + cxxopts::value<bool>(ServerOptions.NoSentry)->default_value("false")); + + options.add_option("security", + "", + "encryption-aes-key", + "256 bit AES encryption key", + cxxopts::value<std::string>(ServerOptions.EncryptionKey), + ""); + + options.add_option("security", + "", + "encryption-aes-iv", + "128 bit AES encryption initialization vector", + cxxopts::value<std::string>(ServerOptions.EncryptionIV), + ""); + + std::string OpenIdProviderName; + options.add_option("security", + "", + "openid-provider-name", + "Open ID provider name", + cxxopts::value<std::string>(OpenIdProviderName), + "Default"); + + std::string OpenIdProviderUrl; + options.add_option("security", "", "openid-provider-url", "Open ID provider URL", cxxopts::value<std::string>(OpenIdProviderUrl), ""); + + std::string OpenIdClientId; + options.add_option("security", "", "openid-client-id", "Open ID client ID", cxxopts::value<std::string>(OpenIdClientId), ""); + + options + .add_option("lifetime", "", "owner-pid", "Specify owning process id", cxxopts::value<int>(ServerOptions.OwnerPid), "<identifier>"); + options.add_option("lifetime", + "", + "child-id", + "Specify id which can be used to signal parent", + cxxopts::value<std::string>(ServerOptions.ChildId), + "<identifier>"); + +#if ZEN_PLATFORM_WINDOWS + options.add_option("lifetime", + "", + "install", + "Install zenserver as a Windows service", + cxxopts::value<bool>(ServerOptions.InstallService), + ""); + options.add_option("lifetime", + "", + "uninstall", + "Uninstall zenserver as a Windows service", + cxxopts::value<bool>(ServerOptions.UninstallService), + ""); +#endif + + options.add_option("network", + "", + "http", + "Select HTTP server implementation (asio|httpsys|null)", + cxxopts::value<std::string>(ServerOptions.HttpServerClass)->default_value(DefaultHttp), + "<http class>"); + + options.add_option("network", + "p", + "port", + "Select HTTP port", + cxxopts::value<int>(ServerOptions.BasePort)->default_value("1337"), + "<port number>"); + + options.add_option("network", + "", + "websocket-port", + "Websocket server port", + cxxopts::value<int>(ServerOptions.WebSocketPort)->default_value("0"), + "<port number>"); + + options.add_option("network", + "", + "websocket-threads", + "Number of websocket I/O thread(s) (0 == hardware concurrency)", + cxxopts::value<int>(ServerOptions.WebSocketThreads)->default_value("0"), + ""); + +#if ZEN_WITH_TRACE + options.add_option("ue-trace", + "", + "tracehost", + "Hostname to send the trace to", + cxxopts::value<std::string>(ServerOptions.TraceHost)->default_value(""), + ""); + + options.add_option("ue-trace", + "", + "tracefile", + "Path to write a trace to", + cxxopts::value<std::string>(ServerOptions.TraceFile)->default_value(""), + ""); +#endif // ZEN_WITH_TRACE + + options.add_option("diagnostics", + "", + "crash", + "Simulate a crash", + cxxopts::value<bool>(ServerOptions.ShouldCrash)->default_value("false"), + ""); + + std::string UpstreamCachePolicyOptions; + options.add_option("cache", + "", + "upstream-cache-policy", + "", + cxxopts::value<std::string>(UpstreamCachePolicyOptions)->default_value(""), + "Upstream cache policy (readwrite|readonly|writeonly|disabled)"); + + options.add_option("cache", + "", + "upstream-jupiter-url", + "URL to a Jupiter instance", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.Url)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-oauth-url", + "URL to the OAuth provier", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthUrl)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-oauth-clientid", + "The OAuth client ID", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientId)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-oauth-clientsecret", + "The OAuth client secret", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientSecret)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-openid-provider", + "Name of a registered Open ID provider", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OpenIdProvider)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-token", + "A static authentication token", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.AccessToken)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-namespace", + "The Common Blob Store API namespace", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.Namespace)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-jupiter-namespace-ddc", + "The lecacy DDC namespace", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.DdcNamespace)->default_value(""), + ""); + + options.add_option("cache", + "", + "upstream-zen-url", + "URL to remote Zen server. Use a comma separated list to choose the one with the best latency.", + cxxopts::value<std::vector<std::string>>(ServerOptions.UpstreamCacheConfig.ZenConfig.Urls), + ""); + + options.add_option("cache", + "", + "upstream-zen-dns", + "DNS that resolves to one or more Zen server instance(s)", + cxxopts::value<std::vector<std::string>>(ServerOptions.UpstreamCacheConfig.ZenConfig.Dns), + ""); + + options.add_option("cache", + "", + "upstream-thread-count", + "Number of threads used for upstream procsssing", + cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.UpstreamThreadCount)->default_value("4"), + ""); + + options.add_option("cache", + "", + "upstream-connect-timeout-ms", + "Connect timeout in millisecond(s). Default 5000 ms.", + cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.ConnectTimeoutMilliseconds)->default_value("5000"), + ""); + + options.add_option("cache", + "", + "upstream-timeout-ms", + "Timeout in millisecond(s). Default 0 ms", + cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.TimeoutMilliseconds)->default_value("0"), + ""); + + options.add_option("compute", + "", + "upstream-horde-url", + "URL to a Horde instance.", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Url)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-oauth-url", + "URL to the OAuth provier", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthUrl)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-oauth-clientid", + "The OAuth client ID", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientId)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-oauth-clientsecret", + "The OAuth client secret", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientSecret)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-openid-provider", + "Name of a registered Open ID provider", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OpenIdProvider)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-token", + "A static authentication token", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.AccessToken)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-storage-url", + "URL to a Horde Storage instance.", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageUrl)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-storage-oauth-url", + "URL to the OAuth provier", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthUrl)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-storage-oauth-clientid", + "The OAuth client ID", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientId)->default_value(""), + ""); + + options.add_option( + "compute", + "", + "upstream-horde-storage-oauth-clientsecret", + "The OAuth client secret", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientSecret)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-storage-openid-provider", + "Name of a registered Open ID provider", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOpenIdProvider)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-storage-token", + "A static authentication token", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageAccessToken)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-cluster", + "The Horde compute cluster id", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Cluster)->default_value(""), + ""); + + options.add_option("compute", + "", + "upstream-horde-namespace", + "The Jupiter namespace to use with Horde compute", + cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Namespace)->default_value(""), + ""); + + options.add_option("gc", + "", + "gc-enabled", + "Whether garbage collection is enabled or not.", + cxxopts::value<bool>(ServerOptions.GcConfig.Enabled)->default_value("true"), + ""); + + options.add_option("gc", + "", + "gc-small-objects", + "Whether garbage collection of small objects is enabled or not.", + cxxopts::value<bool>(ServerOptions.GcConfig.CollectSmallObjects)->default_value("true"), + ""); + + options.add_option("gc", + "", + "gc-interval-seconds", + "Garbage collection interval in seconds. Default set to 3600 (1 hour).", + cxxopts::value<int32_t>(ServerOptions.GcConfig.IntervalSeconds)->default_value("3600"), + ""); + + options.add_option("gc", + "", + "gc-cache-duration-seconds", + "Max duration in seconds before Z$ entries get evicted. Default set to 1209600 (2 weeks)", + cxxopts::value<int32_t>(ServerOptions.GcConfig.Cache.MaxDurationSeconds)->default_value("1209600"), + ""); + + options.add_option("gc", + "", + "disk-reserve-size", + "Size of gc disk reserve in bytes. Default set to 268435456 (256 Mb).", + cxxopts::value<uint64_t>(ServerOptions.GcConfig.DiskReserveSize)->default_value("268435456"), + ""); + + options.add_option("gc", + "", + "gc-monitor-interval-seconds", + "Garbage collection monitoring interval in seconds. Default set to 30 (30 seconds)", + cxxopts::value<int32_t>(ServerOptions.GcConfig.MonitorIntervalSeconds)->default_value("30"), + ""); + + options.add_option("gc", + "", + "gc-disksize-softlimit", + "Garbage collection disk usage soft limit. Default set to 0 (Off).", + cxxopts::value<uint64_t>(ServerOptions.GcConfig.Cache.DiskSizeSoftLimit)->default_value("0"), + ""); + + options.add_option("objectstore", + "", + "objectstore-enabled", + "Whether the object store is enabled or not.", + cxxopts::value<bool>(ServerOptions.ObjectStoreEnabled)->default_value("false"), + ""); + + std::vector<std::string> BucketConfigs; + options.add_option("objectstore", + "", + "objectstore-bucket", + "Object store bucket mappings.", + cxxopts::value<std::vector<std::string>>(BucketConfigs), + ""); + + try + { + auto result = options.parse(argc, argv); + + if (result.count("help")) + { + zen::logging::ConsoleLog().info("{}", options.help()); +#if ZEN_PLATFORM_WINDOWS + zen::logging::ConsoleLog().info("Press any key to exit!"); + _getch(); +#else + // Assume the user's in a terminal on all other platforms and that + // they'll use less/more/etc. if need be. +#endif + exit(0); + } + + auto MakeSafePath = [](const std::string& Path) { +#if ZEN_PLATFORM_WINDOWS + if (Path.empty()) + { + return Path; + } + + std::string FixedPath = Path; + std::replace(FixedPath.begin(), FixedPath.end(), '/', '\\'); + if (!FixedPath.starts_with("\\\\?\\")) + { + FixedPath.insert(0, "\\\\?\\"); + } + return FixedPath; +#else + return Path; +#endif + }; + + ServerOptions.DataDir = MakeSafePath(DataDir); + ServerOptions.ContentDir = MakeSafePath(ContentDir); + ServerOptions.AbsLogFile = MakeSafePath(AbsLogFile); + ServerOptions.ConfigFile = MakeSafePath(ConfigFile); + ServerOptions.UpstreamCacheConfig.CachePolicy = ParseUpstreamCachePolicy(UpstreamCachePolicyOptions); + + if (OpenIdProviderUrl.empty() == false) + { + if (OpenIdClientId.empty()) + { + throw cxxopts::OptionParseException("Invalid OpenID client ID"); + } + + ServerOptions.AuthConfig.OpenIdProviders.push_back( + {.Name = OpenIdProviderName, .Url = OpenIdProviderUrl, .ClientId = OpenIdClientId}); + } + + ServerOptions.ObjectStoreConfig = ParseBucketConfigs(BucketConfigs); + + if (!ServerOptions.ConfigFile.empty()) + { + ParseConfigFile(ServerOptions.ConfigFile, ServerOptions); + } + else + { + ParseConfigFile(ServerOptions.DataDir / "zen_cfg.lua", ServerOptions); + } + + ValidateOptions(ServerOptions); + } + catch (cxxopts::OptionParseException& e) + { + zen::logging::ConsoleLog().error("Error parsing zenserver arguments: {}\n\n{}", e.what(), options.help()); + + throw; + } + + if (ServerOptions.DataDir.empty()) + { + ServerOptions.DataDir = PickDefaultStateDirectory(); + } + + if (ServerOptions.AbsLogFile.empty()) + { + ServerOptions.AbsLogFile = ServerOptions.DataDir / "logs" / "zenserver.log"; + } +} + +void +ParseConfigFile(const std::filesystem::path& Path, ZenServerOptions& ServerOptions) +{ + zen::IoBuffer LuaScript = zen::IoBufferBuilder::MakeFromFile(Path); + + if (LuaScript) + { + sol::state lua; + + lua.open_libraries(sol::lib::base); + + lua.set_function("getenv", [&](const std::string env) -> sol::object { +#if ZEN_PLATFORM_WINDOWS + std::wstring EnvVarValue; + size_t RequiredSize = 0; + std::wstring EnvWide = zen::Utf8ToWide(env); + _wgetenv_s(&RequiredSize, nullptr, 0, EnvWide.c_str()); + + if (RequiredSize == 0) + return sol::make_object(lua, sol::lua_nil); + + EnvVarValue.resize(RequiredSize); + _wgetenv_s(&RequiredSize, EnvVarValue.data(), RequiredSize, EnvWide.c_str()); + return sol::make_object(lua, zen::WideToUtf8(EnvVarValue.c_str())); +#else + ZEN_UNUSED(env); + return sol::make_object(lua, sol::lua_nil); +#endif + }); + + try + { + sol::load_result config = lua.load(std::string_view((const char*)LuaScript.Data(), LuaScript.Size()), "zen_cfg"); + + if (!config.valid()) + { + sol::error err = config; + + std::string ErrorString = sol::to_string(config.status()); + + throw std::runtime_error(fmt::format("{} error: {}", ErrorString, err.what())); + } + + config(); + } + catch (std::exception& e) + { + throw std::runtime_error(fmt::format("failed to load config script ('{}'): {}", Path, e.what()).c_str()); + } + + if (sol::optional<sol::table> ServerConfig = lua["server"]) + { + if (ServerOptions.DataDir.empty()) + { + if (sol::optional<std::string> Opt = ServerConfig.value()["datadir"]) + { + ServerOptions.DataDir = Opt.value(); + } + } + + if (ServerOptions.ContentDir.empty()) + { + if (sol::optional<std::string> Opt = ServerConfig.value()["contentdir"]) + { + ServerOptions.ContentDir = Opt.value(); + } + } + + if (ServerOptions.AbsLogFile.empty()) + { + if (sol::optional<std::string> Opt = ServerConfig.value()["abslog"]) + { + ServerOptions.AbsLogFile = Opt.value(); + } + } + + ServerOptions.IsDebug = ServerConfig->get_or("debug", ServerOptions.IsDebug); + } + + if (sol::optional<sol::table> NetworkConfig = lua["network"]) + { + if (sol::optional<std::string> Opt = NetworkConfig.value()["httpserverclass"]) + { + ServerOptions.HttpServerClass = Opt.value(); + } + + ServerOptions.BasePort = NetworkConfig->get_or<int>("port", ServerOptions.BasePort); + } + + auto UpdateStringValueFromConfig = [](const sol::table& Table, std::string_view Key, std::string& OutValue) { + // Update the specified config value unless it has been set, i.e. from command line + if (auto MaybeValue = Table.get<sol::optional<std::string>>(Key); MaybeValue.has_value() && OutValue.empty()) + { + OutValue = MaybeValue.value(); + } + }; + + if (sol::optional<sol::table> StructuredCacheConfig = lua["cache"]) + { + ServerOptions.StructuredCacheEnabled = StructuredCacheConfig->get_or("enable", ServerOptions.StructuredCacheEnabled); + + if (auto UpstreamConfig = StructuredCacheConfig->get<sol::optional<sol::table>>("upstream")) + { + std::string Policy = UpstreamConfig->get_or("policy", std::string()); + ServerOptions.UpstreamCacheConfig.CachePolicy = ParseUpstreamCachePolicy(Policy); + ServerOptions.UpstreamCacheConfig.UpstreamThreadCount = + UpstreamConfig->get_or("upstreamthreadcount", ServerOptions.UpstreamCacheConfig.UpstreamThreadCount); + + if (auto JupiterConfig = UpstreamConfig->get<sol::optional<sol::table>>("jupiter")) + { + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("name"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.Name); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("url"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.Url); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("oauthprovider"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthUrl); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("oauthclientid"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientId); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("oauthclientsecret"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientSecret); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("openidprovider"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.OpenIdProvider); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("token"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.AccessToken); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("namespace"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.Namespace); + UpdateStringValueFromConfig(JupiterConfig.value(), + std::string_view("ddcnamespace"), + ServerOptions.UpstreamCacheConfig.JupiterConfig.DdcNamespace); + }; + + if (auto ZenConfig = UpstreamConfig->get<sol::optional<sol::table>>("zen")) + { + ServerOptions.UpstreamCacheConfig.ZenConfig.Name = ZenConfig.value().get_or("name", std::string("Zen")); + + if (auto Url = ZenConfig.value().get<sol::optional<std::string>>("url")) + { + ServerOptions.UpstreamCacheConfig.ZenConfig.Urls.push_back(Url.value()); + } + else if (auto Urls = ZenConfig.value().get<sol::optional<sol::table>>("url")) + { + for (const auto& Kv : Urls.value()) + { + ServerOptions.UpstreamCacheConfig.ZenConfig.Urls.push_back(Kv.second.as<std::string>()); + } + } + + if (auto Dns = ZenConfig.value().get<sol::optional<std::string>>("dns")) + { + ServerOptions.UpstreamCacheConfig.ZenConfig.Dns.push_back(Dns.value()); + } + else if (auto DnsArray = ZenConfig.value().get<sol::optional<sol::table>>("dns")) + { + for (const auto& Kv : DnsArray.value()) + { + ServerOptions.UpstreamCacheConfig.ZenConfig.Dns.push_back(Kv.second.as<std::string>()); + } + } + } + } + } + + if (sol::optional<sol::table> ExecConfig = lua["exec"]) + { + ServerOptions.ExecServiceEnabled = ExecConfig->get_or("enable", ServerOptions.ExecServiceEnabled); + } + + if (sol::optional<sol::table> ComputeConfig = lua["compute"]) + { + ServerOptions.ComputeServiceEnabled = ComputeConfig->get_or("enable", ServerOptions.ComputeServiceEnabled); + + if (auto UpstreamConfig = ComputeConfig->get<sol::optional<sol::table>>("upstream")) + { + if (auto HordeConfig = UpstreamConfig->get<sol::optional<sol::table>>("horde")) + { + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("name"), + ServerOptions.UpstreamCacheConfig.HordeConfig.Name); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("url"), + ServerOptions.UpstreamCacheConfig.HordeConfig.Url); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("oauthprovider"), + ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthUrl); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("oauthclientid"), + ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientId); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("oauthclientsecret"), + ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientSecret); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("openidprovider"), + ServerOptions.UpstreamCacheConfig.HordeConfig.OpenIdProvider); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("token"), + ServerOptions.UpstreamCacheConfig.HordeConfig.AccessToken); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("cluster"), + ServerOptions.UpstreamCacheConfig.HordeConfig.Cluster); + UpdateStringValueFromConfig(HordeConfig.value(), + std::string_view("namespace"), + ServerOptions.UpstreamCacheConfig.HordeConfig.Namespace); + }; + + if (auto StorageConfig = UpstreamConfig->get<sol::optional<sol::table>>("storage")) + { + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("url"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageUrl); + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("oauthprovider"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthUrl); + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("oauthclientid"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientId); + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("oauthclientsecret"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientSecret); + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("openidprovider"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOpenIdProvider); + UpdateStringValueFromConfig(StorageConfig.value(), + std::string_view("token"), + ServerOptions.UpstreamCacheConfig.HordeConfig.StorageAccessToken); + }; + } + } + + if (sol::optional<sol::table> GcConfig = lua["gc"]) + { + ServerOptions.GcConfig.MonitorIntervalSeconds = GcConfig.value().get_or("monitorintervalseconds", 30); + ServerOptions.GcConfig.IntervalSeconds = GcConfig.value().get_or("intervalseconds", 0); + ServerOptions.GcConfig.DiskReserveSize = GcConfig.value().get_or("diskreservesize", uint64_t(1u << 28)); + + if (sol::optional<sol::table> CacheGcConfig = GcConfig.value()["cache"]) + { + ServerOptions.GcConfig.Cache.MaxDurationSeconds = CacheGcConfig.value().get_or("maxdurationseconds", int32_t(0)); + ServerOptions.GcConfig.Cache.DiskSizeLimit = CacheGcConfig.value().get_or("disksizelimit", ~uint64_t(0)); + ServerOptions.GcConfig.Cache.MemorySizeLimit = CacheGcConfig.value().get_or("memorysizelimit", ~uint64_t(0)); + ServerOptions.GcConfig.Cache.DiskSizeSoftLimit = CacheGcConfig.value().get_or("disksizesoftlimit", 0); + } + + if (sol::optional<sol::table> CasGcConfig = GcConfig.value()["cas"]) + { + ServerOptions.GcConfig.Cas.LargeStrategySizeLimit = CasGcConfig.value().get_or("largestrategysizelimit", ~uint64_t(0)); + ServerOptions.GcConfig.Cas.SmallStrategySizeLimit = CasGcConfig.value().get_or("smallstrategysizelimit", ~uint64_t(0)); + ServerOptions.GcConfig.Cas.TinyStrategySizeLimit = CasGcConfig.value().get_or("tinystrategysizelimit", ~uint64_t(0)); + } + } + + if (sol::optional<sol::table> SecurityConfig = lua["security"]) + { + if (sol::optional<sol::table> OpenIdProviders = SecurityConfig.value()["openidproviders"]) + { + for (const auto& Kv : OpenIdProviders.value()) + { + if (sol::optional<sol::table> OpenIdProvider = Kv.second.as<sol::table>()) + { + std::string Name = OpenIdProvider.value().get_or("name", std::string("Default")); + std::string Url = OpenIdProvider.value().get_or("url", std::string()); + std::string ClientId = OpenIdProvider.value().get_or("clientid", std::string()); + + ServerOptions.AuthConfig.OpenIdProviders.push_back( + {.Name = std::move(Name), .Url = std::move(Url), .ClientId = std::move(ClientId)}); + } + } + } + + ServerOptions.EncryptionKey = SecurityConfig.value().get_or("encryptionaeskey", std::string()); + ServerOptions.EncryptionIV = SecurityConfig.value().get_or("encryptionaesiv", std::string()); + } + } +} diff --git a/src/zenserver/config.h b/src/zenserver/config.h new file mode 100644 index 000000000..8a5c6de4e --- /dev/null +++ b/src/zenserver/config.h @@ -0,0 +1,158 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> +#include <filesystem> +#include <string> +#include <vector> + +struct ZenUpstreamJupiterConfig +{ + std::string Name; + std::string Url; + std::string OAuthUrl; + std::string OAuthClientId; + std::string OAuthClientSecret; + std::string OpenIdProvider; + std::string AccessToken; + std::string Namespace; + std::string DdcNamespace; +}; + +struct ZenUpstreamHordeConfig +{ + std::string Name; + std::string Url; + std::string OAuthUrl; + std::string OAuthClientId; + std::string OAuthClientSecret; + std::string OpenIdProvider; + std::string AccessToken; + + std::string StorageUrl; + std::string StorageOAuthUrl; + std::string StorageOAuthClientId; + std::string StorageOAuthClientSecret; + std::string StorageOpenIdProvider; + std::string StorageAccessToken; + + std::string Cluster; + std::string Namespace; +}; + +struct ZenUpstreamZenConfig +{ + std::string Name; + std::vector<std::string> Urls; + std::vector<std::string> Dns; +}; + +enum class UpstreamCachePolicy : uint8_t +{ + Disabled = 0, + Read = 1 << 0, + Write = 1 << 1, + ReadWrite = Read | Write +}; + +struct ZenUpstreamCacheConfig +{ + ZenUpstreamJupiterConfig JupiterConfig; + ZenUpstreamHordeConfig HordeConfig; + ZenUpstreamZenConfig ZenConfig; + int32_t UpstreamThreadCount = 4; + int32_t ConnectTimeoutMilliseconds = 5000; + int32_t TimeoutMilliseconds = 0; + UpstreamCachePolicy CachePolicy = UpstreamCachePolicy::ReadWrite; +}; + +struct ZenCacheEvictionPolicy +{ + uint64_t DiskSizeLimit = ~uint64_t(0); + uint64_t MemorySizeLimit = 1024 * 1024 * 1024; + int32_t MaxDurationSeconds = 24 * 60 * 60; + uint64_t DiskSizeSoftLimit = 0; + bool Enabled = true; +}; + +struct ZenCasEvictionPolicy +{ + uint64_t LargeStrategySizeLimit = ~uint64_t(0); + uint64_t SmallStrategySizeLimit = ~uint64_t(0); + uint64_t TinyStrategySizeLimit = ~uint64_t(0); + bool Enabled = true; +}; + +struct ZenGcConfig +{ + ZenCasEvictionPolicy Cas; + ZenCacheEvictionPolicy Cache; + int32_t MonitorIntervalSeconds = 30; + int32_t IntervalSeconds = 0; + bool CollectSmallObjects = true; + bool Enabled = true; + uint64_t DiskReserveSize = 1ul << 28; +}; + +struct ZenOpenIdProviderConfig +{ + std::string Name; + std::string Url; + std::string ClientId; +}; + +struct ZenAuthConfig +{ + std::vector<ZenOpenIdProviderConfig> OpenIdProviders; +}; + +struct ZenObjectStoreConfig +{ + struct BucketConfig + { + std::string Name; + std::filesystem::path Directory; + }; + + std::vector<BucketConfig> Buckets; +}; + +struct ZenServerOptions +{ + ZenUpstreamCacheConfig UpstreamCacheConfig; + ZenGcConfig GcConfig; + ZenAuthConfig AuthConfig; + ZenObjectStoreConfig ObjectStoreConfig; + std::filesystem::path DataDir; // Root directory for state (used for testing) + std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental) + std::filesystem::path AbsLogFile; // Absolute path to main log file + std::filesystem::path ConfigFile; // Path to Lua config file + std::string ChildId; // Id assigned by parent process (used for lifetime management) + std::string LogId; // Id for tagging log output + std::string HttpServerClass; // Choice of HTTP server implementation + std::string EncryptionKey; // 256 bit AES encryption key + std::string EncryptionIV; // 128 bit AES initialization vector + int BasePort = 1337; // Service listen port (used for both UDP and TCP) + int OwnerPid = 0; // Parent process id (zero for standalone) + int WebSocketPort = 0; // Web socket port (Zero = disabled) + int WebSocketThreads = 0; + bool InstallService = false; // Flag used to initiate service install (temporary) + bool UninstallService = false; // Flag used to initiate service uninstall (temporary) + bool IsDebug = false; + bool IsTest = false; + bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements + bool StructuredCacheEnabled = true; + bool ExecServiceEnabled = true; + bool ComputeServiceEnabled = true; + bool ShouldCrash = false; // Option for testing crash handling + bool IsFirstRun = false; + bool NoSentry = false; + bool ObjectStoreEnabled = false; +#if ZEN_WITH_TRACE + std::string TraceHost; // Host name or IP address to send trace data to + std::string TraceFile; // Path of a file to write a trace +#endif +}; + +void ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions); diff --git a/src/zenserver/diag/diagsvcs.cpp b/src/zenserver/diag/diagsvcs.cpp new file mode 100644 index 000000000..29ad5c3dd --- /dev/null +++ b/src/zenserver/diag/diagsvcs.cpp @@ -0,0 +1,127 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "diagsvcs.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/config.h> +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <fstream> +#include <sstream> + +#include <json11.hpp> + +namespace zen { + +using namespace std::literals; + +bool +ReadFile(const std::string& Path, StringBuilderBase& Out) +{ + try + { + constexpr auto ReadSize = std::size_t{4096}; + auto FileStream = std::ifstream{Path}; + + std::string Buf(ReadSize, '\0'); + while (FileStream.read(&Buf[0], ReadSize)) + { + Out.Append(std::string_view(&Buf[0], FileStream.gcount())); + } + Out.Append(std::string_view(&Buf[0], FileStream.gcount())); + + return true; + } + catch (std::exception&) + { + Out.Reset(); + return false; + } +} + +HttpHealthService::HttpHealthService() +{ + m_Router.RegisterRoute( + "", + [](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "info", + [this](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + + CbObjectWriter Writer; + Writer << "DataRoot"sv << m_HealthInfo.DataRoot.string(); + Writer << "AbsLogPath"sv << m_HealthInfo.AbsLogPath.string(); + Writer << "BuildVersion"sv << m_HealthInfo.BuildVersion; + Writer << "HttpServerClass"sv << m_HealthInfo.HttpServerClass; + + HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "log", + [this](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + + zen::Log().flush(); + + std::filesystem::path Path = + m_HealthInfo.AbsLogPath.empty() ? m_HealthInfo.DataRoot / "logs/zenserver.log" : m_HealthInfo.AbsLogPath; + + ExtendableStringBuilder<4096> Sb; + if (ReadFile(Path.string(), Sb) && Sb.Size() > 0) + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Sb.ToView()); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + }, + HttpVerb::kGet); + m_Router.RegisterRoute( + "version", + [this](HttpRouterRequest& RoutedReq) { + HttpServerRequest& HttpReq = RoutedReq.ServerRequest(); + if (HttpReq.GetQueryParams().GetValue("detailed") == "true") + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION_BUILD_STRING_FULL); + } + else + { + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION); + } + }, + HttpVerb::kGet); +} + +void +HttpHealthService::SetHealthInfo(HealthServiceInfo&& Info) +{ + m_HealthInfo = std::move(Info); +} + +const char* +HttpHealthService::BaseUri() const +{ + return "/health/"; +} + +void +HttpHealthService::HandleRequest(HttpServerRequest& Request) +{ + if (!m_Router.HandleRequest(Request)) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv); + } +} + +} // namespace zen diff --git a/src/zenserver/diag/diagsvcs.h b/src/zenserver/diag/diagsvcs.h new file mode 100644 index 000000000..bd03f8023 --- /dev/null +++ b/src/zenserver/diag/diagsvcs.h @@ -0,0 +1,111 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> +#include <zenhttp/httpserver.h> + +#include <filesystem> + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +class HttpTestService : public HttpService +{ + uint32_t LogPoint = 0; + +public: + HttpTestService() {} + ~HttpTestService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + + virtual void HandleRequest(HttpServerRequest& Request) override + { + using namespace std::literals; + + auto Uri = Request.RelativeUri(); + + if (Uri == "hello"sv) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"hello world!"sv); + + // OutputLogMessageInternal(&LogPoint, 0, 0); + } + else if (Uri == "1K"sv) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1k); + } + else if (Uri == "1M"sv) + { + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1m); + } + else if (Uri == "1M_1k"sv) + { + std::vector<IoBuffer> Buffers; + Buffers.reserve(1024); + + for (int i = 0; i < 1024; ++i) + { + Buffers.push_back(m_1k); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); + } + else if (Uri == "1G"sv) + { + std::vector<IoBuffer> Buffers; + Buffers.reserve(1024); + + for (int i = 0; i < 1024; ++i) + { + Buffers.push_back(m_1m); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); + } + else if (Uri == "1G_1k"sv) + { + std::vector<IoBuffer> Buffers; + Buffers.reserve(1024 * 1024); + + for (int i = 0; i < 1024 * 1024; ++i) + { + Buffers.push_back(m_1k); + } + + Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers); + } + } + +private: + IoBuffer m_1m{1024 * 1024}; + IoBuffer m_1k{m_1m, 0u, 1024}; +}; + +struct HealthServiceInfo +{ + std::filesystem::path DataRoot; + std::filesystem::path AbsLogPath; + std::string HttpServerClass; + std::string BuildVersion; +}; + +class HttpHealthService : public HttpService +{ +public: + HttpHealthService(); + ~HttpHealthService() = default; + + void SetHealthInfo(HealthServiceInfo&& Info); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override final; + +private: + HttpRequestRouter m_Router; + HealthServiceInfo m_HealthInfo; +}; + +} // namespace zen diff --git a/src/zenserver/diag/formatters.h b/src/zenserver/diag/formatters.h new file mode 100644 index 000000000..759df58d3 --- /dev/null +++ b/src/zenserver/diag/formatters.h @@ -0,0 +1,71 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/iobuffer.h> +#include <zencore/string.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +template<> +struct fmt::formatter<cpr::Response> +{ + constexpr auto parse(format_parse_context& Ctx) -> decltype(Ctx.begin()) { return Ctx.end(); } + + template<typename FormatContext> + auto format(const cpr::Response& Response, FormatContext& Ctx) -> decltype(Ctx.out()) + { + using namespace std::literals; + + if (Response.status_code == 200 || Response.status_code == 201) + { + return fmt::format_to(Ctx.out(), + "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s", + Response.url.str(), + Response.status_code, + Response.uploaded_bytes, + Response.downloaded_bytes, + Response.elapsed); + } + else + { + const auto It = Response.header.find("Content-Type"); + const std::string_view ContentType = It != Response.header.end() ? It->second : "<None>"sv; + + if (ContentType == "application/x-ue-cb"sv) + { + zen::IoBuffer Body(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size()); + zen::CbObjectView Obj(Body.Data()); + zen::ExtendableStringBuilder<256> Sb; + std::string_view Json = Obj.ToJson(Sb).ToView(); + + return fmt::format_to(Ctx.out(), + "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Response: '{}', Reason: '{}'", + Response.url.str(), + Response.status_code, + Response.uploaded_bytes, + Response.downloaded_bytes, + Response.elapsed, + Json, + Response.reason); + } + else + { + return fmt::format_to(Ctx.out(), + "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Reponse: '{}', Reason: '{}'", + Response.url.str(), + Response.status_code, + Response.uploaded_bytes, + Response.downloaded_bytes, + Response.elapsed, + Response.text, + Response.reason); + } + } + } +}; diff --git a/src/zenserver/diag/logging.cpp b/src/zenserver/diag/logging.cpp new file mode 100644 index 000000000..24c7572f4 --- /dev/null +++ b/src/zenserver/diag/logging.cpp @@ -0,0 +1,467 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "logging.h" + +#include "config.h" + +ZEN_THIRD_PARTY_INCLUDES_START +#include <spdlog/async.h> +#include <spdlog/async_logger.h> +#include <spdlog/pattern_formatter.h> +#include <spdlog/sinks/ansicolor_sink.h> +#include <spdlog/sinks/basic_file_sink.h> +#include <spdlog/sinks/daily_file_sink.h> +#include <spdlog/sinks/msvc_sink.h> +#include <spdlog/sinks/rotating_file_sink.h> +#include <spdlog/sinks/stdout_color_sinks.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <zencore/compactbinary.h> +#include <zencore/filesystem.h> +#include <zencore/string.h> + +#include <chrono> +#include <memory> + +// Custom logging -- test code, this should be tweaked + +namespace logging { + +using namespace spdlog; +using namespace spdlog::details; +using namespace std::literals; + +class full_formatter final : public spdlog::formatter +{ +public: + full_formatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) : m_Epoch(Epoch), m_LogId(LogId) {} + + virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<full_formatter>(m_LogId, m_Epoch); } + + static constexpr bool UseDate = false; + + virtual void format(const details::log_msg& msg, memory_buf_t& dest) override + { + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using std::chrono::seconds; + + if constexpr (UseDate) + { + auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch()); + if (secs != m_LastLogSecs) + { + m_CachedTm = os::localtime(log_clock::to_time_t(msg.time)); + m_LastLogSecs = secs; + } + } + + const auto& tm_time = m_CachedTm; + + // cache the date/time part for the next second. + auto duration = msg.time - m_Epoch; + auto secs = duration_cast<seconds>(duration); + + if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0) + { + m_CachedDatetime.clear(); + m_CachedDatetime.push_back('['); + + if constexpr (UseDate) + { + fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime); + m_CachedDatetime.push_back('-'); + + fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime); + m_CachedDatetime.push_back('-'); + + fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime); + m_CachedDatetime.push_back(' '); + + fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + + fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + + fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime); + } + else + { + int Count = int(secs.count()); + + const int LogSecs = Count % 60; + Count /= 60; + + const int LogMins = Count % 60; + Count /= 60; + + const int LogHours = Count; + + fmt_helper::pad2(LogHours, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + fmt_helper::pad2(LogMins, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + fmt_helper::pad2(LogSecs, m_CachedDatetime); + } + + m_CachedDatetime.push_back('.'); + + m_CacheTimestamp = secs; + } + + dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); + + auto millis = fmt_helper::time_fraction<milliseconds>(msg.time); + fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest); + dest.push_back(']'); + dest.push_back(' '); + + if (!m_LogId.empty()) + { + dest.push_back('['); + fmt_helper::append_string_view(m_LogId, dest); + dest.push_back(']'); + dest.push_back(' '); + } + + // append logger name if exists + if (msg.logger_name.size() > 0) + { + dest.push_back('['); + fmt_helper::append_string_view(msg.logger_name, dest); + dest.push_back(']'); + dest.push_back(' '); + } + + dest.push_back('['); + // wrap the level name with color + msg.color_range_start = dest.size(); + fmt_helper::append_string_view(level::to_string_view(msg.level), dest); + msg.color_range_end = dest.size(); + dest.push_back(']'); + dest.push_back(' '); + + // add source location if present + if (!msg.source.empty()) + { + dest.push_back('['); + const char* filename = details::short_filename_formatter<details::null_scoped_padder>::basename(msg.source.filename); + fmt_helper::append_string_view(filename, dest); + dest.push_back(':'); + fmt_helper::append_int(msg.source.line, dest); + dest.push_back(']'); + dest.push_back(' '); + } + + fmt_helper::append_string_view(msg.payload, dest); + fmt_helper::append_string_view("\n"sv, dest); + } + +private: + std::chrono::time_point<std::chrono::system_clock> m_Epoch; + std::tm m_CachedTm; + std::chrono::seconds m_LastLogSecs; + std::chrono::seconds m_CacheTimestamp{0}; + memory_buf_t m_CachedDatetime; + std::string m_LogId; +}; + +class json_formatter final : public spdlog::formatter +{ +public: + json_formatter(std::string_view LogId) : m_LogId(LogId) {} + + virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<json_formatter>(m_LogId); } + + virtual void format(const details::log_msg& msg, memory_buf_t& dest) override + { + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using std::chrono::seconds; + + auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch()); + if (secs != m_LastLogSecs) + { + m_CachedTm = os::localtime(log_clock::to_time_t(msg.time)); + m_LastLogSecs = secs; + } + + const auto& tm_time = m_CachedTm; + + // cache the date/time part for the next second. + + if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0) + { + m_CachedDatetime.clear(); + + fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime); + m_CachedDatetime.push_back('-'); + + fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime); + m_CachedDatetime.push_back('-'); + + fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime); + m_CachedDatetime.push_back(' '); + + fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + + fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime); + m_CachedDatetime.push_back(':'); + + fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime); + + m_CachedDatetime.push_back('.'); + + m_CacheTimestamp = secs; + } + dest.append("{"sv); + dest.append("\"time\": \""sv); + dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end()); + auto millis = fmt_helper::time_fraction<milliseconds>(msg.time); + fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest); + dest.append("\", "sv); + + dest.append("\"status\": \""sv); + dest.append(level::to_string_view(msg.level)); + dest.append("\", "sv); + + dest.append("\"source\": \""sv); + dest.append("zenserver"sv); + dest.append("\", "sv); + + dest.append("\"service\": \""sv); + dest.append("zencache"sv); + dest.append("\", "sv); + + if (!m_LogId.empty()) + { + dest.append("\"id\": \""sv); + dest.append(m_LogId); + dest.append("\", "sv); + } + + if (msg.logger_name.size() > 0) + { + dest.append("\"logger.name\": \""sv); + dest.append(msg.logger_name); + dest.append("\", "sv); + } + + if (msg.thread_id != 0) + { + dest.append("\"logger.thread_name\": \""sv); + fmt_helper::pad_uint(msg.thread_id, 0, dest); + dest.append("\", "sv); + } + + if (!msg.source.empty()) + { + dest.append("\"file\": \""sv); + WriteEscapedString(dest, details::short_filename_formatter<details::null_scoped_padder>::basename(msg.source.filename)); + dest.append("\","sv); + + dest.append("\"line\": \""sv); + dest.append(fmt::format("{}", msg.source.line)); + dest.append("\","sv); + + dest.append("\"logger.method_name\": \""sv); + WriteEscapedString(dest, msg.source.funcname); + dest.append("\", "sv); + } + + dest.append("\"message\": \""sv); + WriteEscapedString(dest, msg.payload); + dest.append("\""sv); + + dest.append("}\n"sv); + } + +private: + static inline const std::unordered_map<char, std::string_view> SpecialCharacterMap{{'\b', "\\b"sv}, + {'\f', "\\f"sv}, + {'\n', "\\n"sv}, + {'\r', "\\r"sv}, + {'\t', "\\t"sv}, + {'"', "\\\""sv}, + {'\\', "\\\\"sv}}; + + static void WriteEscapedString(memory_buf_t& dest, const spdlog::string_view_t& payload) + { + const char* RangeStart = payload.begin(); + for (const char* It = RangeStart; It != payload.end(); ++It) + { + if (auto SpecialIt = SpecialCharacterMap.find(*It); SpecialIt != SpecialCharacterMap.end()) + { + if (RangeStart != It) + { + dest.append(RangeStart, It); + } + dest.append(SpecialIt->second); + RangeStart = It + 1; + } + } + if (RangeStart != payload.end()) + { + dest.append(RangeStart, payload.end()); + } + }; + + std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::chrono::seconds m_LastLogSecs{0}; + std::chrono::seconds m_CacheTimestamp{0}; + memory_buf_t m_CachedDatetime; + std::string m_LogId; +}; + +bool +EnableVTMode() +{ +#if ZEN_PLATFORM_WINDOWS + // Set output mode to handle virtual terminal sequences + HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE); + if (hOut == INVALID_HANDLE_VALUE) + { + return false; + } + + DWORD dwMode = 0; + if (!GetConsoleMode(hOut, &dwMode)) + { + return false; + } + + dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if (!SetConsoleMode(hOut, dwMode)) + { + return false; + } +#endif + + return true; +} + +} // namespace logging + +void +InitializeLogging(const ZenServerOptions& GlobalOptions) +{ + zen::logging::InitializeLogging(); + logging::EnableVTMode(); + + bool IsAsync = true; + spdlog::level::level_enum LogLevel = spdlog::level::info; + + if (GlobalOptions.IsDebug) + { + LogLevel = spdlog::level::debug; + IsAsync = false; + } + + if (GlobalOptions.IsTest) + { + LogLevel = spdlog::level::trace; + IsAsync = false; + } + + if (IsAsync) + { + const int QueueSize = 8192; + const int ThreadCount = 1; + spdlog::init_thread_pool(QueueSize, ThreadCount); + + auto AsyncLogger = spdlog::create_async<spdlog::sinks::ansicolor_stdout_sink_mt>("main"); + zen::logging::SetDefault(AsyncLogger); + } + + // Sinks + + auto ConsoleSink = std::make_shared<spdlog::sinks::ansicolor_stdout_sink_mt>(); + + // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance + zen::CreateDirectories(GlobalOptions.AbsLogFile.parent_path()); + +#if 0 + auto FileSink = std::make_shared<spdlog::sinks::daily_file_sink_mt>(zen::PathToUtf8(GlobalOptions.AbsLogFile), + 0, + 0, + /* truncate */ false, + uint16_t(/* max files */ 14)); +#else + auto FileSink = std::make_shared<spdlog::sinks::rotating_file_sink_mt>(zen::PathToUtf8(GlobalOptions.AbsLogFile), + /* max size */ 128 * 1024 * 1024, + /* max files */ 16, + /* rotate on open */ true); +#endif + + std::set_terminate([]() { ZEN_CRITICAL("Program exited abnormally via std::terminate()"); }); + + // Default + + auto& DefaultLogger = zen::logging::Default(); + auto& Sinks = DefaultLogger.sinks(); + + Sinks.clear(); + Sinks.push_back(ConsoleSink); + Sinks.push_back(FileSink); + +#if ZEN_PLATFORM_WINDOWS + if (zen::IsDebuggerPresent() && GlobalOptions.IsDebug) + { + auto DebugSink = std::make_shared<spdlog::sinks::msvc_sink_mt>(); + DebugSink->set_level(spdlog::level::debug); + Sinks.push_back(DebugSink); + } +#endif + + // HTTP server request logging + + std::filesystem::path HttpLogPath = GlobalOptions.DataDir / "logs" / "http.log"; + + // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance + zen::CreateDirectories(HttpLogPath.parent_path()); + + auto HttpSink = std::make_shared<spdlog::sinks::rotating_file_sink_mt>(zen::PathToUtf8(HttpLogPath), + /* max size */ 128 * 1024 * 1024, + /* max files */ 16, + /* rotate on open */ true); + + auto HttpLogger = std::make_shared<spdlog::logger>("http_requests", HttpSink); + spdlog::register_logger(HttpLogger); + + // Jupiter - only log upstream HTTP traffic to file + + auto JupiterLogger = std::make_shared<spdlog::logger>("jupiter", FileSink); + spdlog::register_logger(JupiterLogger); + + // Zen - only log upstream HTTP traffic to file + + auto ZenClientLogger = std::make_shared<spdlog::logger>("zenclient", FileSink); + spdlog::register_logger(ZenClientLogger); + + // Configure all registered loggers according to settings + + spdlog::set_level(LogLevel); + spdlog::flush_on(spdlog::level::err); + spdlog::flush_every(std::chrono::seconds{2}); + spdlog::set_formatter(std::make_unique<logging::full_formatter>(GlobalOptions.LogId, std::chrono::system_clock::now())); + + if (GlobalOptions.AbsLogFile.extension() == ".json") + { + FileSink->set_formatter(std::make_unique<logging::json_formatter>(GlobalOptions.LogId)); + } + else + { + FileSink->set_pattern("[%C-%m-%d.%e %T] [%n] [%l] %v"); + } + DefaultLogger.info("log starting at {}", zen::DateTime::Now().ToIso8601()); +} + +void +ShutdownLogging() +{ + auto& DefaultLogger = zen::logging::Default(); + DefaultLogger.info("log ending at {}", zen::DateTime::Now().ToIso8601()); + zen::logging::ShutdownLogging(); +} diff --git a/src/zenserver/diag/logging.h b/src/zenserver/diag/logging.h new file mode 100644 index 000000000..8df49f842 --- /dev/null +++ b/src/zenserver/diag/logging.h @@ -0,0 +1,10 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging.h> +struct ZenServerOptions; + +void InitializeLogging(const ZenServerOptions& GlobalOptions); + +void ShutdownLogging(); diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp new file mode 100644 index 000000000..149d97924 --- /dev/null +++ b/src/zenserver/frontend/frontend.cpp @@ -0,0 +1,128 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "frontend.h" + +#include <zencore/endian.h> +#include <zencore/filesystem.h> +#include <zencore/string.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#if ZEN_PLATFORM_WINDOWS +# include <Windows.h> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +//////////////////////////////////////////////////////////////////////////////// +HttpFrontendService::HttpFrontendService(std::filesystem::path Directory) : m_Directory(Directory) +{ + std::filesystem::path SelfPath = GetRunningExecutablePath(); + + // Locate a .zip file appended onto the end of this binary + IoBuffer SelfBuffer = IoBufferBuilder::MakeFromFile(SelfPath); + m_ZipFs = ZipFs(std::move(SelfBuffer)); + +#if ZEN_BUILD_DEBUG + if (!Directory.empty()) + { + return; + } + + std::error_code ErrorCode; + auto Path = SelfPath; + while (Path.has_parent_path()) + { + auto ParentPath = Path.parent_path(); + if (ParentPath == Path) + { + break; + } + if (std::filesystem::is_regular_file(ParentPath / "xmake.lua", ErrorCode)) + { + if (ErrorCode) + { + break; + } + + auto HtmlDir = ParentPath / "zenserver" / "frontend" / "html"; + if (std::filesystem::is_directory(HtmlDir, ErrorCode)) + { + m_Directory = HtmlDir; + } + break; + } + Path = ParentPath; + }; +#endif +} + +//////////////////////////////////////////////////////////////////////////////// +HttpFrontendService::~HttpFrontendService() +{ +} + +//////////////////////////////////////////////////////////////////////////////// +const char* +HttpFrontendService::BaseUri() const +{ + return "/dashboard"; // in order to use the root path we need to remove HttpAddUrlToUrlGroup in HttpSys.cpp +} + +//////////////////////////////////////////////////////////////////////////////// +void +HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request) +{ + using namespace std::literals; + + std::string_view Uri = Request.RelativeUriWithExtension(); + for (; Uri[0] == '/'; Uri = Uri.substr(1)) + ; + if (Uri.empty()) + { + Uri = "index.html"sv; + } + + // Dismiss if the URI contains .. anywhere to prevent arbitrary file reads + if (Uri.find("..") != Uri.npos) + { + return Request.WriteResponse(HttpResponseCode::Forbidden); + } + + // Map the file extension to a MIME type. To keep things constrained, only a + // small subset of file extensions is allowed + + HttpContentType ContentType = HttpContentType::kUnknownContentType; + + if (const size_t DotIndex = Uri.rfind("."); DotIndex != Uri.npos) + { + const std::string_view DotExt = Uri.substr(DotIndex + 1); + + ContentType = ParseContentType(DotExt); + } + + if (ContentType == HttpContentType::kUnknownContentType) + { + return Request.WriteResponse(HttpResponseCode::Forbidden); + } + + // The given content directory overrides any zip-fs discovered in the binary + if (!m_Directory.empty()) + { + FileContents File = ReadFile(m_Directory / Uri); + + if (!File.ErrorCode) + { + return Request.WriteResponse(HttpResponseCode::OK, ContentType, File.Data[0]); + } + } + + if (IoBuffer FileBuffer = m_ZipFs.GetFile(Uri)) + { + return Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer); + } + + Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv); +} + +} // namespace zen diff --git a/src/zenserver/frontend/frontend.h b/src/zenserver/frontend/frontend.h new file mode 100644 index 000000000..6eac20620 --- /dev/null +++ b/src/zenserver/frontend/frontend.h @@ -0,0 +1,25 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> +#include "zipfs.h" + +#include <filesystem> + +namespace zen { + +class HttpFrontendService final : public zen::HttpService +{ +public: + HttpFrontendService(std::filesystem::path Directory); + virtual ~HttpFrontendService(); + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + ZipFs m_ZipFs; + std::filesystem::path m_Directory; +}; + +} // namespace zen diff --git a/src/zenserver/frontend/html/index.html b/src/zenserver/frontend/html/index.html new file mode 100644 index 000000000..252ee621e --- /dev/null +++ b/src/zenserver/frontend/html/index.html @@ -0,0 +1,59 @@ +<!DOCTYPE html> +<html> +<head> + <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-F3w7mX95PdgyTmZZMECAngseQB83DfGTowi0iMjiWaeVhAn4FJkqJByhZMI3AhiU" crossorigin="anonymous"> + <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.min.js" integrity="sha384-skAcpIdS7UcVUC05LJ9Dxay8AXcDYfBJqt1CJ85S/CFujBsIzCIv+l9liuYLaMQ/" crossorigin="anonymous"></script> + <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/font/bootstrap-icons.css"> + <style type="text/css"> + body { + background-color: #fafafa; + } + </style> + <script type="text/javascript"> + const getCacheStats = () => { + const opts = { headers: { "Accept": "application/json" } }; + fetch("/stats/z$", opts) + .then(response => { + if (!response.ok) { + throw Error(response.statusText); + } + return response.json(); + }) + .then(json => { + document.getElementById("status").innerHTML = "connected" + document.getElementById("stats").innerHTML = JSON.stringify(json, null, 4); + }) + .catch(error => { + document.getElementById("status").innerHTML = "disconnected" + document.getElementById("stats").innerHTML = "" + console.log(error); + }) + .finally(() => { + window.setTimeout(getCacheStats, 1000); + }); + }; + getCacheStats(); + </script> +</head> +<body> + <div class="container"> + <div class="row"> + <div class="text-center mt-5"> + <pre> +__________ _________ __ +\____ / ____ ____ / _____/_/ |_ ____ _______ ____ + / / _/ __ \ / \ \_____ \ \ __\ / _ \ \_ __ \_/ __ \ + / /_ \ ___/ | | \ / \ | | ( <_> ) | | \/\ ___/ +/_______ \ \___ >|___| //_______ / |__| \____/ |__| \___ > + \/ \/ \/ \/ \/ + </pre> + <pre id="status"/> + </div> + </div> + <div class="row"> + <pre class="mb-0">Z$:</pre> + <pre id="stats"></pre> + <div> + </div> +</body> +</html> diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp new file mode 100644 index 000000000..f9c2bc8ff --- /dev/null +++ b/src/zenserver/frontend/zipfs.cpp @@ -0,0 +1,169 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zipfs.h" + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +namespace { + +#if ZEN_COMPILER_MSC +# pragma warning(push) +# pragma warning(disable : 4200) +#endif + + using ZipInt16 = uint16_t; + + struct ZipInt32 + { + operator uint32_t() const { return *(uint32_t*)Parts; } + uint16_t Parts[2]; + }; + + struct EocdRecord + { + enum : uint32_t + { + Magic = 0x0605'4b50, + }; + ZipInt32 Signature; + ZipInt16 ThisDiskIndex; + ZipInt16 CdStartDiskIndex; + ZipInt16 CdRecordThisDiskCount; + ZipInt16 CdRecordCount; + ZipInt32 CdSize; + ZipInt32 CdOffset; + ZipInt16 CommentSize; + char Comment[]; + }; + + struct CentralDirectoryRecord + { + enum : uint32_t + { + Magic = 0x0201'4b50, + }; + + ZipInt32 Signature; + ZipInt16 VersionMadeBy; + ZipInt16 VersionRequired; + ZipInt16 Flags; + ZipInt16 CompressionMethod; + ZipInt16 LastModTime; + ZipInt16 LastModDate; + ZipInt32 Crc32; + ZipInt32 CompressedSize; + ZipInt32 OriginalSize; + ZipInt16 FileNameLength; + ZipInt16 ExtraFieldLength; + ZipInt16 CommentLength; + ZipInt16 DiskIndex; + ZipInt16 InternalFileAttr; + ZipInt32 ExternalFileAttr; + ZipInt32 Offset; + char FileName[]; + }; + + struct LocalFileHeader + { + enum : uint32_t + { + Magic = 0x0403'4b50, + }; + + ZipInt32 Signature; + ZipInt16 VersionRequired; + ZipInt16 Flags; + ZipInt16 CompressionMethod; + ZipInt16 LastModTime; + ZipInt16 LastModDate; + ZipInt32 Crc32; + ZipInt32 CompressedSize; + ZipInt32 OriginalSize; + ZipInt16 FileNameLength; + ZipInt16 ExtraFieldLength; + char FileName[]; + }; + +#if ZEN_COMPILER_MSC +# pragma warning(pop) +#endif + +} // namespace + +////////////////////////////////////////////////////////////////////////// +ZipFs::ZipFs(IoBuffer&& Buffer) +{ + MemoryView View = Buffer.GetView(); + + uint8_t* Cursor = (uint8_t*)(View.GetData()) + View.GetSize(); + if (View.GetSize() < sizeof(EocdRecord)) + { + return; + } + + const auto* EocdCursor = (EocdRecord*)(Cursor - sizeof(EocdRecord)); + + // It is more correct to search backwards for EocdRecord::Magic as the + // comment can be of a variable length. But here we're not going to support + // zip files with comments. + if (EocdCursor->Signature != EocdRecord::Magic) + { + return; + } + + // Zip64 isn't supported either + if (EocdCursor->ThisDiskIndex == 0xffff) + { + return; + } + + Cursor = (uint8_t*)EocdCursor - uint32_t(EocdCursor->CdOffset) - uint32_t(EocdCursor->CdSize); + + const auto* CdCursor = (CentralDirectoryRecord*)(Cursor + EocdCursor->CdOffset); + for (int i = 0, n = EocdCursor->CdRecordCount; i < n; ++i) + { + const CentralDirectoryRecord& Cd = *CdCursor; + + bool Acceptable = true; + Acceptable &= (Cd.OriginalSize > 0); // has some content + Acceptable &= (Cd.CompressionMethod == 0); // is stored uncomrpessed + if (Acceptable) + { + const uint8_t* Lfh = Cursor + Cd.Offset; + if (uintptr_t(Lfh - Cursor) < View.GetSize()) + { + std::string_view FileName(Cd.FileName, Cd.FileNameLength); + m_Files.insert(std::make_pair(FileName, FileItem{Lfh, size_t(0)})); + } + } + + uint32_t ExtraBytes = Cd.FileNameLength + Cd.ExtraFieldLength + Cd.CommentLength; + CdCursor = (CentralDirectoryRecord*)(Cd.FileName + ExtraBytes); + } + + m_Buffer = std::move(Buffer); +} + +////////////////////////////////////////////////////////////////////////// +IoBuffer +ZipFs::GetFile(const std::string_view& FileName) const +{ + FileMap::iterator Iter = m_Files.find(FileName); + if (Iter == m_Files.end()) + { + return {}; + } + + FileItem& Item = Iter->second; + if (Item.GetSize() > 0) + { + return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); + } + + const auto* Lfh = (LocalFileHeader*)(Item.GetData()); + Item = MemoryView(Lfh->FileName + Lfh->FileNameLength + Lfh->ExtraFieldLength, Lfh->OriginalSize); + return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize()); +} + +} // namespace zen diff --git a/src/zenserver/frontend/zipfs.h b/src/zenserver/frontend/zipfs.h new file mode 100644 index 000000000..e1fa4457c --- /dev/null +++ b/src/zenserver/frontend/zipfs.h @@ -0,0 +1,26 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> + +#include <unordered_map> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// +class ZipFs +{ +public: + ZipFs() = default; + ZipFs(IoBuffer&& Buffer); + IoBuffer GetFile(const std::string_view& FileName) const; + +private: + using FileItem = MemoryView; + using FileMap = std::unordered_map<std::string_view, FileItem>; + FileMap mutable m_Files; + IoBuffer m_Buffer; +}; + +} // namespace zen diff --git a/src/zenserver/monitoring/httpstats.cpp b/src/zenserver/monitoring/httpstats.cpp new file mode 100644 index 000000000..4d985f8c2 --- /dev/null +++ b/src/zenserver/monitoring/httpstats.cpp @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpstats.h" + +namespace zen { + +HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats")) +{ +} + +HttpStatsService::~HttpStatsService() +{ +} + +const char* +HttpStatsService::BaseUri() const +{ + return "/stats/"; +} + +void +HttpStatsService::RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_Providers.insert_or_assign(std::string(Id), &Provider); +} + +void +HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider) +{ + ZEN_UNUSED(Provider); + + RwLock::ExclusiveLockScope _(m_Lock); + m_Providers.erase(std::string(Id)); +} + +void +HttpStatsService::HandleRequest(HttpServerRequest& Request) +{ + using namespace std::literals; + + std::string_view Key = Request.RelativeUri(); + + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + RwLock::SharedLockScope _(m_Lock); + if (auto It = m_Providers.find(std::string{Key}); It != end(m_Providers)) + { + return It->second->HandleStatsRequest(Request); + } + } + + [[fallthrough]]; + default: + return; + } +} + +} // namespace zen diff --git a/src/zenserver/monitoring/httpstats.h b/src/zenserver/monitoring/httpstats.h new file mode 100644 index 000000000..732815a9a --- /dev/null +++ b/src/zenserver/monitoring/httpstats.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging.h> +#include <zenhttp/httpserver.h> + +#include <map> + +namespace zen { + +struct IHttpStatsProvider +{ + virtual void HandleStatsRequest(HttpServerRequest& Request) = 0; +}; + +class HttpStatsService : public HttpService +{ +public: + HttpStatsService(); + ~HttpStatsService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider); + void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider); + +private: + spdlog::logger& m_Log; + HttpRequestRouter m_Router; + + inline spdlog::logger& Log() { return m_Log; } + + RwLock m_Lock; + std::map<std::string, IHttpStatsProvider*> m_Providers; +}; + +} // namespace zen
\ No newline at end of file diff --git a/src/zenserver/monitoring/httpstatus.cpp b/src/zenserver/monitoring/httpstatus.cpp new file mode 100644 index 000000000..8b10601dd --- /dev/null +++ b/src/zenserver/monitoring/httpstatus.cpp @@ -0,0 +1,62 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httpstatus.h" + +namespace zen { + +HttpStatusService::HttpStatusService() : m_Log(logging::Get("status")) +{ +} + +HttpStatusService::~HttpStatusService() +{ +} + +const char* +HttpStatusService::BaseUri() const +{ + return "/status/"; +} + +void +HttpStatusService::RegisterHandler(std::string_view Id, IHttpStatusProvider& Provider) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_Providers.insert_or_assign(std::string(Id), &Provider); +} + +void +HttpStatusService::UnregisterHandler(std::string_view Id, IHttpStatusProvider& Provider) +{ + ZEN_UNUSED(Provider); + + RwLock::ExclusiveLockScope _(m_Lock); + m_Providers.erase(std::string(Id)); +} + +void +HttpStatusService::HandleRequest(HttpServerRequest& Request) +{ + using namespace std::literals; + + std::string_view Key = Request.RelativeUri(); + + switch (Request.RequestVerb()) + { + case HttpVerb::kHead: + case HttpVerb::kGet: + { + RwLock::SharedLockScope _(m_Lock); + if (auto It = m_Providers.find(std::string{Key}); It != end(m_Providers)) + { + return It->second->HandleStatusRequest(Request); + } + } + + [[fallthrough]]; + default: + return; + } +} + +} // namespace zen diff --git a/src/zenserver/monitoring/httpstatus.h b/src/zenserver/monitoring/httpstatus.h new file mode 100644 index 000000000..b04e45324 --- /dev/null +++ b/src/zenserver/monitoring/httpstatus.h @@ -0,0 +1,38 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging.h> +#include <zenhttp/httpserver.h> + +#include <map> + +namespace zen { + +struct IHttpStatusProvider +{ + virtual void HandleStatusRequest(HttpServerRequest& Request) = 0; +}; + +class HttpStatusService : public HttpService +{ +public: + HttpStatusService(); + ~HttpStatusService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + void RegisterHandler(std::string_view Id, IHttpStatusProvider& Provider); + void UnregisterHandler(std::string_view Id, IHttpStatusProvider& Provider); + +private: + spdlog::logger& m_Log; + HttpRequestRouter m_Router; + + RwLock m_Lock; + std::map<std::string, IHttpStatusProvider*> m_Providers; + + inline spdlog::logger& Log() { return m_Log; } +}; + +} // namespace zen
\ No newline at end of file diff --git a/src/zenserver/objectstore/objectstore.cpp b/src/zenserver/objectstore/objectstore.cpp new file mode 100644 index 000000000..e5739418e --- /dev/null +++ b/src/zenserver/objectstore/objectstore.cpp @@ -0,0 +1,232 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <objectstore/objectstore.h> + +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include "zencore/compactbinarybuilder.h" +#include "zenhttp/httpcommon.h" +#include "zenhttp/httpserver.h" + +#include <thread> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <json11.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +ZEN_DEFINE_LOG_CATEGORY_STATIC(LogObj, "obj"sv); + +HttpObjectStoreService::HttpObjectStoreService(ObjectStoreConfig Cfg) : m_Cfg(std::move(Cfg)) +{ + Inititalize(); +} + +HttpObjectStoreService::~HttpObjectStoreService() +{ +} + +const char* +HttpObjectStoreService::BaseUri() const +{ + return "/obj/"; +} + +void +HttpObjectStoreService::HandleRequest(zen::HttpServerRequest& Request) +{ + if (m_Router.HandleRequest(Request) == false) + { + ZEN_LOG_WARN(LogObj, "No route found for {0}", Request.RelativeUri()); + return Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv); + } +} + +void +HttpObjectStoreService::Inititalize() +{ + ZEN_LOG_INFO(LogObj, "Initialzing Object Store in '{}'", m_Cfg.RootDirectory); + for (const auto& Bucket : m_Cfg.Buckets) + { + ZEN_LOG_INFO(LogObj, " - bucket '{}' -> '{}'", Bucket.Name, Bucket.Directory); + } + + m_Router.RegisterRoute( + "distributionpoints/{bucket}", + [this](zen::HttpRouterRequest& Request) { + const std::string BucketName = Request.GetCapture(1); + + StringBuilder<1024> Json; + { + CbObjectWriter Writer; + Writer.BeginArray("distributions"); + Writer << fmt::format("http://localhost:{}/obj/{}", m_Cfg.ServerPort, BucketName); + Writer.EndArray(); + Writer.Save().ToJson(Json); + } + + Request.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, Json.ToString()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{bucket}/{path}", + [this](zen::HttpRouterRequest& Request) { GetBlob(Request); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{bucket}/{path}", + [this](zen::HttpRouterRequest& Request) { PutBlob(Request); }, + HttpVerb::kPost | HttpVerb::kPut); +} + +std::filesystem::path +HttpObjectStoreService::GetBucketDirectory(std::string_view BucketName) +{ + std::lock_guard _(BucketsMutex); + + if (const auto It = std::find_if(std::begin(m_Cfg.Buckets), + std::end(m_Cfg.Buckets), + [&BucketName](const auto& Bucket) -> bool { return Bucket.Name == BucketName; }); + It != std::end(m_Cfg.Buckets)) + { + return It->Directory; + } + + return std::filesystem::path(); +} + +void +HttpObjectStoreService::GetBlob(zen::HttpRouterRequest& Request) +{ + namespace fs = std::filesystem; + + const std::string& BucketName = Request.GetCapture(1); + const fs::path BucketDir = GetBucketDirectory(BucketName); + + if (BucketDir.empty()) + { + ZEN_LOG_DEBUG(LogObj, "GET - [FAILED], unknown bucket '{}'", BucketName); + return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound); + } + + const fs::path RelativeBucketPath = Request.GetCapture(2); + + if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with("..")) + { + ZEN_LOG_DEBUG(LogObj, "GET - from bucket '{}' [FAILED], invalid file path", BucketName); + return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden); + } + + fs::path FilePath = BucketDir / RelativeBucketPath; + if (fs::exists(FilePath) == false) + { + ZEN_LOG_DEBUG(LogObj, "GET - '{}/{}' [FAILED], doesn't exist", BucketName, FilePath); + return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound); + } + + zen::HttpRanges Ranges; + if (Request.ServerRequest().TryGetRanges(Ranges); Ranges.size() > 1) + { + // Only a single range is supported + return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + } + + FileContents File = ReadFile(FilePath); + if (File.ErrorCode) + { + ZEN_LOG_WARN(LogObj, + "GET - '{}/{}' [FAILED] ('{}': {})", + BucketName, + FilePath, + File.ErrorCode.category().name(), + File.ErrorCode.value()); + + return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + } + + const IoBuffer& FileBuf = File.Data[0]; + + if (Ranges.empty()) + { + const uint64_t TotalServed = TotalBytesServed.fetch_add(FileBuf.Size()) + FileBuf.Size(); + + ZEN_LOG_DEBUG(LogObj, + "GET - '{}/{}' ({}) [OK] (Served: {})", + BucketName, + RelativeBucketPath, + NiceBytes(FileBuf.Size()), + NiceBytes(TotalServed)); + + Request.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, FileBuf); + } + else + { + const auto Range = Ranges[0]; + const uint64_t RangeSize = Range.End - Range.Start; + const uint64_t TotalServed = TotalBytesServed.fetch_add(RangeSize) + RangeSize; + + ZEN_LOG_DEBUG(LogObj, + "GET - '{}/{}' (Range: {}-{}) ({}/{}) [OK] (Served: {})", + BucketName, + RelativeBucketPath, + Range.Start, + Range.End, + NiceBytes(RangeSize), + NiceBytes(FileBuf.Size()), + NiceBytes(TotalServed)); + + MemoryView RangeView = FileBuf.GetView().Mid(Range.Start, RangeSize); + if (RangeView.GetSize() != RangeSize) + { + return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + } + + IoBuffer RangeBuf = IoBuffer(IoBuffer::Wrap, RangeView.GetData(), RangeView.GetSize()); + Request.ServerRequest().WriteResponse(HttpResponseCode::PartialContent, HttpContentType::kBinary, RangeBuf); + } +} + +void +HttpObjectStoreService::PutBlob(zen::HttpRouterRequest& Request) +{ + namespace fs = std::filesystem; + + const std::string& BucketName = Request.GetCapture(1); + const fs::path BucketDir = GetBucketDirectory(BucketName); + + if (BucketDir.empty()) + { + ZEN_LOG_DEBUG(LogObj, "PUT - [FAILED], unknown bucket '{}'", BucketName); + return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound); + } + + const fs::path RelativeBucketPath = Request.GetCapture(2); + + if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with("..")) + { + ZEN_LOG_DEBUG(LogObj, "PUT - bucket '{}' [FAILED], invalid file path", BucketName); + return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden); + } + + fs::path FilePath = BucketDir / RelativeBucketPath; + const IoBuffer FileBuf = Request.ServerRequest().ReadPayload(); + + if (FileBuf.Size() == 0) + { + ZEN_LOG_DEBUG(LogObj, "PUT - '{}/{}' [FAILED], empty file", BucketName, FilePath); + return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + } + + WriteFile(FilePath, FileBuf); + ZEN_LOG_DEBUG(LogObj, "PUT - '{}/{}' [OK] ({})", BucketName, RelativeBucketPath, NiceBytes(FileBuf.Size())); + Request.ServerRequest().WriteResponse(HttpResponseCode::OK); +} + +} // namespace zen diff --git a/src/zenserver/objectstore/objectstore.h b/src/zenserver/objectstore/objectstore.h new file mode 100644 index 000000000..eaab57794 --- /dev/null +++ b/src/zenserver/objectstore/objectstore.h @@ -0,0 +1,48 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> +#include <atomic> +#include <filesystem> +#include <mutex> + +namespace zen { + +class HttpRouterRequest; + +struct ObjectStoreConfig +{ + struct BucketConfig + { + std::string Name; + std::filesystem::path Directory; + }; + + std::filesystem::path RootDirectory; + std::vector<BucketConfig> Buckets; + uint16_t ServerPort{1337}; +}; + +class HttpObjectStoreService final : public zen::HttpService +{ +public: + HttpObjectStoreService(ObjectStoreConfig Cfg); + virtual ~HttpObjectStoreService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + void Inititalize(); + std::filesystem::path GetBucketDirectory(std::string_view BucketName); + void GetBlob(zen::HttpRouterRequest& Request); + void PutBlob(zen::HttpRouterRequest& Request); + + ObjectStoreConfig m_Cfg; + std::mutex BucketsMutex; + HttpRequestRouter m_Router; + std::atomic_uint64_t TotalBytesServed{0}; +}; + +} // namespace zen diff --git a/src/zenserver/projectstore/fileremoteprojectstore.cpp b/src/zenserver/projectstore/fileremoteprojectstore.cpp new file mode 100644 index 000000000..d7a34a6c2 --- /dev/null +++ b/src/zenserver/projectstore/fileremoteprojectstore.cpp @@ -0,0 +1,235 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "fileremoteprojectstore.h" + +#include <zencore/compress.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/timer.h> + +namespace zen { + +using namespace std::literals; + +class LocalExportProjectStore : public RemoteProjectStore +{ +public: + LocalExportProjectStore(std::string_view Name, + const std::filesystem::path& FolderPath, + bool ForceDisableBlocks, + bool ForceEnableTempBlocks) + : m_Name(Name) + , m_OutputPath(FolderPath) + { + if (ForceDisableBlocks) + { + m_EnableBlocks = false; + } + if (ForceEnableTempBlocks) + { + m_UseTempBlocks = true; + } + } + + virtual RemoteStoreInfo GetInfo() const override + { + return {.CreateBlocks = m_EnableBlocks, + .UseTempBlockFiles = m_UseTempBlocks, + .Description = fmt::format("[file] {}"sv, m_OutputPath)}; + } + + virtual SaveResult SaveContainer(const IoBuffer& Payload) override + { + Stopwatch Timer; + SaveResult Result; + + { + CbObject ContainerObject = LoadCompactBinaryObject(Payload); + + ContainerObject.IterateAttachments([&](CbFieldView FieldView) { + IoHash AttachmentHash = FieldView.AsBinaryAttachment(); + std::filesystem::path AttachmentPath = GetAttachmentPath(AttachmentHash); + if (!std::filesystem::exists(AttachmentPath)) + { + Result.Needs.insert(AttachmentHash); + } + }); + } + + std::filesystem::path ContainerPath = m_OutputPath; + ContainerPath.append(m_Name); + + CreateDirectories(m_OutputPath); + BasicFile ContainerFile; + ContainerFile.Open(ContainerPath, BasicFile::Mode::kTruncate); + std::error_code Ec; + ContainerFile.WriteAll(Payload, Ec); + if (Ec) + { + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.Reason = Ec.message(); + } + Result.RawHash = IoHash::HashBuffer(Payload); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override + { + Stopwatch Timer; + SaveAttachmentResult Result; + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); + if (!std::filesystem::exists(ChunkPath)) + { + try + { + CreateDirectories(ChunkPath.parent_path()); + + BasicFile ChunkFile; + ChunkFile.Open(ChunkPath, BasicFile::Mode::kTruncate); + size_t Offset = 0; + for (const SharedBuffer& Segment : Payload.GetSegments()) + { + ChunkFile.Write(Segment.GetView(), Offset); + Offset += Segment.GetSize(); + } + } + catch (std::exception& Ex) + { + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.Reason = Ex.what(); + } + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override + { + Stopwatch Timer; + + for (const SharedBuffer& Chunk : Chunks) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(Chunk.AsIoBuffer()); + SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash()); + if (ChunkResult.ErrorCode) + { + ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return SaveAttachmentsResult{ChunkResult}; + } + } + SaveAttachmentsResult Result; + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual Result FinalizeContainer(const IoHash&) override { return {}; } + + virtual LoadContainerResult LoadContainer() override + { + Stopwatch Timer; + LoadContainerResult Result; + std::filesystem::path ContainerPath = m_OutputPath; + ContainerPath.append(m_Name); + if (!std::filesystem::is_regular_file(ContainerPath)) + { + Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound); + Result.Reason = fmt::format("The file {} does not exist"sv, ContainerPath.string()); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + IoBuffer ContainerPayload; + { + BasicFile ContainerFile; + ContainerFile.Open(ContainerPath, BasicFile::Mode::kRead); + ContainerPayload = ContainerFile.ReadAll(); + } + Result.ContainerObject = LoadCompactBinaryObject(ContainerPayload); + if (!Result.ContainerObject) + { + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.Reason = fmt::format("The file {} is not formatted as a compact binary object"sv, ContainerPath.string()); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + { + Stopwatch Timer; + LoadAttachmentResult Result; + std::filesystem::path ChunkPath = GetAttachmentPath(RawHash); + if (!std::filesystem::is_regular_file(ChunkPath)) + { + Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound); + Result.Reason = fmt::format("The file {} does not exist"sv, ChunkPath.string()); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + { + BasicFile ChunkFile; + ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead); + Result.Bytes = ChunkFile.ReadAll(); + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override + { + Stopwatch Timer; + LoadAttachmentsResult Result; + for (const IoHash& Hash : RawHashes) + { + LoadAttachmentResult ChunkResult = LoadAttachment(Hash); + if (ChunkResult.ErrorCode) + { + ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return LoadAttachmentsResult{ChunkResult}; + } + ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(ChunkResult.ElapsedSeconds * 1000))); + Result.Chunks.emplace_back( + std::pair<IoHash, CompressedBuffer>{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))}); + } + return Result; + } + +private: + std::filesystem::path GetAttachmentPath(const IoHash& RawHash) const + { + ExtendablePathBuilder<128> ShardedPath; + ShardedPath.Append(m_OutputPath.c_str()); + ExtendableStringBuilder<64> HashString; + RawHash.ToHexString(HashString); + const char* str = HashString.c_str(); + ShardedPath.AppendSeparator(); + ShardedPath.AppendAsciiRange(str, str + 3); + + ShardedPath.AppendSeparator(); + ShardedPath.AppendAsciiRange(str + 3, str + 5); + + ShardedPath.AppendSeparator(); + ShardedPath.AppendAsciiRange(str + 5, str + 40); + + return ShardedPath.ToPath(); + } + + const std::string m_Name; + const std::filesystem::path m_OutputPath; + bool m_EnableBlocks = true; + bool m_UseTempBlocks = false; +}; + +std::unique_ptr<RemoteProjectStore> +CreateFileRemoteStore(const FileRemoteStoreOptions& Options) +{ + std::unique_ptr<RemoteProjectStore> RemoteStore = std::make_unique<LocalExportProjectStore>(Options.Name, + std::filesystem::path(Options.FolderPath), + Options.ForceDisableBlocks, + Options.ForceEnableTempBlocks); + return RemoteStore; +} + +} // namespace zen diff --git a/src/zenserver/projectstore/fileremoteprojectstore.h b/src/zenserver/projectstore/fileremoteprojectstore.h new file mode 100644 index 000000000..68d1eb71e --- /dev/null +++ b/src/zenserver/projectstore/fileremoteprojectstore.h @@ -0,0 +1,19 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "remoteprojectstore.h" + +namespace zen { + +struct FileRemoteStoreOptions : RemoteStoreOptions +{ + std::filesystem::path FolderPath; + std::string Name; + bool ForceDisableBlocks; + bool ForceEnableTempBlocks; +}; + +std::unique_ptr<RemoteProjectStore> CreateFileRemoteStore(const FileRemoteStoreOptions& Options); + +} // namespace zen diff --git a/src/zenserver/projectstore/jupiterremoteprojectstore.cpp b/src/zenserver/projectstore/jupiterremoteprojectstore.cpp new file mode 100644 index 000000000..66cf3c4f8 --- /dev/null +++ b/src/zenserver/projectstore/jupiterremoteprojectstore.cpp @@ -0,0 +1,244 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "jupiterremoteprojectstore.h" + +#include <zencore/compress.h> +#include <zencore/fmtutils.h> + +#include <auth/authmgr.h> +#include <upstream/jupiter.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +class JupiterRemoteStore : public RemoteProjectStore +{ +public: + JupiterRemoteStore(Ref<CloudCacheClient>&& CloudClient, + std::string_view Namespace, + std::string_view Bucket, + const IoHash& Key, + bool ForceDisableBlocks, + bool ForceDisableTempBlocks) + : m_CloudClient(CloudClient) + , m_Namespace(Namespace) + , m_Bucket(Bucket) + , m_Key(Key) + { + if (ForceDisableBlocks) + { + m_EnableBlocks = false; + } + if (ForceDisableTempBlocks) + { + m_UseTempBlocks = false; + } + } + + virtual RemoteStoreInfo GetInfo() const override + { + return {.CreateBlocks = m_EnableBlocks, + .UseTempBlockFiles = m_UseTempBlocks, + .Description = fmt::format("[cloud] {} as {}/{}/{}"sv, m_CloudClient->ServiceUrl(), m_Namespace, m_Bucket, m_Key)}; + } + + virtual SaveResult SaveContainer(const IoBuffer& Payload) override + { + const int32_t MaxAttempts = 3; + PutRefResult Result; + { + CloudCacheSession Session(m_CloudClient.Get()); + for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutRef(m_Namespace, m_Bucket, m_Key, Payload, ZenContentType::kCbObject); + } + } + + return SaveResult{ConvertResult(Result), {Result.Needs.begin(), Result.Needs.end()} /*, {}*/, IoHash::HashBuffer(Payload)}; + } + + virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override + { + const int32_t MaxAttempts = 3; + CloudCacheResult Result; + { + CloudCacheSession Session(m_CloudClient.Get()); + for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCompressedBlob(m_Namespace, RawHash, Payload); + } + } + + return SaveAttachmentResult{ConvertResult(Result)}; + } + + virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override + { + SaveAttachmentsResult Result; + for (const SharedBuffer& Chunk : Chunks) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(Chunk.AsIoBuffer()); + SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash()); + if (ChunkResult.ErrorCode) + { + return SaveAttachmentsResult{ChunkResult}; + } + } + return Result; + } + + virtual Result FinalizeContainer(const IoHash& RawHash) override + { + const int32_t MaxAttempts = 3; + CloudCacheResult Result; + { + CloudCacheSession Session(m_CloudClient.Get()); + for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.FinalizeRef(m_Namespace, m_Bucket, m_Key, RawHash); + } + } + return ConvertResult(Result); + } + + virtual LoadContainerResult LoadContainer() override + { + const int32_t MaxAttempts = 3; + CloudCacheResult Result; + { + CloudCacheSession Session(m_CloudClient.Get()); + for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.GetRef(m_Namespace, m_Bucket, m_Key, ZenContentType::kCbObject); + } + } + + if (Result.ErrorCode || !Result.Success) + { + return LoadContainerResult{ConvertResult(Result)}; + } + + CbObject ContainerObject = LoadCompactBinaryObject(Result.Response); + if (!ContainerObject) + { + return LoadContainerResult{ + RemoteProjectStore::Result{ + .ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), + .ElapsedSeconds = Result.ElapsedSeconds, + .Reason = fmt::format("The ref {}/{}/{} is not formatted as a compact binary object"sv, m_Namespace, m_Bucket, m_Key)}, + std::move(ContainerObject)}; + } + + return LoadContainerResult{ConvertResult(Result), std::move(ContainerObject)}; + } + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + { + const int32_t MaxAttempts = 3; + CloudCacheResult Result; + { + CloudCacheSession Session(m_CloudClient.Get()); + for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.GetCompressedBlob(m_Namespace, RawHash); + } + } + return LoadAttachmentResult{ConvertResult(Result), std::move(Result.Response)}; + } + + virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override + { + LoadAttachmentsResult Result; + for (const IoHash& Hash : RawHashes) + { + LoadAttachmentResult ChunkResult = LoadAttachment(Hash); + if (ChunkResult.ErrorCode) + { + return LoadAttachmentsResult{ChunkResult}; + } + ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(ChunkResult.ElapsedSeconds * 1000))); + Result.Chunks.emplace_back( + std::pair<IoHash, CompressedBuffer>{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))}); + } + return Result; + } + +private: + static Result ConvertResult(const CloudCacheResult& Response) + { + std::string Text; + int32_t ErrorCode = 0; + if (Response.ErrorCode != 0) + { + ErrorCode = Response.ErrorCode; + } + else if (!Response.Success) + { + ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + if (Response.Response.GetContentType() == ZenContentType::kText) + { + Text = + std::string(reinterpret_cast<const std::string::value_type*>(Response.Response.GetData()), Response.Response.GetSize()); + } + } + return {.ErrorCode = ErrorCode, .ElapsedSeconds = Response.ElapsedSeconds, .Reason = Response.Reason, .Text = Text}; + } + + Ref<CloudCacheClient> m_CloudClient; + const std::string m_Namespace; + const std::string m_Bucket; + const IoHash m_Key; + bool m_EnableBlocks = true; + bool m_UseTempBlocks = true; +}; + +std::unique_ptr<RemoteProjectStore> +CreateJupiterRemoteStore(const JupiterRemoteStoreOptions& Options) +{ + std::string Url = Options.Url; + if (Url.find("://"sv) == std::string::npos) + { + // Assume https URL + Url = fmt::format("https://{}"sv, Url); + } + CloudCacheClientOptions ClientOptions{.Name = "Remote store"sv, + .ServiceUrl = Url, + .ConnectTimeout = std::chrono::milliseconds(2000), + .Timeout = std::chrono::milliseconds(60000)}; + // 1) Access token as parameter in request + // 2) Environment variable (different win vs linux/mac) + // 3) openid-provider (assumes oidctoken.exe -Zen true has been run with matching Options.OpenIdProvider + + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + if (!Options.AccessToken.empty()) + { + TokenProvider = CloudCacheTokenProvider::CreateFromCallback([AccessToken = Options.AccessToken]() { + return CloudCacheAccessToken{.Value = AccessToken, .ExpireTime = GcClock::TimePoint::max()}; + }); + } + else + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([&AuthManager = Options.AuthManager, OpenIdProvider = Options.OpenIdProvider]() { + AuthMgr::OpenIdAccessToken Token = AuthManager.GetOpenIdAccessToken(OpenIdProvider.empty() ? "Default" : OpenIdProvider); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + + Ref<CloudCacheClient> CloudClient(new CloudCacheClient(ClientOptions, std::move(TokenProvider))); + + std::unique_ptr<RemoteProjectStore> RemoteStore = std::make_unique<JupiterRemoteStore>(std::move(CloudClient), + Options.Namespace, + Options.Bucket, + Options.Key, + Options.ForceDisableBlocks, + Options.ForceDisableTempBlocks); + return RemoteStore; +} + +} // namespace zen diff --git a/src/zenserver/projectstore/jupiterremoteprojectstore.h b/src/zenserver/projectstore/jupiterremoteprojectstore.h new file mode 100644 index 000000000..31548af22 --- /dev/null +++ b/src/zenserver/projectstore/jupiterremoteprojectstore.h @@ -0,0 +1,26 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "remoteprojectstore.h" + +namespace zen { + +class AuthMgr; + +struct JupiterRemoteStoreOptions : RemoteStoreOptions +{ + std::string Url; + std::string Namespace; + std::string Bucket; + IoHash Key; + std::string OpenIdProvider; + std::string AccessToken; + AuthMgr& AuthManager; + bool ForceDisableBlocks; + bool ForceDisableTempBlocks; +}; + +std::unique_ptr<RemoteProjectStore> CreateJupiterRemoteStore(const JupiterRemoteStoreOptions& Options); + +} // namespace zen diff --git a/src/zenserver/projectstore/projectstore.cpp b/src/zenserver/projectstore/projectstore.cpp new file mode 100644 index 000000000..847a79a1d --- /dev/null +++ b/src/zenserver/projectstore/projectstore.cpp @@ -0,0 +1,4082 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "projectstore.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zenhttp/httpshared.h> +#include <zenstore/caslog.h> +#include <zenstore/cidstore.h> +#include <zenstore/scrubcontext.h> +#include <zenutil/cache/rpcrecording.h> + +#include "fileremoteprojectstore.h" +#include "jupiterremoteprojectstore.h" +#include "remoteprojectstore.h" +#include "zenremoteprojectstore.h" + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <xxh3.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +# include <zencore/testutils.h> +#endif // ZEN_WITH_TESTS + +namespace zen { + +namespace { + bool PrepareDirectoryDelete(const std::filesystem::path& Dir, std::filesystem::path& OutDeleteDir) + { + int DropIndex = 0; + do + { + if (!std::filesystem::exists(Dir)) + { + return true; + } + + std::string DroppedName = fmt::format("[dropped]{}({})", Dir.filename().string(), DropIndex); + std::filesystem::path DroppedBucketPath = Dir.parent_path() / DroppedName; + if (std::filesystem::exists(DroppedBucketPath)) + { + DropIndex++; + continue; + } + + std::error_code Ec; + std::filesystem::rename(Dir, DroppedBucketPath, Ec); + if (!Ec) + { + OutDeleteDir = DroppedBucketPath; + return true; + } + if (Ec && !std::filesystem::exists(DroppedBucketPath)) + { + // We can't move our folder, probably because it is busy, bail.. + return false; + } + Sleep(100); + } while (true); + } + + std::pair<std::unique_ptr<RemoteProjectStore>, std::string> CreateRemoteStore(CbObjectView Params, + AuthMgr& AuthManager, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize) + { + using namespace std::literals; + + std::unique_ptr<RemoteProjectStore> RemoteStore; + + if (CbObjectView File = Params["file"sv].AsObjectView(); File) + { + std::filesystem::path FolderPath(File["path"sv].AsString()); + if (FolderPath.empty()) + { + return {nullptr, "Missing file path"}; + } + std::string_view Name(File["name"sv].AsString()); + if (Name.empty()) + { + return {nullptr, "Missing file name"}; + } + bool ForceDisableBlocks = File["disableblocks"sv].AsBool(false); + bool ForceEnableTempBlocks = File["enabletempblocks"sv].AsBool(false); + + FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize}, + FolderPath, + std::string(Name), + ForceDisableBlocks, + ForceEnableTempBlocks}; + RemoteStore = CreateFileRemoteStore(Options); + } + + if (CbObjectView Cloud = Params["cloud"sv].AsObjectView(); Cloud) + { + std::string_view CloudServiceUrl = Cloud["url"sv].AsString(); + if (CloudServiceUrl.empty()) + { + return {nullptr, "Missing service url"}; + } + + std::string Url = cpr::util::urlDecode(std::string(CloudServiceUrl)); + std::string_view Namespace = Cloud["namespace"sv].AsString(); + if (Namespace.empty()) + { + return {nullptr, "Missing namespace"}; + } + std::string_view Bucket = Cloud["bucket"sv].AsString(); + if (Bucket.empty()) + { + return {nullptr, "Missing bucket"}; + } + std::string_view OpenIdProvider = Cloud["openid-provider"sv].AsString(); + std::string AccessToken = std::string(Cloud["access-token"sv].AsString()); + if (AccessToken.empty()) + { + std::string_view AccessTokenEnvVariable = Cloud["access-token-env"].AsString(); + if (!AccessTokenEnvVariable.empty()) + { + AccessToken = GetEnvVariable(AccessTokenEnvVariable); + } + } + std::string_view KeyParam = Cloud["key"sv].AsString(); + if (KeyParam.empty()) + { + return {nullptr, "Missing key"}; + } + if (KeyParam.length() != IoHash::StringLength) + { + return {nullptr, "Invalid key"}; + } + IoHash Key = IoHash::FromHexString(KeyParam); + if (Key == IoHash::Zero) + { + return {nullptr, "Invalid key string"}; + } + bool ForceDisableBlocks = Cloud["disableblocks"sv].AsBool(false); + bool ForceDisableTempBlocks = Cloud["disabletempblocks"sv].AsBool(false); + + JupiterRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize}, + Url, + std::string(Namespace), + std::string(Bucket), + Key, + std::string(OpenIdProvider), + AccessToken, + AuthManager, + ForceDisableBlocks, + ForceDisableTempBlocks}; + RemoteStore = CreateJupiterRemoteStore(Options); + } + + if (CbObjectView Zen = Params["zen"sv].AsObjectView(); Zen) + { + std::string_view Url = Zen["url"sv].AsString(); + std::string_view Project = Zen["project"sv].AsString(); + if (Project.empty()) + { + return {nullptr, "Missing project"}; + } + std::string_view Oplog = Zen["oplog"sv].AsString(); + if (Oplog.empty()) + { + return {nullptr, "Missing oplog"}; + } + ZenRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize}, + std::string(Url), + std::string(Project), + std::string(Oplog)}; + RemoteStore = CreateZenRemoteStore(Options); + } + + if (!RemoteStore) + { + return {nullptr, "Unknown remote store type"}; + } + + return {std::move(RemoteStore), ""}; + } + + std::pair<HttpResponseCode, std::string> ConvertResult(const RemoteProjectStore::Result& Result) + { + if (Result.ErrorCode == 0) + { + return {HttpResponseCode::OK, Result.Text}; + } + return {static_cast<HttpResponseCode>(Result.ErrorCode), + Result.Reason.empty() ? Result.Text + : Result.Text.empty() ? Result.Reason + : fmt::format("{}. Reason: '{}'", Result.Text, Result.Reason)}; + } + + void CSVHeader(bool Details, bool AttachmentDetails, StringBuilderBase& CSVWriter) + { + if (AttachmentDetails) + { + CSVWriter << "Project, Oplog, LSN, Key, Cid, Size"; + } + else if (Details) + { + CSVWriter << "Project, Oplog, LSN, Key, Size, AttachmentCount, AttachmentsSize"; + } + else + { + CSVWriter << "Project, Oplog, Key"; + } + } + + void CSVWriteOp(CidStore& CidStore, + std::string_view ProjectId, + std::string_view OplogId, + bool Details, + bool AttachmentDetails, + int LSN, + const Oid& Key, + CbObject Op, + StringBuilderBase& CSVWriter) + { + StringBuilder<32> KeyStringBuilder; + Key.ToString(KeyStringBuilder); + const std::string_view KeyString = KeyStringBuilder.ToView(); + + SharedBuffer Buffer = Op.GetBuffer(); + if (AttachmentDetails) + { + Op.IterateAttachments([&CidStore, &CSVWriter, &ProjectId, &OplogId, LSN, &KeyString](CbFieldView FieldView) { + const IoHash AttachmentHash = FieldView.AsAttachment(); + IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash); + CSVWriter << "\r\n" + << ProjectId << ", " << OplogId << ", " << LSN << ", " << KeyString << ", " << AttachmentHash.ToHexString() + << ", " << gsl::narrow<uint64_t>(Attachment.GetSize()); + }); + } + else if (Details) + { + uint64_t AttachmentCount = 0; + size_t AttachmentsSize = 0; + Op.IterateAttachments([&CidStore, &AttachmentCount, &AttachmentsSize](CbFieldView FieldView) { + const IoHash AttachmentHash = FieldView.AsAttachment(); + AttachmentCount++; + IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash); + AttachmentsSize += Attachment.GetSize(); + }); + CSVWriter << "\r\n" + << ProjectId << ", " << OplogId << ", " << LSN << ", " << KeyString << ", " << gsl::narrow<uint64_t>(Buffer.GetSize()) + << ", " << AttachmentCount << ", " << gsl::narrow<uint64_t>(AttachmentsSize); + } + else + { + CSVWriter << "\r\n" << ProjectId << ", " << OplogId << ", " << KeyString; + } + }; + + void CbWriteOp(CidStore& CidStore, + bool Details, + bool OpDetails, + bool AttachmentDetails, + int LSN, + const Oid& Key, + CbObject Op, + CbObjectWriter& CbWriter) + { + CbWriter.BeginObject(); + { + SharedBuffer Buffer = Op.GetBuffer(); + CbWriter.AddObjectId("key", Key); + if (Details) + { + CbWriter.AddInteger("lsn", LSN); + CbWriter.AddInteger("size", gsl::narrow<uint64_t>(Buffer.GetSize())); + } + if (AttachmentDetails) + { + CbWriter.BeginArray("attachments"); + Op.IterateAttachments([&CidStore, &CbWriter](CbFieldView FieldView) { + const IoHash AttachmentHash = FieldView.AsAttachment(); + CbWriter.BeginObject(); + { + IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash); + CbWriter.AddString("cid", AttachmentHash.ToHexString()); + CbWriter.AddInteger("size", gsl::narrow<uint64_t>(Attachment.GetSize())); + } + CbWriter.EndObject(); + }); + CbWriter.EndArray(); + } + else if (Details) + { + uint64_t AttachmentCount = 0; + size_t AttachmentsSize = 0; + Op.IterateAttachments([&CidStore, &AttachmentCount, &AttachmentsSize](CbFieldView FieldView) { + const IoHash AttachmentHash = FieldView.AsAttachment(); + AttachmentCount++; + IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash); + AttachmentsSize += Attachment.GetSize(); + }); + if (AttachmentCount > 0) + { + CbWriter.AddInteger("attachments", AttachmentCount); + CbWriter.AddInteger("attachmentssize", gsl::narrow<uint64_t>(AttachmentsSize)); + } + } + if (OpDetails) + { + CbWriter.BeginObject("op"); + for (const CbFieldView& Field : Op) + { + if (!Field.HasName()) + { + CbWriter.AddField(Field); + continue; + } + std::string_view FieldName = Field.GetName(); + CbWriter.AddField(FieldName, Field); + } + CbWriter.EndObject(); + } + } + CbWriter.EndObject(); + }; + + void CbWriteOplogOps(CidStore& CidStore, + ProjectStore::Oplog& Oplog, + bool Details, + bool OpDetails, + bool AttachmentDetails, + CbObjectWriter& Cbo) + { + Cbo.BeginArray("ops"); + { + Oplog.IterateOplogWithKey([&Cbo, &CidStore, Details, OpDetails, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) { + CbWriteOp(CidStore, Details, OpDetails, AttachmentDetails, LSN, Key, Op, Cbo); + }); + } + Cbo.EndArray(); + } + + void CbWriteOplog(CidStore& CidStore, + ProjectStore::Oplog& Oplog, + bool Details, + bool OpDetails, + bool AttachmentDetails, + CbObjectWriter& Cbo) + { + Cbo.BeginObject(); + { + Cbo.AddString("name", Oplog.OplogId()); + CbWriteOplogOps(CidStore, Oplog, Details, OpDetails, AttachmentDetails, Cbo); + } + Cbo.EndObject(); + } + + void CbWriteOplogs(CidStore& CidStore, + ProjectStore::Project& Project, + std::vector<std::string> OpLogs, + bool Details, + bool OpDetails, + bool AttachmentDetails, + CbObjectWriter& Cbo) + { + Cbo.BeginArray("oplogs"); + { + for (const std::string& OpLogId : OpLogs) + { + ProjectStore::Oplog* Oplog = Project.OpenOplog(OpLogId); + if (Oplog != nullptr) + { + CbWriteOplog(CidStore, *Oplog, Details, OpDetails, AttachmentDetails, Cbo); + } + } + } + Cbo.EndArray(); + } + + void CbWriteProject(CidStore& CidStore, + ProjectStore::Project& Project, + std::vector<std::string> OpLogs, + bool Details, + bool OpDetails, + bool AttachmentDetails, + CbObjectWriter& Cbo) + { + Cbo.BeginObject(); + { + Cbo.AddString("name", Project.Identifier); + CbWriteOplogs(CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo); + } + Cbo.EndObject(); + } + +} // namespace + +////////////////////////////////////////////////////////////////////////// + +Oid +OpKeyStringAsOId(std::string_view OpKey) +{ + using namespace std::literals; + + CbObjectWriter Writer; + Writer << "key"sv << OpKey; + + XXH3_128Stream KeyHasher; + Writer.Save()["key"sv].WriteToStream([&](const void* Data, size_t Size) { KeyHasher.Append(Data, Size); }); + XXH3_128 KeyHash = KeyHasher.GetHash(); + + Oid OpId; + memcpy(OpId.OidBits, &KeyHash, sizeof(OpId.OidBits)); + + return OpId; +} + +////////////////////////////////////////////////////////////////////////// + +struct ProjectStore::OplogStorage : public RefCounted +{ + OplogStorage(ProjectStore::Oplog* OwnerOplog, std::filesystem::path BasePath) : m_OwnerOplog(OwnerOplog), m_OplogStoragePath(BasePath) + { + } + + ~OplogStorage() + { + ZEN_INFO("closing oplog storage at {}", m_OplogStoragePath); + Flush(); + } + + [[nodiscard]] bool Exists() { return Exists(m_OplogStoragePath); } + [[nodiscard]] static bool Exists(std::filesystem::path BasePath) + { + return std::filesystem::exists(BasePath / "ops.zlog") && std::filesystem::exists(BasePath / "ops.zops"); + } + + static bool Delete(std::filesystem::path BasePath) { return DeleteDirectories(BasePath); } + + uint64_t OpBlobsSize() const + { + RwLock::SharedLockScope _(m_RwLock); + return m_NextOpsOffset; + } + + void Open(bool IsCreate) + { + using namespace std::literals; + + ZEN_INFO("initializing oplog storage at '{}'", m_OplogStoragePath); + + if (IsCreate) + { + DeleteDirectories(m_OplogStoragePath); + CreateDirectories(m_OplogStoragePath); + } + + m_Oplog.Open(m_OplogStoragePath / "ops.zlog"sv, IsCreate ? CasLogFile::Mode::kTruncate : CasLogFile::Mode::kWrite); + m_Oplog.Initialize(); + + m_OpBlobs.Open(m_OplogStoragePath / "ops.zops"sv, IsCreate ? BasicFile::Mode::kTruncate : BasicFile::Mode::kWrite); + + ZEN_ASSERT(IsPow2(m_OpsAlign)); + ZEN_ASSERT(!(m_NextOpsOffset & (m_OpsAlign - 1))); + } + + void ReplayLog(std::function<void(CbObject, const OplogEntry&)>&& Handler) + { + ZEN_TRACE_CPU("ProjectStore::OplogStorage::ReplayLog"); + + // This could use memory mapping or do something clever but for now it just reads the file sequentially + + ZEN_INFO("replaying log for '{}'", m_OplogStoragePath); + + Stopwatch Timer; + + uint64_t InvalidEntries = 0; + + IoBuffer OpBuffer; + m_Oplog.Replay( + [&](const OplogEntry& LogEntry) { + if (LogEntry.OpCoreSize == 0) + { + ++InvalidEntries; + + return; + } + + if (OpBuffer.GetSize() < LogEntry.OpCoreSize) + { + OpBuffer = IoBuffer(LogEntry.OpCoreSize); + } + + const uint64_t OpFileOffset = LogEntry.OpCoreOffset * m_OpsAlign; + + m_OpBlobs.Read((void*)OpBuffer.Data(), LogEntry.OpCoreSize, OpFileOffset); + + // Verify checksum, ignore op data if incorrect + const auto OpCoreHash = uint32_t(XXH3_64bits(OpBuffer.Data(), LogEntry.OpCoreSize) & 0xffffFFFF); + + if (OpCoreHash != LogEntry.OpCoreHash) + { + ZEN_WARN("skipping oplog entry with bad checksum!"); + return; + } + + CbObject Op(SharedBuffer::MakeView(OpBuffer.Data(), LogEntry.OpCoreSize)); + + m_NextOpsOffset = + Max(m_NextOpsOffset.load(std::memory_order_relaxed), RoundUp(OpFileOffset + LogEntry.OpCoreSize, m_OpsAlign)); + m_MaxLsn = Max(m_MaxLsn.load(std::memory_order_relaxed), LogEntry.OpLsn); + + Handler(Op, LogEntry); + }, + 0); + + if (InvalidEntries) + { + ZEN_WARN("ignored {} zero-sized oplog entries", InvalidEntries); + } + + ZEN_INFO("Oplog replay completed in {} - Max LSN# {}, Next offset: {}", + NiceTimeSpanMs(Timer.GetElapsedTimeMs()), + m_MaxLsn, + m_NextOpsOffset); + } + + void ReplayLog(const std::vector<OplogEntryAddress>& Entries, std::function<void(CbObject)>&& Handler) + { + for (const OplogEntryAddress& Entry : Entries) + { + CbObject Op = GetOp(Entry); + Handler(Op); + } + } + + CbObject GetOp(const OplogEntryAddress& Entry) + { + IoBuffer OpBuffer(Entry.Size); + + const uint64_t OpFileOffset = Entry.Offset * m_OpsAlign; + m_OpBlobs.Read((void*)OpBuffer.Data(), Entry.Size, OpFileOffset); + + return CbObject(SharedBuffer(std::move(OpBuffer))); + } + + OplogEntry AppendOp(SharedBuffer Buffer, uint32_t OpCoreHash, XXH3_128 KeyHash) + { + ZEN_TRACE_CPU("ProjectStore::OplogStorage::AppendOp"); + + using namespace std::literals; + + uint64_t WriteSize = Buffer.GetSize(); + + RwLock::ExclusiveLockScope Lock(m_RwLock); + const uint64_t WriteOffset = m_NextOpsOffset; + const uint32_t OpLsn = ++m_MaxLsn; + m_NextOpsOffset = RoundUp(WriteOffset + WriteSize, m_OpsAlign); + Lock.ReleaseNow(); + + ZEN_ASSERT(IsMultipleOf(WriteOffset, m_OpsAlign)); + + OplogEntry Entry = {.OpLsn = OpLsn, + .OpCoreOffset = gsl::narrow_cast<uint32_t>(WriteOffset / m_OpsAlign), + .OpCoreSize = uint32_t(Buffer.GetSize()), + .OpCoreHash = OpCoreHash, + .OpKeyHash = KeyHash}; + + m_Oplog.Append(Entry); + m_OpBlobs.Write(Buffer.GetData(), WriteSize, WriteOffset); + + return Entry; + } + + void Flush() + { + m_Oplog.Flush(); + m_OpBlobs.Flush(); + } + + spdlog::logger& Log() { return m_OwnerOplog->Log(); } + +private: + ProjectStore::Oplog* m_OwnerOplog; + std::filesystem::path m_OplogStoragePath; + mutable RwLock m_RwLock; + TCasLogFile<OplogEntry> m_Oplog; + BasicFile m_OpBlobs; + std::atomic<uint64_t> m_NextOpsOffset{0}; + uint64_t m_OpsAlign = 32; + std::atomic<uint32_t> m_MaxLsn{0}; +}; + +////////////////////////////////////////////////////////////////////////// + +ProjectStore::Oplog::Oplog(std::string_view Id, + Project* Project, + CidStore& Store, + std::filesystem::path BasePath, + const std::filesystem::path& MarkerPath) +: m_OuterProject(Project) +, m_CidStore(Store) +, m_BasePath(BasePath) +, m_MarkerPath(MarkerPath) +, m_OplogId(Id) +{ + using namespace std::literals; + + m_Storage = new OplogStorage(this, m_BasePath); + const bool StoreExists = m_Storage->Exists(); + m_Storage->Open(/* IsCreate */ !StoreExists); + + m_TempPath = m_BasePath / "temp"sv; + + CleanDirectory(m_TempPath); +} + +ProjectStore::Oplog::~Oplog() +{ + if (m_Storage) + { + Flush(); + } +} + +void +ProjectStore::Oplog::Flush() +{ + ZEN_ASSERT(m_Storage); + m_Storage->Flush(); +} + +void +ProjectStore::Oplog::Scrub(ScrubContext& Ctx) const +{ + ZEN_UNUSED(Ctx); +} + +void +ProjectStore::Oplog::GatherReferences(GcContext& GcCtx) +{ + RwLock::SharedLockScope _(m_OplogLock); + + std::vector<IoHash> Hashes; + Hashes.reserve(Max(m_ChunkMap.size(), m_MetaMap.size())); + + for (const auto& Kv : m_ChunkMap) + { + Hashes.push_back(Kv.second); + } + + GcCtx.AddRetainedCids(Hashes); + + Hashes.clear(); + + for (const auto& Kv : m_MetaMap) + { + Hashes.push_back(Kv.second); + } + + GcCtx.AddRetainedCids(Hashes); +} + +uint64_t +ProjectStore::Oplog::TotalSize() const +{ + RwLock::SharedLockScope _(m_OplogLock); + if (m_Storage) + { + return m_Storage->OpBlobsSize(); + } + return 0; +} + +bool +ProjectStore::Oplog::IsExpired() const +{ + if (m_MarkerPath.empty()) + { + return false; + } + return !std::filesystem::exists(m_MarkerPath); +} + +std::filesystem::path +ProjectStore::Oplog::PrepareForDelete(bool MoveFolder) +{ + RwLock::ExclusiveLockScope _(m_OplogLock); + m_ChunkMap.clear(); + m_MetaMap.clear(); + m_FileMap.clear(); + m_OpAddressMap.clear(); + m_LatestOpMap.clear(); + m_Storage = {}; + if (!MoveFolder) + { + return {}; + } + std::filesystem::path MovedDir; + if (PrepareDirectoryDelete(m_BasePath, MovedDir)) + { + return MovedDir; + } + return {}; +} + +bool +ProjectStore::Oplog::ExistsAt(std::filesystem::path BasePath) +{ + using namespace std::literals; + + std::filesystem::path StateFilePath = BasePath / "oplog.zcb"sv; + return std::filesystem::is_regular_file(StateFilePath); +} + +void +ProjectStore::Oplog::Read() +{ + using namespace std::literals; + + std::filesystem::path StateFilePath = m_BasePath / "oplog.zcb"sv; + if (std::filesystem::is_regular_file(StateFilePath)) + { + ZEN_INFO("reading config for oplog '{}' in project '{}' from {}", m_OplogId, m_OuterProject->Identifier, StateFilePath); + + BasicFile Blob; + Blob.Open(StateFilePath, BasicFile::Mode::kRead); + + IoBuffer Obj = Blob.ReadAll(); + CbValidateError ValidationError = ValidateCompactBinary(MemoryView(Obj.Data(), Obj.Size()), CbValidateMode::All); + + if (ValidationError != CbValidateError::None) + { + ZEN_ERROR("validation error {} hit for '{}'", int(ValidationError), StateFilePath); + return; + } + + CbObject Cfg = LoadCompactBinaryObject(Obj); + + m_MarkerPath = Cfg["gcpath"sv].AsString(); + } + else + { + ZEN_INFO("config for oplog '{}' in project '{}' not found at {}. Assuming legacy store", + m_OplogId, + m_OuterProject->Identifier, + StateFilePath); + } + ReplayLog(); +} + +void +ProjectStore::Oplog::Write() +{ + using namespace std::literals; + + BinaryWriter Mem; + + CbObjectWriter Cfg; + + Cfg << "gcpath"sv << PathToUtf8(m_MarkerPath); + + Cfg.Save(Mem); + + std::filesystem::path StateFilePath = m_BasePath / "oplog.zcb"sv; + + ZEN_INFO("persisting config for oplog '{}' in project '{}' to {}", m_OplogId, m_OuterProject->Identifier, StateFilePath); + + BasicFile Blob; + Blob.Open(StateFilePath, BasicFile::Mode::kTruncate); + Blob.Write(Mem.Data(), Mem.Size(), 0); + Blob.Flush(); +} + +void +ProjectStore::Oplog::ReplayLog() +{ + RwLock::ExclusiveLockScope OplogLock(m_OplogLock); + if (!m_Storage) + { + return; + } + m_Storage->ReplayLog( + [&](CbObject Op, const OplogEntry& OpEntry) { RegisterOplogEntry(OplogLock, GetMapping(Op), OpEntry, kUpdateReplay); }); +} + +IoBuffer +ProjectStore::Oplog::FindChunk(Oid ChunkId) +{ + RwLock::SharedLockScope OplogLock(m_OplogLock); + if (!m_Storage) + { + return IoBuffer{}; + } + + if (auto ChunkIt = m_ChunkMap.find(ChunkId); ChunkIt != m_ChunkMap.end()) + { + IoHash ChunkHash = ChunkIt->second; + OplogLock.ReleaseNow(); + + IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkHash); + Chunk.SetContentType(ZenContentType::kCompressedBinary); + + return Chunk; + } + + if (auto FileIt = m_FileMap.find(ChunkId); FileIt != m_FileMap.end()) + { + std::filesystem::path FilePath = m_OuterProject->RootDir / FileIt->second.ServerPath; + + OplogLock.ReleaseNow(); + + IoBuffer FileChunk = IoBufferBuilder::MakeFromFile(FilePath); + FileChunk.SetContentType(ZenContentType::kBinary); + + return FileChunk; + } + + if (auto MetaIt = m_MetaMap.find(ChunkId); MetaIt != m_MetaMap.end()) + { + IoHash ChunkHash = MetaIt->second; + OplogLock.ReleaseNow(); + + IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkHash); + Chunk.SetContentType(ZenContentType::kCompressedBinary); + + return Chunk; + } + + return {}; +} + +void +ProjectStore::Oplog::IterateFileMap( + std::function<void(const Oid&, const std::string_view& ServerPath, const std::string_view& ClientPath)>&& Fn) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return; + } + + for (const auto& Kv : m_FileMap) + { + Fn(Kv.first, Kv.second.ServerPath, Kv.second.ClientPath); + } +} + +void +ProjectStore::Oplog::IterateOplog(std::function<void(CbObject)>&& Handler) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return; + } + + std::vector<OplogEntryAddress> Entries; + Entries.reserve(m_LatestOpMap.size()); + + for (const auto& Kv : m_LatestOpMap) + { + const auto AddressEntry = m_OpAddressMap.find(Kv.second); + ZEN_ASSERT(AddressEntry != m_OpAddressMap.end()); + + Entries.push_back(AddressEntry->second); + } + + std::sort(Entries.begin(), Entries.end(), [](const OplogEntryAddress& Lhs, const OplogEntryAddress& Rhs) { + return Lhs.Offset < Rhs.Offset; + }); + + m_Storage->ReplayLog(Entries, [&](CbObject Op) { Handler(Op); }); +} + +void +ProjectStore::Oplog::IterateOplogWithKey(std::function<void(int, const Oid&, CbObject)>&& Handler) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return; + } + + std::vector<size_t> EntryIndexes; + std::vector<OplogEntryAddress> Entries; + std::vector<Oid> Keys; + std::vector<int> LSNs; + Entries.reserve(m_LatestOpMap.size()); + EntryIndexes.reserve(m_LatestOpMap.size()); + Keys.reserve(m_LatestOpMap.size()); + LSNs.reserve(m_LatestOpMap.size()); + + for (const auto& Kv : m_LatestOpMap) + { + const auto AddressEntry = m_OpAddressMap.find(Kv.second); + ZEN_ASSERT(AddressEntry != m_OpAddressMap.end()); + + Entries.push_back(AddressEntry->second); + Keys.push_back(Kv.first); + LSNs.push_back(Kv.second); + EntryIndexes.push_back(EntryIndexes.size()); + } + + std::sort(EntryIndexes.begin(), EntryIndexes.end(), [&Entries](const size_t& Lhs, const size_t& Rhs) { + const OplogEntryAddress& LhsEntry = Entries[Lhs]; + const OplogEntryAddress& RhsEntry = Entries[Rhs]; + return LhsEntry.Offset < RhsEntry.Offset; + }); + std::vector<OplogEntryAddress> SortedEntries; + SortedEntries.reserve(EntryIndexes.size()); + for (size_t Index : EntryIndexes) + { + SortedEntries.push_back(Entries[Index]); + } + + size_t EntryIndex = 0; + m_Storage->ReplayLog(SortedEntries, [&](CbObject Op) { + Handler(LSNs[EntryIndex], Keys[EntryIndex], Op); + EntryIndex++; + }); +} + +int +ProjectStore::Oplog::GetOpIndexByKey(const Oid& Key) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return {}; + } + if (const auto LatestOp = m_LatestOpMap.find(Key); LatestOp != m_LatestOpMap.end()) + { + return LatestOp->second; + } + return -1; +} + +std::optional<CbObject> +ProjectStore::Oplog::GetOpByKey(const Oid& Key) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return {}; + } + + if (const auto LatestOp = m_LatestOpMap.find(Key); LatestOp != m_LatestOpMap.end()) + { + const auto AddressEntry = m_OpAddressMap.find(LatestOp->second); + ZEN_ASSERT(AddressEntry != m_OpAddressMap.end()); + + return m_Storage->GetOp(AddressEntry->second); + } + + return {}; +} + +std::optional<CbObject> +ProjectStore::Oplog::GetOpByIndex(int Index) +{ + RwLock::SharedLockScope _(m_OplogLock); + if (!m_Storage) + { + return {}; + } + + if (const auto AddressEntryIt = m_OpAddressMap.find(Index); AddressEntryIt != m_OpAddressMap.end()) + { + return m_Storage->GetOp(AddressEntryIt->second); + } + + return {}; +} + +void +ProjectStore::Oplog::AddFileMapping(const RwLock::ExclusiveLockScope&, + Oid FileId, + IoHash Hash, + std::string_view ServerPath, + std::string_view ClientPath) +{ + if (Hash != IoHash::Zero) + { + m_ChunkMap.insert_or_assign(FileId, Hash); + } + + FileMapEntry Entry; + Entry.ServerPath = ServerPath; + Entry.ClientPath = ClientPath; + + m_FileMap[FileId] = std::move(Entry); + + if (Hash != IoHash::Zero) + { + m_ChunkMap.insert_or_assign(FileId, Hash); + } +} + +void +ProjectStore::Oplog::AddChunkMapping(const RwLock::ExclusiveLockScope&, Oid ChunkId, IoHash Hash) +{ + m_ChunkMap.insert_or_assign(ChunkId, Hash); +} + +void +ProjectStore::Oplog::AddMetaMapping(const RwLock::ExclusiveLockScope&, Oid ChunkId, IoHash Hash) +{ + m_MetaMap.insert_or_assign(ChunkId, Hash); +} + +ProjectStore::Oplog::OplogEntryMapping +ProjectStore::Oplog::GetMapping(CbObject Core) +{ + using namespace std::literals; + + OplogEntryMapping Result; + + // Update chunk id maps + CbObjectView PackageObj = Core["package"sv].AsObjectView(); + CbArrayView BulkDataArray = Core["bulkdata"sv].AsArrayView(); + CbArrayView PackageDataArray = Core["packagedata"sv].AsArrayView(); + Result.Chunks.reserve(PackageObj ? 1 : 0 + BulkDataArray.Num() + PackageDataArray.Num()); + + if (PackageObj) + { + Oid Id = PackageObj["id"sv].AsObjectId(); + IoHash Hash = PackageObj["data"sv].AsBinaryAttachment(); + Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash}); + ZEN_DEBUG("package data {} -> {}", Id, Hash); + } + + for (CbFieldView& Entry : PackageDataArray) + { + CbObjectView PackageDataObj = Entry.AsObjectView(); + Oid Id = PackageDataObj["id"sv].AsObjectId(); + IoHash Hash = PackageDataObj["data"sv].AsBinaryAttachment(); + Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash}); + ZEN_DEBUG("package {} -> {}", Id, Hash); + } + + for (CbFieldView& Entry : BulkDataArray) + { + CbObjectView BulkObj = Entry.AsObjectView(); + Oid Id = BulkObj["id"sv].AsObjectId(); + IoHash Hash = BulkObj["data"sv].AsBinaryAttachment(); + Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash}); + ZEN_DEBUG("bulkdata {} -> {}", Id, Hash); + } + + CbArrayView FilesArray = Core["files"sv].AsArrayView(); + Result.Files.reserve(FilesArray.Num()); + for (CbFieldView& Entry : FilesArray) + { + CbObjectView FileObj = Entry.AsObjectView(); + + std::string_view ServerPath = FileObj["serverpath"sv].AsString(); + std::string_view ClientPath = FileObj["clientpath"sv].AsString(); + if (ServerPath.empty() || ClientPath.empty()) + { + ZEN_WARN("invalid file"); + continue; + } + + Oid Id = FileObj["id"sv].AsObjectId(); + IoHash Hash = FileObj["data"sv].AsBinaryAttachment(); + Result.Files.emplace_back( + OplogEntryMapping::FileMapping{OplogEntryMapping::Mapping{Id, Hash}, std::string(ServerPath), std::string(ClientPath)}); + ZEN_DEBUG("file {} -> {}, ServerPath: {}, ClientPath: {}", Id, Hash, ServerPath, ClientPath); + } + + CbArrayView MetaArray = Core["meta"sv].AsArrayView(); + Result.Meta.reserve(MetaArray.Num()); + for (CbFieldView& Entry : MetaArray) + { + CbObjectView MetaObj = Entry.AsObjectView(); + Oid Id = MetaObj["id"sv].AsObjectId(); + IoHash Hash = MetaObj["data"sv].AsBinaryAttachment(); + Result.Meta.emplace_back(OplogEntryMapping::Mapping{Id, Hash}); + auto NameString = MetaObj["name"sv].AsString(); + ZEN_DEBUG("meta data ({}) {} -> {}", NameString, Id, Hash); + } + + return Result; +} + +uint32_t +ProjectStore::Oplog::RegisterOplogEntry(RwLock::ExclusiveLockScope& OplogLock, + const OplogEntryMapping& OpMapping, + const OplogEntry& OpEntry, + UpdateType TypeOfUpdate) +{ + ZEN_TRACE_CPU("ProjectStore::Oplog::RegisterOplogEntry"); + + ZEN_UNUSED(TypeOfUpdate); + + // For now we're assuming the update is all in-memory so we can hold an exclusive lock without causing + // too many problems. Longer term we'll probably want to ensure we can do concurrent updates however + + using namespace std::literals; + + // Update chunk id maps + for (const OplogEntryMapping::Mapping& Chunk : OpMapping.Chunks) + { + AddChunkMapping(OplogLock, Chunk.Id, Chunk.Hash); + } + + for (const OplogEntryMapping::FileMapping& File : OpMapping.Files) + { + AddFileMapping(OplogLock, File.Id, File.Hash, File.ServerPath, File.ClientPath); + } + + for (const OplogEntryMapping::Mapping& Meta : OpMapping.Meta) + { + AddMetaMapping(OplogLock, Meta.Id, Meta.Hash); + } + + m_OpAddressMap.emplace(OpEntry.OpLsn, OplogEntryAddress{.Offset = OpEntry.OpCoreOffset, .Size = OpEntry.OpCoreSize}); + m_LatestOpMap[OpEntry.OpKeyAsOId()] = OpEntry.OpLsn; + + return OpEntry.OpLsn; +} + +uint32_t +ProjectStore::Oplog::AppendNewOplogEntry(CbPackage OpPackage) +{ + ZEN_TRACE_CPU("ProjectStore::Oplog::AppendNewOplogEntry"); + + const CbObject& Core = OpPackage.GetObject(); + const uint32_t EntryId = AppendNewOplogEntry(Core); + if (EntryId == 0xffffffffu) + { + // The oplog has been deleted so just drop this + return EntryId; + } + + // Persist attachments after oplog entry so GC won't find attachments without references + + uint64_t AttachmentBytes = 0; + uint64_t NewAttachmentBytes = 0; + + auto Attachments = OpPackage.GetAttachments(); + + for (const auto& Attach : Attachments) + { + ZEN_ASSERT(Attach.IsCompressedBinary()); + + CompressedBuffer AttachmentData = Attach.AsCompressedBinary(); + const uint64_t AttachmentSize = AttachmentData.DecodeRawSize(); + CidStore::InsertResult InsertResult = m_CidStore.AddChunk(AttachmentData.GetCompressed().Flatten().AsIoBuffer(), Attach.GetHash()); + + if (InsertResult.New) + { + NewAttachmentBytes += AttachmentSize; + } + AttachmentBytes += AttachmentSize; + } + + ZEN_DEBUG("oplog entry #{} attachments: {} new, {} total", EntryId, NiceBytes(NewAttachmentBytes), NiceBytes(AttachmentBytes)); + + return EntryId; +} + +uint32_t +ProjectStore::Oplog::AppendNewOplogEntry(CbObject Core) +{ + ZEN_TRACE_CPU("ProjectStore::Oplog::AppendNewOplogEntry"); + + using namespace std::literals; + + OplogEntryMapping Mapping = GetMapping(Core); + + SharedBuffer Buffer = Core.GetBuffer(); + const uint64_t WriteSize = Buffer.GetSize(); + const auto OpCoreHash = uint32_t(XXH3_64bits(Buffer.GetData(), WriteSize) & 0xffffFFFF); + + ZEN_ASSERT(WriteSize != 0); + + XXH3_128Stream KeyHasher; + Core["key"sv].WriteToStream([&](const void* Data, size_t Size) { KeyHasher.Append(Data, Size); }); + XXH3_128 KeyHash = KeyHasher.GetHash(); + + RefPtr<OplogStorage> Storage; + { + RwLock::SharedLockScope _(m_OplogLock); + Storage = m_Storage; + } + if (!m_Storage) + { + return 0xffffffffu; + } + const OplogEntry OpEntry = m_Storage->AppendOp(Buffer, OpCoreHash, KeyHash); + + RwLock::ExclusiveLockScope OplogLock(m_OplogLock); + const uint32_t EntryId = RegisterOplogEntry(OplogLock, Mapping, OpEntry, kUpdateNewEntry); + + return EntryId; +} + +////////////////////////////////////////////////////////////////////////// + +ProjectStore::Project::Project(ProjectStore* PrjStore, CidStore& Store, std::filesystem::path BasePath) +: m_ProjectStore(PrjStore) +, m_CidStore(Store) +, m_OplogStoragePath(BasePath) +{ +} + +ProjectStore::Project::~Project() +{ +} + +bool +ProjectStore::Project::Exists(std::filesystem::path BasePath) +{ + return std::filesystem::exists(BasePath / "Project.zcb"); +} + +void +ProjectStore::Project::Read() +{ + using namespace std::literals; + + std::filesystem::path ProjectStateFilePath = m_OplogStoragePath / "Project.zcb"sv; + + ZEN_INFO("reading config for project '{}' from {}", Identifier, ProjectStateFilePath); + + BasicFile Blob; + Blob.Open(ProjectStateFilePath, BasicFile::Mode::kRead); + + IoBuffer Obj = Blob.ReadAll(); + CbValidateError ValidationError = ValidateCompactBinary(MemoryView(Obj.Data(), Obj.Size()), CbValidateMode::All); + + if (ValidationError == CbValidateError::None) + { + CbObject Cfg = LoadCompactBinaryObject(Obj); + + Identifier = Cfg["id"sv].AsString(); + RootDir = Cfg["root"sv].AsString(); + ProjectRootDir = Cfg["project"sv].AsString(); + EngineRootDir = Cfg["engine"sv].AsString(); + ProjectFilePath = Cfg["projectfile"sv].AsString(); + } + else + { + ZEN_ERROR("validation error {} hit for '{}'", int(ValidationError), ProjectStateFilePath); + } +} + +void +ProjectStore::Project::Write() +{ + using namespace std::literals; + + BinaryWriter Mem; + + CbObjectWriter Cfg; + Cfg << "id"sv << Identifier; + Cfg << "root"sv << PathToUtf8(RootDir); + Cfg << "project"sv << ProjectRootDir; + Cfg << "engine"sv << EngineRootDir; + Cfg << "projectfile"sv << ProjectFilePath; + + Cfg.Save(Mem); + + CreateDirectories(m_OplogStoragePath); + + std::filesystem::path ProjectStateFilePath = m_OplogStoragePath / "Project.zcb"sv; + + ZEN_INFO("persisting config for project '{}' to {}", Identifier, ProjectStateFilePath); + + BasicFile Blob; + Blob.Open(ProjectStateFilePath, BasicFile::Mode::kTruncate); + Blob.Write(Mem.Data(), Mem.Size(), 0); + Blob.Flush(); +} + +spdlog::logger& +ProjectStore::Project::Log() +{ + return m_ProjectStore->Log(); +} + +std::filesystem::path +ProjectStore::Project::BasePathForOplog(std::string_view OplogId) +{ + return m_OplogStoragePath / OplogId; +} + +ProjectStore::Oplog* +ProjectStore::Project::NewOplog(std::string_view OplogId, const std::filesystem::path& MarkerPath) +{ + RwLock::ExclusiveLockScope _(m_ProjectLock); + + std::filesystem::path OplogBasePath = BasePathForOplog(OplogId); + + try + { + Oplog* Log = m_Oplogs + .try_emplace(std::string{OplogId}, + std::make_unique<ProjectStore::Oplog>(OplogId, this, m_CidStore, OplogBasePath, MarkerPath)) + .first->second.get(); + + Log->Write(); + return Log; + } + catch (std::exception&) + { + // In case of failure we need to ensure there's no half constructed entry around + // + // (This is probably already ensured by the try_emplace implementation?) + + m_Oplogs.erase(std::string{OplogId}); + + return nullptr; + } +} + +ProjectStore::Oplog* +ProjectStore::Project::OpenOplog(std::string_view OplogId) +{ + { + RwLock::SharedLockScope _(m_ProjectLock); + + auto OplogIt = m_Oplogs.find(std::string(OplogId)); + + if (OplogIt != m_Oplogs.end()) + { + return OplogIt->second.get(); + } + } + + RwLock::ExclusiveLockScope _(m_ProjectLock); + + std::filesystem::path OplogBasePath = BasePathForOplog(OplogId); + + if (Oplog::ExistsAt(OplogBasePath)) + { + // Do open of existing oplog + + try + { + Oplog* Log = + m_Oplogs + .try_emplace(std::string{OplogId}, + std::make_unique<ProjectStore::Oplog>(OplogId, this, m_CidStore, OplogBasePath, std::filesystem::path{})) + .first->second.get(); + Log->Read(); + + return Log; + } + catch (std::exception& ex) + { + ZEN_WARN("failed to open oplog '{}' @ '{}': {}", OplogId, OplogBasePath, ex.what()); + + m_Oplogs.erase(std::string{OplogId}); + } + } + + return nullptr; +} + +void +ProjectStore::Project::DeleteOplog(std::string_view OplogId) +{ + std::filesystem::path DeletePath; + { + RwLock::ExclusiveLockScope _(m_ProjectLock); + + auto OplogIt = m_Oplogs.find(std::string(OplogId)); + + if (OplogIt != m_Oplogs.end()) + { + std::unique_ptr<Oplog>& Oplog = OplogIt->second; + DeletePath = Oplog->PrepareForDelete(true); + m_DeletedOplogs.emplace_back(std::move(Oplog)); + m_Oplogs.erase(OplogIt); + } + } + + // Erase content on disk + if (!DeletePath.empty()) + { + OplogStorage::Delete(DeletePath); + } +} + +std::vector<std::string> +ProjectStore::Project::ScanForOplogs() const +{ + DirectoryContent DirContent; + GetDirectoryContent(m_OplogStoragePath, DirectoryContent::IncludeDirsFlag, DirContent); + std::vector<std::string> Oplogs; + Oplogs.reserve(DirContent.Directories.size()); + for (const std::filesystem::path& DirPath : DirContent.Directories) + { + Oplogs.push_back(DirPath.filename().string()); + } + return Oplogs; +} + +void +ProjectStore::Project::IterateOplogs(std::function<void(const Oplog&)>&& Fn) const +{ + RwLock::SharedLockScope _(m_ProjectLock); + + for (auto& Kv : m_Oplogs) + { + Fn(*Kv.second); + } +} + +void +ProjectStore::Project::IterateOplogs(std::function<void(Oplog&)>&& Fn) +{ + RwLock::SharedLockScope _(m_ProjectLock); + + for (auto& Kv : m_Oplogs) + { + Fn(*Kv.second); + } +} + +void +ProjectStore::Project::Flush() +{ + // We only need to flush oplogs that we have already loaded + IterateOplogs([&](Oplog& Ops) { Ops.Flush(); }); +} + +void +ProjectStore::Project::Scrub(ScrubContext& Ctx) +{ + // Scrubbing needs to check all existing oplogs + std::vector<std::string> OpLogs = ScanForOplogs(); + for (const std::string& OpLogId : OpLogs) + { + OpenOplog(OpLogId); + } + IterateOplogs([&](const Oplog& Ops) { + if (!Ops.IsExpired()) + { + Ops.Scrub(Ctx); + } + }); +} + +void +ProjectStore::Project::GatherReferences(GcContext& GcCtx) +{ + ZEN_TRACE_CPU("ProjectStore::Project::GatherReferences"); + + Stopwatch Timer; + const auto Guard = MakeGuard([&] { + ZEN_DEBUG("gathered references from project store project {} in {}", Identifier, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + // GatherReferences needs to check all existing oplogs + std::vector<std::string> OpLogs = ScanForOplogs(); + for (const std::string& OpLogId : OpLogs) + { + OpenOplog(OpLogId); + } + IterateOplogs([&](Oplog& Ops) { + if (!Ops.IsExpired()) + { + Ops.GatherReferences(GcCtx); + } + }); +} + +uint64_t +ProjectStore::Project::TotalSize() const +{ + uint64_t Result = 0; + { + RwLock::SharedLockScope _(m_ProjectLock); + for (const auto& It : m_Oplogs) + { + Result += It.second->TotalSize(); + } + } + return Result; +} + +bool +ProjectStore::Project::PrepareForDelete(std::filesystem::path& OutDeletePath) +{ + RwLock::ExclusiveLockScope _(m_ProjectLock); + + for (auto& It : m_Oplogs) + { + // We don't care about the moved folder + It.second->PrepareForDelete(false); + m_DeletedOplogs.emplace_back(std::move(It.second)); + } + + m_Oplogs.clear(); + + bool Success = PrepareDirectoryDelete(m_OplogStoragePath, OutDeletePath); + if (!Success) + { + return false; + } + m_OplogStoragePath.clear(); + return true; +} + +bool +ProjectStore::Project::IsExpired() const +{ + if (ProjectFilePath.empty()) + { + return false; + } + return !std::filesystem::exists(ProjectFilePath); +} + +////////////////////////////////////////////////////////////////////////// + +ProjectStore::ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcManager& Gc) +: GcStorage(Gc) +, GcContributor(Gc) +, m_Log(logging::Get("project")) +, m_CidStore(Store) +, m_ProjectBasePath(BasePath) +{ + ZEN_INFO("initializing project store at '{}'", BasePath); + // m_Log.set_level(spdlog::level::debug); +} + +ProjectStore::~ProjectStore() +{ + ZEN_INFO("closing project store ('{}')", m_ProjectBasePath); +} + +std::filesystem::path +ProjectStore::BasePathForProject(std::string_view ProjectId) +{ + return m_ProjectBasePath / ProjectId; +} + +void +ProjectStore::DiscoverProjects() +{ + if (!std::filesystem::exists(m_ProjectBasePath)) + { + return; + } + + DirectoryContent DirContent; + GetDirectoryContent(m_ProjectBasePath, DirectoryContent::IncludeDirsFlag, DirContent); + + for (const std::filesystem::path& DirPath : DirContent.Directories) + { + std::string DirName = PathToUtf8(DirPath.filename()); + OpenProject(DirName); + } +} + +void +ProjectStore::IterateProjects(std::function<void(Project& Prj)>&& Fn) +{ + RwLock::SharedLockScope _(m_ProjectsLock); + + for (auto& Kv : m_Projects) + { + Fn(*Kv.second.Get()); + } +} + +void +ProjectStore::Flush() +{ + std::vector<Ref<Project>> Projects; + { + RwLock::SharedLockScope _(m_ProjectsLock); + Projects.reserve(m_Projects.size()); + + for (auto& Kv : m_Projects) + { + Projects.push_back(Kv.second); + } + } + for (const Ref<Project>& Project : Projects) + { + Project->Flush(); + } +} + +void +ProjectStore::Scrub(ScrubContext& Ctx) +{ + DiscoverProjects(); + + std::vector<Ref<Project>> Projects; + { + RwLock::SharedLockScope _(m_ProjectsLock); + Projects.reserve(m_Projects.size()); + + for (auto& Kv : m_Projects) + { + if (Kv.second->IsExpired()) + { + continue; + } + Projects.push_back(Kv.second); + } + } + for (const Ref<Project>& Project : Projects) + { + Project->Scrub(Ctx); + } +} + +void +ProjectStore::GatherReferences(GcContext& GcCtx) +{ + ZEN_TRACE_CPU("ProjectStore::GatherReferences"); + + size_t ProjectCount = 0; + size_t ExpiredProjectCount = 0; + Stopwatch Timer; + const auto Guard = MakeGuard([&] { + ZEN_DEBUG("gathered references from '{}' in {}, found {} active projects and {} expired projects", + m_ProjectBasePath.string(), + NiceTimeSpanMs(Timer.GetElapsedTimeMs()), + ProjectCount, + ExpiredProjectCount); + }); + + DiscoverProjects(); + + std::vector<Ref<Project>> Projects; + { + RwLock::SharedLockScope _(m_ProjectsLock); + Projects.reserve(m_Projects.size()); + + for (auto& Kv : m_Projects) + { + if (Kv.second->IsExpired()) + { + ExpiredProjectCount++; + continue; + } + Projects.push_back(Kv.second); + } + } + ProjectCount = Projects.size(); + for (const Ref<Project>& Project : Projects) + { + Project->GatherReferences(GcCtx); + } +} + +void +ProjectStore::CollectGarbage(GcContext& GcCtx) +{ + ZEN_TRACE_CPU("ProjectStore::CollectGarbage"); + + size_t ProjectCount = 0; + size_t ExpiredProjectCount = 0; + + Stopwatch Timer; + const auto Guard = MakeGuard([&] { + ZEN_DEBUG("garbage collect from '{}' DONE after {}, found {} active projects and {} expired projects", + m_ProjectBasePath.string(), + NiceTimeSpanMs(Timer.GetElapsedTimeMs()), + ProjectCount, + ExpiredProjectCount); + }); + std::vector<Ref<Project>> ExpiredProjects; + std::vector<Ref<Project>> Projects; + + { + RwLock::SharedLockScope _(m_ProjectsLock); + for (auto& Kv : m_Projects) + { + if (Kv.second->IsExpired()) + { + ExpiredProjects.push_back(Kv.second); + ExpiredProjectCount++; + continue; + } + Projects.push_back(Kv.second); + ProjectCount++; + } + } + + if (!GcCtx.IsDeletionMode()) + { + ZEN_DEBUG("garbage collect DISABLED, for '{}' ", m_ProjectBasePath.string()); + return; + } + + for (const Ref<Project>& Project : Projects) + { + std::vector<std::string> ExpiredOplogs; + { + RwLock::ExclusiveLockScope _(m_ProjectsLock); + Project->IterateOplogs([&ExpiredOplogs](ProjectStore::Oplog& Oplog) { + if (Oplog.IsExpired()) + { + ExpiredOplogs.push_back(Oplog.OplogId()); + } + }); + } + for (const std::string& OplogId : ExpiredOplogs) + { + ZEN_DEBUG("ProjectStore::CollectGarbage garbage collected oplog '{}' in project '{}'. Removing storage on disk", + OplogId, + Project->Identifier); + Project->DeleteOplog(OplogId); + } + } + + if (ExpiredProjects.empty()) + { + ZEN_DEBUG("garbage collect for '{}', no expired projects found", m_ProjectBasePath.string()); + return; + } + + for (const Ref<Project>& Project : ExpiredProjects) + { + std::filesystem::path PathToRemove; + std::string ProjectId; + { + RwLock::ExclusiveLockScope _(m_ProjectsLock); + if (!Project->IsExpired()) + { + ZEN_DEBUG("ProjectStore::CollectGarbage skipped garbage collect of project '{}'. Project no longer expired.", ProjectId); + continue; + } + bool Success = Project->PrepareForDelete(PathToRemove); + if (!Success) + { + ZEN_DEBUG("ProjectStore::CollectGarbage skipped garbage collect of project '{}'. Project folder is locked.", ProjectId); + continue; + } + m_Projects.erase(Project->Identifier); + ProjectId = Project->Identifier; + } + + ZEN_DEBUG("ProjectStore::CollectGarbage garbage collected project '{}'. Removing storage on disk", ProjectId); + if (PathToRemove.empty()) + { + continue; + } + + DeleteDirectories(PathToRemove); + } +} + +GcStorageSize +ProjectStore::StorageSize() const +{ + GcStorageSize Result; + { + RwLock::SharedLockScope _(m_ProjectsLock); + for (auto& Kv : m_Projects) + { + const Ref<Project>& Project = Kv.second; + Result.DiskSize += Project->TotalSize(); + } + } + return Result; +} + +Ref<ProjectStore::Project> +ProjectStore::OpenProject(std::string_view ProjectId) +{ + { + RwLock::SharedLockScope _(m_ProjectsLock); + + auto ProjIt = m_Projects.find(std::string{ProjectId}); + + if (ProjIt != m_Projects.end()) + { + return ProjIt->second; + } + } + + RwLock::ExclusiveLockScope _(m_ProjectsLock); + + std::filesystem::path BasePath = BasePathForProject(ProjectId); + + if (Project::Exists(BasePath)) + { + try + { + ZEN_INFO("opening project {} @ {}", ProjectId, BasePath); + + Ref<Project>& Prj = + m_Projects + .try_emplace(std::string{ProjectId}, Ref<ProjectStore::Project>(new ProjectStore::Project(this, m_CidStore, BasePath))) + .first->second; + Prj->Identifier = ProjectId; + Prj->Read(); + return Prj; + } + catch (std::exception& e) + { + ZEN_WARN("failed to open {} @ {} ({})", ProjectId, BasePath, e.what()); + m_Projects.erase(std::string{ProjectId}); + } + } + + return {}; +} + +Ref<ProjectStore::Project> +ProjectStore::NewProject(std::filesystem::path BasePath, + std::string_view ProjectId, + std::string_view RootDir, + std::string_view EngineRootDir, + std::string_view ProjectRootDir, + std::string_view ProjectFilePath) +{ + RwLock::ExclusiveLockScope _(m_ProjectsLock); + + Ref<Project>& Prj = + m_Projects.try_emplace(std::string{ProjectId}, Ref<ProjectStore::Project>(new ProjectStore::Project(this, m_CidStore, BasePath))) + .first->second; + Prj->Identifier = ProjectId; + Prj->RootDir = RootDir; + Prj->EngineRootDir = EngineRootDir; + Prj->ProjectRootDir = ProjectRootDir; + Prj->ProjectFilePath = ProjectFilePath; + Prj->Write(); + + return Prj; +} + +bool +ProjectStore::DeleteProject(std::string_view ProjectId) +{ + ZEN_INFO("deleting project {}", ProjectId); + + RwLock::ExclusiveLockScope ProjectsLock(m_ProjectsLock); + + auto ProjIt = m_Projects.find(std::string{ProjectId}); + + if (ProjIt == m_Projects.end()) + { + return true; + } + + std::filesystem::path DeletePath; + bool Success = ProjIt->second->PrepareForDelete(DeletePath); + + if (!Success) + { + return false; + } + m_Projects.erase(ProjIt); + ProjectsLock.ReleaseNow(); + + if (!DeletePath.empty()) + { + DeleteDirectories(DeletePath); + } + return true; +} + +bool +ProjectStore::Exists(std::string_view ProjectId) +{ + return Project::Exists(BasePathForProject(ProjectId)); +} + +CbArray +ProjectStore::GetProjectsList() +{ + using namespace std::literals; + + DiscoverProjects(); + + CbWriter Response; + Response.BeginArray(); + + IterateProjects([&Response](ProjectStore::Project& Prj) { + Response.BeginObject(); + Response << "Id"sv << Prj.Identifier; + Response << "RootDir"sv << Prj.RootDir.string(); + Response << "ProjectRootDir"sv << Prj.ProjectRootDir; + Response << "EngineRootDir"sv << Prj.EngineRootDir; + Response << "ProjectFilePath"sv << Prj.ProjectFilePath; + Response.EndObject(); + }); + Response.EndArray(); + return Response.Save().AsArray(); +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::GetProjectFiles(const std::string_view ProjectId, const std::string_view OplogId, bool FilterClient, CbObject& OutPayload) +{ + using namespace std::literals; + + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Project files request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return {HttpResponseCode::NotFound, fmt::format("Project files for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + CbObjectWriter Response; + Response.BeginArray("files"sv); + + FoundLog->IterateFileMap([&](const Oid& Id, const std::string_view& ServerPath, const std::string_view& ClientPath) { + Response.BeginObject(); + Response << "id"sv << Id; + Response << "clientpath"sv << ClientPath; + if (!FilterClient) + { + Response << "serverpath"sv << ServerPath; + } + Response.EndObject(); + }); + + Response.EndArray(); + OutPayload = Response.Save(); + return {HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::GetChunkInfo(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view ChunkId, + CbObject& OutPayload) +{ + using namespace std::literals; + + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk info request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk info request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + if (ChunkId.size() != 2 * sizeof(Oid::OidBits)) + { + return {HttpResponseCode::BadRequest, + fmt::format("Chunk info request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId)}; + } + + const Oid Obj = Oid::FromHexString(ChunkId); + + IoBuffer Chunk = FoundLog->FindChunk(Obj); + if (!Chunk) + { + return {HttpResponseCode::NotFound, {}}; + } + + uint64_t ChunkSize = Chunk.GetSize(); + if (Chunk.GetContentType() == HttpContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + bool IsCompressed = CompressedBuffer::ValidateCompressedHeader(Chunk, RawHash, RawSize); + ZEN_ASSERT(IsCompressed); + ChunkSize = RawSize; + } + + CbObjectWriter Response; + Response << "size"sv << ChunkSize; + OutPayload = Response.Save(); + return {HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::GetChunkRange(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view ChunkId, + uint64_t Offset, + uint64_t Size, + ZenContentType AcceptType, + IoBuffer& OutChunk) +{ + bool IsOffset = Offset != 0 || Size != ~(0ull); + + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + if (ChunkId.size() != 2 * sizeof(Oid::OidBits)) + { + return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId)}; + } + + const Oid Obj = Oid::FromHexString(ChunkId); + + IoBuffer Chunk = FoundLog->FindChunk(Obj); + if (!Chunk) + { + return {HttpResponseCode::NotFound, {}}; + } + + OutChunk = Chunk; + HttpContentType ContentType = Chunk.GetContentType(); + + if (Chunk.GetContentType() == HttpContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(std::move(Chunk)), RawHash, RawSize); + ZEN_ASSERT(!Compressed.IsNull()); + + if (IsOffset) + { + if ((Offset + Size) > RawSize) + { + Size = RawSize - Offset; + } + + if (AcceptType == HttpContentType::kBinary) + { + OutChunk = Compressed.Decompress(Offset, Size).AsIoBuffer(); + OutChunk.SetContentType(HttpContentType::kBinary); + } + else + { + // Value will be a range of compressed blocks that covers the requested range + // The client will have to compensate for any offsets that do not land on an even block size multiple + OutChunk = Compressed.CopyRange(Offset, Size).GetCompressed().Flatten().AsIoBuffer(); + OutChunk.SetContentType(HttpContentType::kCompressedBinary); + } + } + else + { + if (AcceptType == HttpContentType::kBinary) + { + OutChunk = Compressed.Decompress().AsIoBuffer(); + OutChunk.SetContentType(HttpContentType::kBinary); + } + else + { + OutChunk = Compressed.GetCompressed().Flatten().AsIoBuffer(); + OutChunk.SetContentType(HttpContentType::kCompressedBinary); + } + } + } + else if (IsOffset) + { + if ((Offset + Size) > Chunk.GetSize()) + { + Size = Chunk.GetSize() - Offset; + } + OutChunk = IoBuffer(std::move(Chunk), Offset, Size); + OutChunk.SetContentType(ContentType); + } + + return {HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::GetChunk(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view Cid, + ZenContentType AcceptType, + IoBuffer& OutChunk) +{ + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + if (Cid.length() != IoHash::StringLength) + { + return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, Cid)}; + } + + const IoHash Hash = IoHash::FromHexString(Cid); + OutChunk = m_CidStore.FindChunkByCid(Hash); + + if (!OutChunk) + { + return {HttpResponseCode::NotFound, fmt::format("chunk - '{}' MISSING", Cid)}; + } + + if (AcceptType == ZenContentType::kUnknownContentType || AcceptType == ZenContentType::kBinary) + { + CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(OutChunk)); + OutChunk = Compressed.Decompress().AsIoBuffer(); + OutChunk.SetContentType(ZenContentType::kBinary); + } + else + { + OutChunk.SetContentType(ZenContentType::kCompressedBinary); + } + return {HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::PutChunk(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view Cid, + ZenContentType ContentType, + IoBuffer&& Chunk) +{ + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk put request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return {HttpResponseCode::NotFound, fmt::format("Chunk put request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + if (Cid.length() != IoHash::StringLength) + { + return {HttpResponseCode::BadRequest, fmt::format("Chunk put request for invalid chunk hash '{}'", Cid)}; + } + + const IoHash Hash = IoHash::FromHexString(Cid); + + if (ContentType != HttpContentType::kCompressedBinary) + { + return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid content type for chunk '{}'", Cid)}; + } + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), RawHash, RawSize); + if (RawHash != Hash) + { + return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid payload format for chunk '{}'", Cid)}; + } + + CidStore::InsertResult Result = m_CidStore.AddChunk(Chunk, Hash); + return {Result.New ? HttpResponseCode::Created : HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::WriteOplog(const std::string_view ProjectId, const std::string_view OplogId, IoBuffer&& Payload, CbObject& OutResponse) +{ + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Write oplog request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId); + + if (!Oplog) + { + return {HttpResponseCode::NotFound, fmt::format("Write oplog request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + CbObject ContainerObject = LoadCompactBinaryObject(Payload); + if (!ContainerObject) + { + return {HttpResponseCode::BadRequest, "Invalid payload format"}; + } + + CidStore& ChunkStore = m_CidStore; + RwLock AttachmentsLock; + std::unordered_set<IoHash, IoHash::Hasher> Attachments; + + auto HasAttachment = [&ChunkStore](const IoHash& RawHash) { return ChunkStore.ContainsChunk(RawHash); }; + auto OnNeedBlock = [&AttachmentsLock, &Attachments](const IoHash& BlockHash, const std::vector<IoHash>&& ChunkHashes) { + RwLock::ExclusiveLockScope _(AttachmentsLock); + if (BlockHash != IoHash::Zero) + { + Attachments.insert(BlockHash); + } + else + { + Attachments.insert(ChunkHashes.begin(), ChunkHashes.end()); + } + }; + auto OnNeedAttachment = [&AttachmentsLock, &Attachments](const IoHash& RawHash) { + RwLock::ExclusiveLockScope _(AttachmentsLock); + Attachments.insert(RawHash); + }; + + RemoteProjectStore::Result RemoteResult = SaveOplogContainer(*Oplog, ContainerObject, HasAttachment, OnNeedBlock, OnNeedAttachment); + + if (RemoteResult.ErrorCode) + { + return ConvertResult(RemoteResult); + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + { + for (const IoHash& Hash : Attachments) + { + ZEN_DEBUG("Need attachment {}", Hash); + Cbo << Hash; + } + } + Cbo.EndArray(); // "need" + + OutResponse = Cbo.Save(); + return {HttpResponseCode::OK, {}}; +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::ReadOplog(const std::string_view ProjectId, + const std::string_view OplogId, + const HttpServerRequest::QueryParams& Params, + CbObject& OutResponse) +{ + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Read oplog request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId); + + if (!Oplog) + { + return {HttpResponseCode::NotFound, fmt::format("Read oplog request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + size_t MaxBlockSize = 128u * 1024u * 1024u; + if (auto Param = Params.GetValue("maxblocksize"); Param.empty() == false) + { + if (auto Value = ParseInt<size_t>(Param)) + { + MaxBlockSize = Value.value(); + } + } + size_t MaxChunkEmbedSize = 1024u * 1024u; + if (auto Param = Params.GetValue("maxchunkembedsize"); Param.empty() == false) + { + if (auto Value = ParseInt<size_t>(Param)) + { + MaxChunkEmbedSize = Value.value(); + } + } + + CidStore& ChunkStore = m_CidStore; + + RemoteProjectStore::LoadContainerResult ContainerResult = BuildContainer( + ChunkStore, + *Oplog, + MaxBlockSize, + MaxChunkEmbedSize, + false, + [](CompressedBuffer&&, const IoHash) {}, + [](const IoHash&) {}, + [](const std::unordered_set<IoHash, IoHash::Hasher>) {}); + + OutResponse = std::move(ContainerResult.ContainerObject); + return ConvertResult(ContainerResult); +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::WriteBlock(const std::string_view ProjectId, const std::string_view OplogId, IoBuffer&& Payload) +{ + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return {HttpResponseCode::NotFound, fmt::format("Write block request for unknown project '{}'", ProjectId)}; + } + + ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId); + + if (!Oplog) + { + return {HttpResponseCode::NotFound, fmt::format("Write block request for unknown oplog '{}/{}'", ProjectId, OplogId)}; + } + + if (!IterateBlock(std::move(Payload), [this](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) { + IoBuffer Compressed = Chunk.GetCompressed().Flatten().AsIoBuffer(); + m_CidStore.AddChunk(Compressed, AttachmentRawHash); + ZEN_DEBUG("Saved attachment {} from block, size {}", AttachmentRawHash, Compressed.GetSize()); + })) + { + return {HttpResponseCode::BadRequest, "Invalid chunk in block"}; + } + + return {HttpResponseCode::OK, {}}; +} + +void +ProjectStore::Rpc(HttpServerRequest& HttpReq, + const std::string_view ProjectId, + const std::string_view OplogId, + IoBuffer&& Payload, + AuthMgr& AuthManager) +{ + using namespace std::literals; + HttpContentType PayloadContentType = HttpReq.RequestContentType(); + CbPackage Package; + CbObject Cb; + switch (PayloadContentType) + { + case HttpContentType::kJSON: + case HttpContentType::kUnknownContentType: + case HttpContentType::kText: + { + std::string JsonText(reinterpret_cast<const char*>(Payload.GetData()), Payload.GetSize()); + Cb = LoadCompactBinaryFromJson(JsonText).AsObject(); + if (!Cb) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Content format not supported, expected JSON format"); + } + } + break; + case HttpContentType::kCbObject: + Cb = LoadCompactBinaryObject(Payload); + if (!Cb) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Content format not supported, expected compact binary format"); + } + break; + case HttpContentType::kCbPackage: + Package = ParsePackageMessage(Payload); + Cb = Package.GetObject(); + if (!Cb) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, + HttpContentType::kText, + "Content format not supported, expected package message format"); + } + break; + default: + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid request content type"); + } + + Ref<ProjectStore::Project> Project = OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("Rpc oplog request for unknown project '{}'", ProjectId)); + } + + ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId); + + if (!Oplog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("Rpc oplog request for unknown oplog '{}/{}'", ProjectId, OplogId)); + } + + std::string_view Method = Cb["method"sv].AsString(); + + if (Method == "import") + { + std::pair<HttpResponseCode, std::string> Result = Import(*Project.Get(), *Oplog, Cb["params"sv].AsObjectView(), AuthManager); + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + } + else if (Method == "export") + { + std::pair<HttpResponseCode, std::string> Result = Export(*Project.Get(), *Oplog, Cb["params"sv].AsObjectView(), AuthManager); + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + } + else if (Method == "getchunks") + { + CbPackage ResponsePackage; + { + CbArrayView ChunksArray = Cb["chunks"sv].AsArrayView(); + CbObjectWriter ResponseWriter; + ResponseWriter.BeginArray("chunks"sv); + for (CbFieldView FieldView : ChunksArray) + { + IoHash RawHash = FieldView.AsHash(); + IoBuffer ChunkBuffer = m_CidStore.FindChunkByCid(RawHash); + if (ChunkBuffer) + { + ResponseWriter.AddHash(RawHash); + ResponsePackage.AddAttachment( + CbAttachment(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkBuffer)), RawHash)); + } + } + ResponseWriter.EndArray(); + ResponsePackage.SetObject(ResponseWriter.Save()); + } + CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(ResponsePackage, FormatFlags::kDefault); + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer); + } + else if (Method == "putchunks") + { + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + for (const CbAttachment& Attachment : Attachments) + { + IoHash RawHash = Attachment.GetHash(); + CompressedBuffer Compressed = Attachment.AsCompressedBinary(); + m_CidStore.AddChunk(Compressed.GetCompressed().Flatten().AsIoBuffer(), RawHash, CidStore::InsertMode::kCopyOnly); + } + return HttpReq.WriteResponse(HttpResponseCode::OK); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, fmt::format("Unknown rpc method '{}'", Method)); +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::Export(ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, CbObjectView&& Params, AuthMgr& AuthManager) +{ + using namespace std::literals; + + size_t MaxBlockSize = Params["maxblocksize"sv].AsUInt64(128u * 1024u * 1024u); + size_t MaxChunkEmbedSize = Params["maxchunkembedsize"sv].AsUInt64(1024u * 1024u); + bool Force = Params["force"sv].AsBool(false); + + std::pair<std::unique_ptr<RemoteProjectStore>, std::string> RemoteStoreResult = + CreateRemoteStore(Params, AuthManager, MaxBlockSize, MaxChunkEmbedSize); + + if (RemoteStoreResult.first == nullptr) + { + return {HttpResponseCode::BadRequest, RemoteStoreResult.second}; + } + std::unique_ptr<RemoteProjectStore> RemoteStore = std::move(RemoteStoreResult.first); + RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo(); + + ZEN_INFO("Saving oplog '{}/{}' to {}, maxblocksize {}, maxchunkembedsize {}", + Project.Identifier, + Oplog.OplogId(), + StoreInfo.Description, + NiceBytes(MaxBlockSize), + NiceBytes(MaxChunkEmbedSize)); + + RemoteProjectStore::Result Result = SaveOplog(m_CidStore, + *RemoteStore, + Oplog, + MaxBlockSize, + MaxChunkEmbedSize, + StoreInfo.CreateBlocks, + StoreInfo.UseTempBlockFiles, + Force); + + return ConvertResult(Result); +} + +std::pair<HttpResponseCode, std::string> +ProjectStore::Import(ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, CbObjectView&& Params, AuthMgr& AuthManager) +{ + using namespace std::literals; + + size_t MaxBlockSize = Params["maxblocksize"sv].AsUInt64(128u * 1024u * 1024u); + size_t MaxChunkEmbedSize = Params["maxchunkembedsize"sv].AsUInt64(1024u * 1024u); + bool Force = Params["force"sv].AsBool(false); + + std::pair<std::unique_ptr<RemoteProjectStore>, std::string> RemoteStoreResult = + CreateRemoteStore(Params, AuthManager, MaxBlockSize, MaxChunkEmbedSize); + + if (RemoteStoreResult.first == nullptr) + { + return {HttpResponseCode::BadRequest, RemoteStoreResult.second}; + } + std::unique_ptr<RemoteProjectStore> RemoteStore = std::move(RemoteStoreResult.first); + RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo(); + + ZEN_INFO("Loading oplog '{}/{}' from {}", Project.Identifier, Oplog.OplogId(), StoreInfo.Description); + RemoteProjectStore::Result Result = LoadOplog(m_CidStore, *RemoteStore, Oplog, Force); + return ConvertResult(Result); +} + +////////////////////////////////////////////////////////////////////////// + +HttpProjectService::HttpProjectService(CidStore& Store, ProjectStore* Projects, HttpStatsService& StatsService, AuthMgr& AuthMgr) +: m_Log(logging::Get("project")) +, m_CidStore(Store) +, m_ProjectStore(Projects) +, m_StatsService(StatsService) +, m_AuthMgr(AuthMgr) +{ + using namespace std::literals; + + m_StatsService.RegisterHandler("prj", *this); + + m_Router.AddPattern("project", "([[:alnum:]_.]+)"); + m_Router.AddPattern("log", "([[:alnum:]_.]+)"); + m_Router.AddPattern("op", "([[:digit:]]+?)"); + m_Router.AddPattern("chunk", "([[:xdigit:]]{24})"); + m_Router.AddPattern("hash", "([[:xdigit:]]{40})"); + + m_Router.RegisterRoute( + "", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_ProjectStore->GetProjectsList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "list", + [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_ProjectStore->GetProjectsList()); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/batch", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + // Parse Request + + IoBuffer Payload = HttpReq.ReadPayload(); + BinaryReader Reader(Payload); + + struct RequestHeader + { + enum + { + kMagic = 0xAAAA'77AC + }; + uint32_t Magic; + uint32_t ChunkCount; + uint32_t Reserved1; + uint32_t Reserved2; + }; + + struct RequestChunkEntry + { + Oid ChunkId; + uint32_t CorrelationId; + uint64_t Offset; + uint64_t RequestBytes; + }; + + if (Payload.Size() <= sizeof(RequestHeader)) + { + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + RequestHeader RequestHdr; + Reader.Read(&RequestHdr, sizeof RequestHdr); + + if (RequestHdr.Magic != RequestHeader::kMagic) + { + HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + std::vector<RequestChunkEntry> RequestedChunks; + RequestedChunks.resize(RequestHdr.ChunkCount); + Reader.Read(RequestedChunks.data(), sizeof(RequestChunkEntry) * RequestHdr.ChunkCount); + + // Make Response + + struct ResponseHeader + { + uint32_t Magic = 0xbada'b00f; + uint32_t ChunkCount; + uint32_t Reserved1 = 0; + uint32_t Reserved2 = 0; + }; + + struct ResponseChunkEntry + { + uint32_t CorrelationId; + uint32_t Flags = 0; + uint64_t ChunkSize; + }; + + std::vector<IoBuffer> OutBlobs; + OutBlobs.emplace_back(sizeof(ResponseHeader) + RequestHdr.ChunkCount * sizeof(ResponseChunkEntry)); + for (uint32_t ChunkIndex = 0; ChunkIndex < RequestHdr.ChunkCount; ++ChunkIndex) + { + const RequestChunkEntry& RequestedChunk = RequestedChunks[ChunkIndex]; + IoBuffer FoundChunk = FoundLog->FindChunk(RequestedChunk.ChunkId); + if (FoundChunk) + { + if (RequestedChunk.Offset > 0 || RequestedChunk.RequestBytes < uint64_t(-1)) + { + uint64_t Offset = RequestedChunk.Offset; + if (Offset > FoundChunk.Size()) + { + Offset = FoundChunk.Size(); + } + uint64_t Size = RequestedChunk.RequestBytes; + if ((Offset + Size) > FoundChunk.Size()) + { + Size = FoundChunk.Size() - Offset; + } + FoundChunk = IoBuffer(FoundChunk, Offset, Size); + } + } + OutBlobs.emplace_back(std::move(FoundChunk)); + } + uint8_t* ResponsePtr = reinterpret_cast<uint8_t*>(OutBlobs[0].MutableData()); + ResponseHeader ResponseHdr; + ResponseHdr.ChunkCount = RequestHdr.ChunkCount; + memcpy(ResponsePtr, &ResponseHdr, sizeof(ResponseHdr)); + ResponsePtr += sizeof(ResponseHdr); + for (uint32_t ChunkIndex = 0; ChunkIndex < RequestHdr.ChunkCount; ++ChunkIndex) + { + // const RequestChunkEntry& RequestedChunk = RequestedChunks[ChunkIndex]; + const IoBuffer& FoundChunk(OutBlobs[ChunkIndex + 1]); + ResponseChunkEntry ResponseChunk; + ResponseChunk.CorrelationId = ChunkIndex; + if (FoundChunk) + { + ResponseChunk.ChunkSize = FoundChunk.Size(); + } + else + { + ResponseChunk.ChunkSize = uint64_t(-1); + } + memcpy(ResponsePtr, &ResponseChunk, sizeof(ResponseChunk)); + ResponsePtr += sizeof(ResponseChunk); + } + return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, OutBlobs); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/files", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + // File manifest fetch, returns the client file list + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + + const bool FilterClient = Params.GetValue("filter"sv) == "client"sv; + + CbObject ResponsePayload; + std::pair<HttpResponseCode, std::string> Result = + m_ProjectStore->GetProjectFiles(ProjectId, OplogId, FilterClient, ResponsePayload); + if (Result.first == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, ResponsePayload); + } + else + { + ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`", + ToString(HttpReq.RequestVerb()), + HttpReq.QueryString(), + static_cast<int>(Result.first), + Result.second); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/{chunk}/info", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + const auto& ChunkId = Req.GetCapture(3); + + CbObject ResponsePayload; + std::pair<HttpResponseCode, std::string> Result = m_ProjectStore->GetChunkInfo(ProjectId, OplogId, ChunkId, ResponsePayload); + if (Result.first == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, ResponsePayload); + } + else if (Result.first == HttpResponseCode::NotFound) + { + ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, ChunkId); + } + else + { + ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`", + ToString(HttpReq.RequestVerb()), + HttpReq.QueryString(), + static_cast<int>(Result.first), + Result.second); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/{chunk}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + const auto& ChunkId = Req.GetCapture(3); + + uint64_t Offset = 0; + uint64_t Size = ~(0ull); + + auto QueryParms = Req.ServerRequest().GetQueryParams(); + + if (auto OffsetParm = QueryParms.GetValue("offset"); OffsetParm.empty() == false) + { + if (auto OffsetVal = ParseInt<uint64_t>(OffsetParm)) + { + Offset = OffsetVal.value(); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + } + + if (auto SizeParm = QueryParms.GetValue("size"); SizeParm.empty() == false) + { + if (auto SizeVal = ParseInt<uint64_t>(SizeParm)) + { + Size = SizeVal.value(); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + } + + HttpContentType AcceptType = HttpReq.AcceptContentType(); + + IoBuffer Chunk; + std::pair<HttpResponseCode, std::string> Result = + m_ProjectStore->GetChunkRange(ProjectId, OplogId, ChunkId, Offset, Size, AcceptType, Chunk); + if (Result.first == HttpResponseCode::OK) + { + ZEN_DEBUG("chunk - '{}/{}/{}' '{}'", ProjectId, OplogId, ChunkId, ToString(Chunk.GetContentType())); + return HttpReq.WriteResponse(HttpResponseCode::OK, Chunk.GetContentType(), Chunk); + } + else if (Result.first == HttpResponseCode::NotFound) + { + ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, ChunkId); + } + else + { + ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`", + ToString(HttpReq.RequestVerb()), + HttpReq.QueryString(), + static_cast<int>(Result.first), + Result.second); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + }, + HttpVerb::kGet | HttpVerb::kHead); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/{hash}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + const auto& Cid = Req.GetCapture(3); + HttpContentType AcceptType = HttpReq.AcceptContentType(); + HttpContentType RequestType = HttpReq.RequestContentType(); + + switch (Req.ServerRequest().RequestVerb()) + { + case HttpVerb::kGet: + { + IoBuffer Value; + std::pair<HttpResponseCode, std::string> Result = + m_ProjectStore->GetChunk(ProjectId, OplogId, Cid, AcceptType, Value); + + if (Result.first == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Value.GetContentType(), Value); + } + else if (Result.first == HttpResponseCode::NotFound) + { + ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, Cid); + } + else + { + ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`", + ToString(HttpReq.RequestVerb()), + HttpReq.QueryString(), + static_cast<int>(Result.first), + Result.second); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + } + case HttpVerb::kPost: + { + std::pair<HttpResponseCode, std::string> Result = + m_ProjectStore->PutChunk(ProjectId, OplogId, Cid, RequestType, HttpReq.ReadPayload()); + if (Result.first == HttpResponseCode::OK || Result.first == HttpResponseCode::Created) + { + return HttpReq.WriteResponse(Result.first); + } + else + { + ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`", + ToString(HttpReq.RequestVerb()), + HttpReq.QueryString(), + static_cast<int>(Result.first), + Result.second); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + } + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/prep", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + // This operation takes a list of referenced hashes and decides which + // chunks are not present on this server. This list is then returned in + // the "need" list in the response + + IoBuffer Payload = HttpReq.ReadPayload(); + CbObject RequestObject = LoadCompactBinaryObject(Payload); + + std::vector<IoHash> NeedList; + + for (auto Entry : RequestObject["have"sv]) + { + const IoHash FileHash = Entry.AsHash(); + + if (!m_CidStore.ContainsChunk(FileHash)) + { + ZEN_DEBUG("prep - NEED: {}", FileHash); + + NeedList.push_back(FileHash); + } + } + + CbObjectWriter Cbo; + Cbo.BeginArray("need"); + + for (const IoHash& Hash : NeedList) + { + Cbo << Hash; + } + + Cbo.EndArray(); + CbObject Response = Cbo.Save(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Response); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/new", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + + bool IsUsingSalt = false; + IoHash SaltHash = IoHash::Zero; + + if (std::string_view SaltParam = Params.GetValue("salt"); SaltParam.empty() == false) + { + const uint32_t Salt = std::stoi(std::string(SaltParam)); + SaltHash = IoHash::HashBuffer(&Salt, sizeof Salt); + IsUsingSalt = true; + } + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog& Oplog = *FoundLog; + + IoBuffer Payload = HttpReq.ReadPayload(); + + // This will attempt to open files which may not exist for the case where + // the prep step rejected the chunk. This should be fixed since there's + // a performance cost associated with any file system activity + + bool IsValid = true; + std::vector<IoHash> MissingChunks; + + CbPackage::AttachmentResolver Resolver = [&](const IoHash& Hash) -> SharedBuffer { + if (m_CidStore.ContainsChunk(Hash)) + { + // Return null attachment as we already have it, no point in reading it and storing it again + return {}; + } + + IoHash AttachmentId; + if (IsUsingSalt) + { + IoHash AttachmentSpec[]{SaltHash, Hash}; + AttachmentId = IoHash::HashBuffer(MakeMemoryView(AttachmentSpec)); + } + else + { + AttachmentId = Hash; + } + + std::filesystem::path AttachmentPath = Oplog.TempPath() / AttachmentId.ToHexString(); + if (IoBuffer Data = IoBufferBuilder::MakeFromTemporaryFile(AttachmentPath)) + { + return SharedBuffer(std::move(Data)); + } + else + { + IsValid = false; + MissingChunks.push_back(Hash); + + return {}; + } + }; + + CbPackage Package; + + if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver)) + { + std::filesystem::path BadPackagePath = + Oplog.TempPath() / "bad_packages"sv / fmt::format("session{}_request{}"sv, HttpReq.SessionId(), HttpReq.RequestId()); + + ZEN_WARN("Received malformed package! Saving payload to '{}'", BadPackagePath); + + WriteFile(BadPackagePath, Payload); + + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid package"); + } + + if (!IsValid) + { + // TODO: emit diagnostics identifying missing chunks + + return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Missing chunk reference"); + } + + CbObject Core = Package.GetObject(); + + if (!Core["key"sv]) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "No oplog entry key specified"); + } + + // Write core to oplog + + const uint32_t OpLsn = Oplog.AppendNewOplogEntry(Package); + + if (OpLsn == ProjectStore::Oplog::kInvalidOp) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest); + } + + ZEN_DEBUG("'{}/{}' op #{} ({}) - '{}'", ProjectId, OplogId, OpLsn, NiceBytes(Payload.Size()), Core["key"sv].AsString()); + + HttpReq.WriteResponse(HttpResponseCode::Created); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/{op}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const std::string& ProjectId = Req.GetCapture(1); + const std::string& OplogId = Req.GetCapture(2); + const std::string& OpIdString = Req.GetCapture(3); + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog& Oplog = *FoundLog; + + if (const std::optional<int32_t> OpId = ParseInt<uint32_t>(OpIdString)) + { + if (std::optional<CbObject> MaybeOp = Oplog.GetOpByIndex(OpId.value())) + { + CbObject& Op = MaybeOp.value(); + if (Req.ServerRequest().AcceptContentType() == ZenContentType::kCbPackage) + { + CbPackage Package; + Package.SetObject(Op); + + Op.IterateAttachments([&](CbFieldView FieldView) { + const IoHash AttachmentHash = FieldView.AsAttachment(); + IoBuffer Payload = m_CidStore.FindChunkByCid(AttachmentHash); + + // We force this for now as content type is not consistently tracked (will + // be fixed in CidStore refactor) + Payload.SetContentType(ZenContentType::kCompressedBinary); + + if (Payload) + { + switch (Payload.GetContentType()) + { + case ZenContentType::kCbObject: + if (CbObject Object = LoadCompactBinaryObject(Payload)) + { + Package.AddAttachment(CbAttachment(Object)); + } + else + { + // Error - malformed object + + ZEN_WARN("malformed object returned for {}", AttachmentHash); + } + break; + + case ZenContentType::kCompressedBinary: + if (CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Payload))) + { + Package.AddAttachment(CbAttachment(Compressed, AttachmentHash)); + } + else + { + // Error - not compressed! + + ZEN_WARN("invalid compressed binary returned for {}", AttachmentHash); + } + break; + + default: + Package.AddAttachment(CbAttachment(SharedBuffer(Payload))); + break; + } + } + }); + + return HttpReq.WriteResponse(HttpResponseCode::Accepted, Package); + } + else + { + // Client cannot accept a package, so we only send the core object + return HttpReq.WriteResponse(HttpResponseCode::Accepted, Op); + } + } + } + + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{project}/oplog/{log}", + [this](HttpRouterRequest& Req) { + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + + if (!Project) + { + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("project {} not found", ProjectId)); + } + + switch (Req.ServerRequest().RequestVerb()) + { + case HttpVerb::kGet: + { + ProjectStore::Oplog* OplogIt = Project->OpenOplog(OplogId); + + if (!OplogIt) + { + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("oplog {} not found in project {}", OplogId, ProjectId)); + } + + ProjectStore::Oplog& Log = *OplogIt; + + CbObjectWriter Cb; + Cb << "id"sv << Log.OplogId() << "project"sv << Project->Identifier << "tempdir"sv << Log.TempPath().c_str() + << "markerpath"sv << Log.MarkerPath().c_str() << "totalsize"sv << Log.TotalSize() << "opcount" + << Log.OplogCount() << "expired"sv << Log.IsExpired(); + + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cb.Save()); + } + break; + + case HttpVerb::kPost: + { + std::filesystem::path OplogMarkerPath; + if (CbObject Params = Req.ServerRequest().ReadPayloadObject()) + { + OplogMarkerPath = Params["gcpath"sv].AsString(); + } + + ProjectStore::Oplog* OplogIt = Project->OpenOplog(OplogId); + + if (!OplogIt) + { + if (!Project->NewOplog(OplogId, OplogMarkerPath)) + { + // TODO: indicate why the operation failed! + return Req.ServerRequest().WriteResponse(HttpResponseCode::InternalServerError); + } + + ZEN_INFO("established oplog '{}/{}', gc marker file at '{}'", ProjectId, OplogId, OplogMarkerPath); + + return Req.ServerRequest().WriteResponse(HttpResponseCode::Created); + } + + // I guess this should ultimately be used to execute RPCs but for now, it + // does absolutely nothing + + return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest); + } + break; + + case HttpVerb::kDelete: + { + ZEN_INFO("deleting oplog '{}/{}'", ProjectId, OplogId); + + Project->DeleteOplog(OplogId); + + return Req.ServerRequest().WriteResponse(HttpResponseCode::OK); + } + break; + + default: + break; + } + }, + HttpVerb::kPost | HttpVerb::kGet | HttpVerb::kDelete); + + m_Router.RegisterRoute( + "{project}/oplog/{log}/entries", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + if (!Project) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + CbObjectWriter Response; + + if (FoundLog->OplogCount() > 0) + { + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + + if (auto OpKey = Params.GetValue("opkey"); !OpKey.empty()) + { + Oid OpKeyId = OpKeyStringAsOId(OpKey); + std::optional<CbObject> Op = FoundLog->GetOpByKey(OpKeyId); + + if (Op.has_value()) + { + Response << "entry"sv << Op.value(); + } + else + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + } + else + { + Response.BeginArray("entries"sv); + + FoundLog->IterateOplog([&Response](CbObject Op) { Response << Op; }); + + Response.EndArray(); + } + } + + return HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "{project}", + [this](HttpRouterRequest& Req) { + const std::string ProjectId = Req.GetCapture(1); + + switch (Req.ServerRequest().RequestVerb()) + { + case HttpVerb::kPost: + { + IoBuffer Payload = Req.ServerRequest().ReadPayload(); + CbObject Params = LoadCompactBinaryObject(Payload); + std::string_view Id = Params["id"sv].AsString(); + std::string_view Root = Params["root"sv].AsString(); + std::string_view EngineRoot = Params["engine"sv].AsString(); + std::string_view ProjectRoot = Params["project"sv].AsString(); + std::string_view ProjectFilePath = Params["projectfile"sv].AsString(); + + const std::filesystem::path BasePath = m_ProjectStore->BasePath() / ProjectId; + m_ProjectStore->NewProject(BasePath, ProjectId, Root, EngineRoot, ProjectRoot, ProjectFilePath); + + ZEN_INFO("established project - {} (id: '{}', roots: '{}', '{}', '{}', '{}'{})", + ProjectId, + Id, + Root, + EngineRoot, + ProjectRoot, + ProjectFilePath, + ProjectFilePath.empty() ? ", project will not be GCd due to empty project file path" : ""); + + Req.ServerRequest().WriteResponse(HttpResponseCode::Created); + } + break; + + case HttpVerb::kGet: + { + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + + if (!Project) + { + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("project {} not found", ProjectId)); + } + + std::vector<std::string> OpLogs = Project->ScanForOplogs(); + + CbObjectWriter Response; + Response << "id"sv << Project->Identifier; + Response << "root"sv << PathToUtf8(Project->RootDir); + Response << "engine"sv << PathToUtf8(Project->EngineRootDir); + Response << "project"sv << PathToUtf8(Project->ProjectRootDir); + Response << "projectfile"sv << PathToUtf8(Project->ProjectFilePath); + + Response.BeginArray("oplogs"sv); + for (const std::string& OplogId : OpLogs) + { + Response.BeginObject(); + Response << "id"sv << OplogId; + Response.EndObject(); + } + Response.EndArray(); // oplogs + + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Response.Save()); + } + break; + + case HttpVerb::kDelete: + { + Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId); + + if (!Project) + { + return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound, + HttpContentType::kText, + fmt::format("project {} not found", ProjectId)); + } + + ZEN_INFO("deleting project '{}'", ProjectId); + if (!m_ProjectStore->DeleteProject(ProjectId)) + { + return Req.ServerRequest().WriteResponse(HttpResponseCode::Locked, + HttpContentType::kText, + fmt::format("project {} is in use", ProjectId)); + } + + return Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent); + } + break; + + default: + break; + } + }, + HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kDelete); + + // Push a oplog container + m_Router.RegisterRoute( + "{project}/oplog/{log}/save", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + if (HttpReq.RequestContentType() != HttpContentType::kCbObject) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid content type"); + } + IoBuffer Payload = Req.ServerRequest().ReadPayload(); + + CbObject Response; + std::pair<HttpResponseCode, std::string> Result = m_ProjectStore->WriteOplog(ProjectId, OplogId, std::move(Payload), Response); + if (Result.first == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Response); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + }, + HttpVerb::kPost); + + // Pull a oplog container + m_Router.RegisterRoute( + "{project}/oplog/{log}/load", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + if (HttpReq.AcceptContentType() != HttpContentType::kCbObject) + { + return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid accept content type"); + } + IoBuffer Payload = Req.ServerRequest().ReadPayload(); + + CbObject Response; + std::pair<HttpResponseCode, std::string> Result = + m_ProjectStore->ReadOplog(ProjectId, OplogId, Req.ServerRequest().GetQueryParams(), Response); + if (Result.first == HttpResponseCode::OK) + { + return HttpReq.WriteResponse(HttpResponseCode::OK, Response); + } + if (Result.second.empty()) + { + return HttpReq.WriteResponse(Result.first); + } + return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second); + }, + HttpVerb::kGet); + + // Do an rpc style operation on project/oplog + m_Router.RegisterRoute( + "{project}/oplog/{log}/rpc", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + IoBuffer Payload = Req.ServerRequest().ReadPayload(); + + m_ProjectStore->Rpc(HttpReq, ProjectId, OplogId, std::move(Payload), m_AuthMgr); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "details\\$", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + bool CSV = Params.GetValue("csv") == "true"; + bool Details = Params.GetValue("details") == "true"; + bool OpDetails = Params.GetValue("opdetails") == "true"; + bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true"; + + if (CSV) + { + ExtendableStringBuilder<4096> CSVWriter; + CSVHeader(Details, AttachmentDetails, CSVWriter); + + m_ProjectStore->IterateProjects([&](ProjectStore::Project& Project) { + Project.IterateOplogs([&](ProjectStore::Oplog& Oplog) { + Oplog.IterateOplogWithKey( + [this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) { + CSVWriteOp(m_CidStore, + Project.Identifier, + Oplog.OplogId(), + Details, + AttachmentDetails, + LSN, + Key, + Op, + CSVWriter); + }); + }); + }); + + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView()); + } + else + { + CbObjectWriter Cbo; + Cbo.BeginArray("projects"); + { + m_ProjectStore->DiscoverProjects(); + + m_ProjectStore->IterateProjects([&](ProjectStore::Project& Project) { + std::vector<std::string> OpLogs = Project.ScanForOplogs(); + CbWriteProject(m_CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo); + }); + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "details\\$/{project}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + bool CSV = Params.GetValue("csv") == "true"; + bool Details = Params.GetValue("details") == "true"; + bool OpDetails = Params.GetValue("opdetails") == "true"; + bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true"; + + Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId); + if (!FoundProject) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + ProjectStore::Project& Project = *FoundProject.Get(); + if (CSV) + { + ExtendableStringBuilder<4096> CSVWriter; + CSVHeader(Details, AttachmentDetails, CSVWriter); + + FoundProject->IterateOplogs([&](ProjectStore::Oplog& Oplog) { + Oplog.IterateOplogWithKey([this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN, + const Oid& Key, + CbObject Op) { + CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, Key, Op, CSVWriter); + }); + }); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView()); + } + else + { + CbObjectWriter Cbo; + std::vector<std::string> OpLogs = FoundProject->ScanForOplogs(); + Cbo.BeginArray("projects"); + { + CbWriteProject(m_CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo); + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "details\\$/{project}/{log}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + bool CSV = Params.GetValue("csv") == "true"; + bool Details = Params.GetValue("details") == "true"; + bool OpDetails = Params.GetValue("opdetails") == "true"; + bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true"; + + Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId); + if (!FoundProject) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + ProjectStore::Oplog* FoundLog = FoundProject->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + ProjectStore::Project& Project = *FoundProject.Get(); + ProjectStore::Oplog& Oplog = *FoundLog; + if (CSV) + { + ExtendableStringBuilder<4096> CSVWriter; + CSVHeader(Details, AttachmentDetails, CSVWriter); + + Oplog.IterateOplogWithKey( + [this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) { + CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, Key, Op, CSVWriter); + }); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView()); + } + else + { + CbObjectWriter Cbo; + Cbo.BeginArray("oplogs"); + { + CbWriteOplog(m_CidStore, Oplog, Details, OpDetails, AttachmentDetails, Cbo); + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "details\\$/{project}/{log}/{chunk}", + [this](HttpRouterRequest& Req) { + HttpServerRequest& HttpReq = Req.ServerRequest(); + const auto& ProjectId = Req.GetCapture(1); + const auto& OplogId = Req.GetCapture(2); + const auto& ChunkId = Req.GetCapture(3); + + HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams(); + bool CSV = Params.GetValue("csv") == "true"; + bool Details = Params.GetValue("details") == "true"; + bool OpDetails = Params.GetValue("opdetails") == "true"; + bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true"; + + Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId); + if (!FoundProject) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + ProjectStore::Oplog* FoundLog = FoundProject->OpenOplog(OplogId); + + if (!FoundLog) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + if (ChunkId.size() != 2 * sizeof(Oid::OidBits)) + { + return HttpReq.WriteResponse( + HttpResponseCode::BadRequest, + HttpContentType::kText, + fmt::format("Chunk info request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId)); + } + + const Oid ObjId = Oid::FromHexString(ChunkId); + ProjectStore::Project& Project = *FoundProject.Get(); + ProjectStore::Oplog& Oplog = *FoundLog; + + int LSN = Oplog.GetOpIndexByKey(ObjId); + if (LSN == -1) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + std::optional<CbObject> Op = Oplog.GetOpByIndex(LSN); + if (!Op.has_value()) + { + return HttpReq.WriteResponse(HttpResponseCode::NotFound); + } + + if (CSV) + { + ExtendableStringBuilder<4096> CSVWriter; + CSVHeader(Details, AttachmentDetails, CSVWriter); + + CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, ObjId, Op.value(), CSVWriter); + HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView()); + } + else + { + CbObjectWriter Cbo; + Cbo.BeginArray("ops"); + { + CbWriteOp(m_CidStore, Details, OpDetails, AttachmentDetails, LSN, ObjId, Op.value(), Cbo); + } + Cbo.EndArray(); + HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + }, + HttpVerb::kGet); +} + +HttpProjectService::~HttpProjectService() +{ + m_StatsService.UnregisterHandler("prj", *this); +} + +const char* +HttpProjectService::BaseUri() const +{ + return "/prj/"; +} + +void +HttpProjectService::HandleRequest(HttpServerRequest& Request) +{ + if (m_Router.HandleRequest(Request) == false) + { + ZEN_WARN("No route found for {0}", Request.RelativeUri()); + } +} + +void +HttpProjectService::HandleStatsRequest(HttpServerRequest& HttpReq) +{ + const GcStorageSize StoreSize = m_ProjectStore->StorageSize(); + const CidStoreSize CidSize = m_CidStore.TotalSize(); + + CbObjectWriter Cbo; + Cbo.BeginObject("store"); + { + Cbo.BeginObject("size"); + { + Cbo << "disk" << StoreSize.DiskSize; + Cbo << "memory" << StoreSize.MemorySize; + } + Cbo.EndObject(); + } + Cbo.EndObject(); + + Cbo.BeginObject("cid"); + { + Cbo.BeginObject("size"); + { + Cbo << "tiny" << CidSize.TinySize; + Cbo << "small" << CidSize.SmallSize; + Cbo << "large" << CidSize.LargeSize; + Cbo << "total" << CidSize.TotalSize; + } + Cbo.EndObject(); + } + Cbo.EndObject(); + + return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save()); +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +namespace testutils { + using namespace std::literals; + + std::string OidAsString(const Oid& Id) + { + StringBuilder<25> OidStringBuilder; + Id.ToString(OidStringBuilder); + return OidStringBuilder.ToString(); + } + + CbPackage CreateOplogPackage(const Oid& Id, const std::span<const std::pair<Oid, CompressedBuffer>>& Attachments) + { + CbPackage Package; + CbObjectWriter Object; + Object << "key"sv << OidAsString(Id); + if (!Attachments.empty()) + { + Object.BeginArray("bulkdata"); + for (const auto& Attachment : Attachments) + { + CbAttachment Attach(Attachment.second, Attachment.second.DecodeRawHash()); + Object.BeginObject(); + Object << "id"sv << Attachment.first; + Object << "type"sv + << "Standard"sv; + Object << "data"sv << Attach; + Object.EndObject(); + + Package.AddAttachment(Attach); + } + Object.EndArray(); + } + Package.SetObject(Object.Save()); + return Package; + }; + + std::vector<std::pair<Oid, CompressedBuffer>> CreateAttachments(const std::span<const size_t>& Sizes) + { + std::vector<std::pair<Oid, CompressedBuffer>> Result; + Result.reserve(Sizes.size()); + for (size_t Size : Sizes) + { + std::vector<uint8_t> Data; + Data.resize(Size); + uint16_t* DataPtr = reinterpret_cast<uint16_t*>(Data.data()); + for (size_t Idx = 0; Idx < Size / 2; ++Idx) + { + DataPtr[Idx] = static_cast<uint16_t>(Idx % 0xffffu); + } + if (Size & 1) + { + Data[Size - 1] = static_cast<uint8_t>((Size - 1) & 0xff); + } + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size())); + Result.emplace_back(std::pair<Oid, CompressedBuffer>(Oid::NewOid(), Compressed)); + } + return Result; + } + + uint64 GetCompressedOffset(const CompressedBuffer& Buffer, uint64 RawOffset) + { + if (RawOffset > 0) + { + uint64 BlockSize = 0; + OodleCompressor Compressor; + OodleCompressionLevel CompressionLevel; + if (!Buffer.TryGetCompressParameters(Compressor, CompressionLevel, BlockSize)) + { + return 0; + } + return BlockSize > 0 ? RawOffset % BlockSize : 0; + } + return 0; + } + +} // namespace testutils + +TEST_CASE("project.store.create") +{ + using namespace std::literals; + + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CidStore CidStore(Gc); + CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + CidStore.Initialize(CidConfig); + + std::string_view ProjectName("proj1"sv); + std::filesystem::path BasePath = TempDir.Path() / "projectstore"; + ProjectStore ProjectStore(CidStore, BasePath, Gc); + std::filesystem::path RootDir = TempDir.Path() / "root"; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"; + std::filesystem::path ProjectRootDir = TempDir.Path() / "game"; + std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject"; + + Ref<ProjectStore::Project> Project(ProjectStore.NewProject(BasePath / ProjectName, + ProjectName, + RootDir.string(), + EngineRootDir.string(), + ProjectRootDir.string(), + ProjectFilePath.string())); + CHECK(ProjectStore.DeleteProject(ProjectName)); + CHECK(!Project->Exists(BasePath)); +} + +TEST_CASE("project.store.lifetimes") +{ + using namespace std::literals; + + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CidStore CidStore(Gc); + CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + CidStore.Initialize(CidConfig); + + std::filesystem::path BasePath = TempDir.Path() / "projectstore"; + ProjectStore ProjectStore(CidStore, BasePath, Gc); + std::filesystem::path RootDir = TempDir.Path() / "root"; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"; + std::filesystem::path ProjectRootDir = TempDir.Path() / "game"; + std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject"; + + Ref<ProjectStore::Project> Project(ProjectStore.NewProject(BasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + ProjectRootDir.string(), + ProjectFilePath.string())); + ProjectStore::Oplog* Oplog = Project->NewOplog("oplog1", {}); + CHECK(Oplog != nullptr); + + std::filesystem::path DeletePath; + CHECK(Project->PrepareForDelete(DeletePath)); + CHECK(!DeletePath.empty()); + CHECK(Project->OpenOplog("oplog1") == nullptr); + // Oplog is now invalid, but pointer can still be accessed since we store old oplog pointers + CHECK(Oplog->OplogCount() == 0); + // Project is still valid since we have a Ref to it + CHECK(Project->Identifier == "proj1"sv); +} + +TEST_CASE("project.store.gc") +{ + using namespace std::literals; + using namespace testutils; + + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CidStore CidStore(Gc); + CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + CidStore.Initialize(CidConfig); + + std::filesystem::path BasePath = TempDir.Path() / "projectstore"; + ProjectStore ProjectStore(CidStore, BasePath, Gc); + std::filesystem::path RootDir = TempDir.Path() / "root"; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"; + + std::filesystem::path Project1RootDir = TempDir.Path() / "game1"; + std::filesystem::path Project1FilePath = TempDir.Path() / "game1" / "game.uproject"; + { + CreateDirectories(Project1FilePath.parent_path()); + BasicFile ProjectFile; + ProjectFile.Open(Project1FilePath, BasicFile::Mode::kTruncate); + } + + std::filesystem::path Project2RootDir = TempDir.Path() / "game2"; + std::filesystem::path Project2FilePath = TempDir.Path() / "game2" / "game.uproject"; + { + CreateDirectories(Project2FilePath.parent_path()); + BasicFile ProjectFile; + ProjectFile.Open(Project2FilePath, BasicFile::Mode::kTruncate); + } + + { + Ref<ProjectStore::Project> Project1(ProjectStore.NewProject(BasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + Project1RootDir.string(), + Project1FilePath.string())); + ProjectStore::Oplog* Oplog = Project1->NewOplog("oplog1", {}); + CHECK(Oplog != nullptr); + + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), {})); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{77}))); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{7123, 583, 690, 99}))); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{55, 122}))); + } + + { + Ref<ProjectStore::Project> Project2(ProjectStore.NewProject(BasePath / "proj2"sv, + "proj2"sv, + RootDir.string(), + EngineRootDir.string(), + Project2RootDir.string(), + Project2FilePath.string())); + ProjectStore::Oplog* Oplog = Project2->NewOplog("oplog1", {}); + CHECK(Oplog != nullptr); + + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), {})); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{177}))); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{9123, 383, 590, 96}))); + Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{535, 221}))); + } + + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + ProjectStore.GatherReferences(GcCtx); + size_t RefCount = 0; + GcCtx.IterateCids([&RefCount](const IoHash&) { RefCount++; }); + CHECK(RefCount == 14); + ProjectStore.CollectGarbage(GcCtx); + CHECK(ProjectStore.OpenProject("proj1"sv)); + CHECK(ProjectStore.OpenProject("proj2"sv)); + } + + std::filesystem::remove(Project1FilePath); + + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + ProjectStore.GatherReferences(GcCtx); + size_t RefCount = 0; + GcCtx.IterateCids([&RefCount](const IoHash&) { RefCount++; }); + CHECK(RefCount == 7); + ProjectStore.CollectGarbage(GcCtx); + CHECK(!ProjectStore.OpenProject("proj1"sv)); + CHECK(ProjectStore.OpenProject("proj2"sv)); + } +} + +TEST_CASE("project.store.partial.read") +{ + using namespace std::literals; + using namespace testutils; + + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CidStore CidStore(Gc); + CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas"sv, .TinyValueThreshold = 1024, .HugeValueThreshold = 4096}; + CidStore.Initialize(CidConfig); + + std::filesystem::path BasePath = TempDir.Path() / "projectstore"sv; + ProjectStore ProjectStore(CidStore, BasePath, Gc); + std::filesystem::path RootDir = TempDir.Path() / "root"sv; + std::filesystem::path EngineRootDir = TempDir.Path() / "engine"sv; + + std::filesystem::path Project1RootDir = TempDir.Path() / "game1"sv; + std::filesystem::path Project1FilePath = TempDir.Path() / "game1"sv / "game.uproject"sv; + { + CreateDirectories(Project1FilePath.parent_path()); + BasicFile ProjectFile; + ProjectFile.Open(Project1FilePath, BasicFile::Mode::kTruncate); + } + + std::vector<Oid> OpIds; + OpIds.insert(OpIds.end(), {Oid::NewOid(), Oid::NewOid(), Oid::NewOid(), Oid::NewOid()}); + std::unordered_map<Oid, std::vector<std::pair<Oid, CompressedBuffer>>, Oid::Hasher> Attachments; + { + Ref<ProjectStore::Project> Project1(ProjectStore.NewProject(BasePath / "proj1"sv, + "proj1"sv, + RootDir.string(), + EngineRootDir.string(), + Project1RootDir.string(), + Project1FilePath.string())); + ProjectStore::Oplog* Oplog = Project1->NewOplog("oplog1"sv, {}); + CHECK(Oplog != nullptr); + Attachments[OpIds[0]] = {}; + Attachments[OpIds[1]] = CreateAttachments(std::initializer_list<size_t>{77}); + Attachments[OpIds[2]] = CreateAttachments(std::initializer_list<size_t>{7123, 9583, 690, 99}); + Attachments[OpIds[3]] = CreateAttachments(std::initializer_list<size_t>{55, 122}); + for (auto It : Attachments) + { + Oplog->AppendNewOplogEntry(CreateOplogPackage(It.first, It.second)); + } + } + { + IoBuffer Chunk; + CHECK(ProjectStore + .GetChunk("proj1"sv, + "oplog1"sv, + Attachments[OpIds[1]][0].second.DecodeRawHash().ToHexString(), + HttpContentType::kCompressedBinary, + Chunk) + .first == HttpResponseCode::OK); + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Attachment = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), RawHash, RawSize); + CHECK(RawSize == Attachments[OpIds[1]][0].second.DecodeRawSize()); + } + + IoBuffer ChunkResult; + CHECK(ProjectStore + .GetChunkRange("proj1"sv, + "oplog1"sv, + OidAsString(Attachments[OpIds[2]][1].first), + 0, + ~0ull, + HttpContentType::kCompressedBinary, + ChunkResult) + .first == HttpResponseCode::OK); + CHECK(ChunkResult); + CHECK(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult)).DecodeRawSize() == + Attachments[OpIds[2]][1].second.DecodeRawSize()); + + IoBuffer PartialChunkResult; + CHECK(ProjectStore + .GetChunkRange("proj1"sv, + "oplog1"sv, + OidAsString(Attachments[OpIds[2]][1].first), + 5, + 1773, + HttpContentType::kCompressedBinary, + PartialChunkResult) + .first == HttpResponseCode::OK); + CHECK(PartialChunkResult); + IoHash PartialRawHash; + uint64_t PartialRawSize; + CompressedBuffer PartialCompressedResult = + CompressedBuffer::FromCompressed(SharedBuffer(PartialChunkResult), PartialRawHash, PartialRawSize); + CHECK(PartialRawSize >= 1773); + + uint64_t RawOffsetInPartialCompressed = GetCompressedOffset(PartialCompressedResult, 5); + SharedBuffer PartialDecompressed = PartialCompressedResult.Decompress(RawOffsetInPartialCompressed); + SharedBuffer FullDecompressed = Attachments[OpIds[2]][1].second.Decompress(); + const uint8_t* FullDataPtr = &(reinterpret_cast<const uint8_t*>(FullDecompressed.GetView().GetData())[5]); + const uint8_t* PartialDataPtr = reinterpret_cast<const uint8_t*>(PartialDecompressed.GetView().GetData()); + CHECK(FullDataPtr[0] == PartialDataPtr[0]); +} + +TEST_CASE("project.store.block") +{ + using namespace std::literals; + using namespace testutils; + + std::vector<std::size_t> AttachmentSizes({7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, 2257, 3685, 3489, + 7194, 6151, 5482, 6217, 3511, 6738, 5061, 7537, 2759, 1916, 8210, 2235, 4024, 1582, 5251, + 491, 5464, 4607, 8135, 3767, 4045, 4415, 5007, 8876, 6761, 3359, 8526, 4097, 4855, 8225}); + + std::vector<std::pair<Oid, CompressedBuffer>> AttachmentsWithId = CreateAttachments(AttachmentSizes); + std::vector<SharedBuffer> Chunks; + Chunks.reserve(AttachmentSizes.size()); + for (const auto& It : AttachmentsWithId) + { + Chunks.push_back(It.second.GetCompressed().Flatten()); + } + CompressedBuffer Block = GenerateBlock(std::move(Chunks)); + IoBuffer BlockBuffer = Block.GetCompressed().Flatten().AsIoBuffer(); + CHECK(IterateBlock(std::move(BlockBuffer), [](CompressedBuffer&&, const IoHash&) {})); +} + +#endif + +void +prj_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenserver/projectstore/projectstore.h b/src/zenserver/projectstore/projectstore.h new file mode 100644 index 000000000..e4f664b85 --- /dev/null +++ b/src/zenserver/projectstore/projectstore.h @@ -0,0 +1,372 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/uid.h> +#include <zencore/xxhash.h> +#include <zenhttp/httpserver.h> +#include <zenstore/gc.h> + +#include "monitoring/httpstats.h" + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +class CbPackage; +class CidStore; +class AuthMgr; +class ScrubContext; + +struct OplogEntry +{ + uint32_t OpLsn; + uint32_t OpCoreOffset; // note: Multiple of alignment! + uint32_t OpCoreSize; + uint32_t OpCoreHash; // Used as checksum + XXH3_128 OpKeyHash; // XXH128_canonical_t + + inline Oid OpKeyAsOId() const + { + Oid Id; + memcpy(Id.OidBits, &OpKeyHash, sizeof Id.OidBits); + return Id; + } +}; + +struct OplogEntryAddress +{ + uint64_t Offset; + uint64_t Size; +}; + +static_assert(IsPow2(sizeof(OplogEntry))); + +/** Project Store + + A project store consists of a number of Projects. + + Each project contains a number of oplogs (short for "operation log"). UE uses + one oplog per target platform to store the output of the cook process. + + An oplog consists of a sequence of "op" entries. Each entry is a structured object + containing references to attachments. Attachments are typically the serialized + package data split into separate chunks for bulk data, exports and header + information. + */ +class ProjectStore : public RefCounted, public GcStorage, public GcContributor +{ + struct OplogStorage; + +public: + ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcManager& Gc); + ~ProjectStore(); + + struct Project; + + struct Oplog + { + Oplog(std::string_view Id, + Project* Project, + CidStore& Store, + std::filesystem::path BasePath, + const std::filesystem::path& MarkerPath); + ~Oplog(); + + [[nodiscard]] static bool ExistsAt(std::filesystem::path BasePath); + + void Read(); + void Write(); + + void IterateFileMap(std::function<void(const Oid&, const std::string_view& ServerPath, const std::string_view& ClientPath)>&& Fn); + void IterateOplog(std::function<void(CbObject)>&& Fn); + void IterateOplogWithKey(std::function<void(int, const Oid&, CbObject)>&& Fn); + std::optional<CbObject> GetOpByKey(const Oid& Key); + std::optional<CbObject> GetOpByIndex(int Index); + int GetOpIndexByKey(const Oid& Key); + + IoBuffer FindChunk(Oid ChunkId); + + inline static const uint32_t kInvalidOp = ~0u; + + /** Persist a new oplog entry + * + * Returns the oplog LSN assigned to the new entry, or kInvalidOp if the entry is rejected + */ + uint32_t AppendNewOplogEntry(CbPackage Op); + + uint32_t AppendNewOplogEntry(CbObject Core); + + enum UpdateType + { + kUpdateNewEntry, + kUpdateReplay + }; + + const std::string& OplogId() const { return m_OplogId; } + + const std::filesystem::path& TempPath() const { return m_TempPath; } + const std::filesystem::path& MarkerPath() const { return m_MarkerPath; } + + spdlog::logger& Log() { return m_OuterProject->Log(); } + void Flush(); + void Scrub(ScrubContext& Ctx) const; + void GatherReferences(GcContext& GcCtx); + uint64_t TotalSize() const; + + std::size_t OplogCount() const + { + RwLock::SharedLockScope _(m_OplogLock); + return m_LatestOpMap.size(); + } + + bool IsExpired() const; + std::filesystem::path PrepareForDelete(bool MoveFolder); + + private: + struct FileMapEntry + { + std::string ServerPath; + std::string ClientPath; + }; + + template<class V> + using OidMap = tsl::robin_map<Oid, V, Oid::Hasher>; + + Project* m_OuterProject = nullptr; + CidStore& m_CidStore; + std::filesystem::path m_BasePath; + std::filesystem::path m_MarkerPath; + std::filesystem::path m_TempPath; + + mutable RwLock m_OplogLock; + OidMap<IoHash> m_ChunkMap; // output data chunk id -> CAS address + OidMap<IoHash> m_MetaMap; // meta chunk id -> CAS address + OidMap<FileMapEntry> m_FileMap; // file id -> file map entry + int32_t m_ManifestVersion; // File system manifest version + tsl::robin_map<int, OplogEntryAddress> m_OpAddressMap; // Index LSN -> op data in ops blob file + OidMap<int> m_LatestOpMap; // op key -> latest op LSN for key + + RefPtr<OplogStorage> m_Storage; + std::string m_OplogId; + + /** Scan oplog and register each entry, thus updating the in-memory tracking tables + */ + void ReplayLog(); + + struct OplogEntryMapping + { + struct Mapping + { + Oid Id; + IoHash Hash; + }; + struct FileMapping : public Mapping + { + std::string ServerPath; + std::string ClientPath; + }; + std::vector<Mapping> Chunks; + std::vector<Mapping> Meta; + std::vector<FileMapping> Files; + }; + + OplogEntryMapping GetMapping(CbObject Core); + + /** Update tracking metadata for a new oplog entry + * + * This is used during replay (and gets called as part of new op append) + * + * Returns the oplog LSN assigned to the new entry, or kInvalidOp if the entry is rejected + */ + uint32_t RegisterOplogEntry(RwLock::ExclusiveLockScope& OplogLock, + const OplogEntryMapping& OpMapping, + const OplogEntry& OpEntry, + UpdateType TypeOfUpdate); + + void AddFileMapping(const RwLock::ExclusiveLockScope& OplogLock, + Oid FileId, + IoHash Hash, + std::string_view ServerPath, + std::string_view ClientPath); + void AddChunkMapping(const RwLock::ExclusiveLockScope& OplogLock, Oid ChunkId, IoHash Hash); + void AddMetaMapping(const RwLock::ExclusiveLockScope& OplogLock, Oid ChunkId, IoHash Hash); + }; + + struct Project : public RefCounted + { + std::string Identifier; + std::filesystem::path RootDir; + std::string EngineRootDir; + std::string ProjectRootDir; + std::string ProjectFilePath; + + Oplog* NewOplog(std::string_view OplogId, const std::filesystem::path& MarkerPath); + Oplog* OpenOplog(std::string_view OplogId); + void DeleteOplog(std::string_view OplogId); + void IterateOplogs(std::function<void(const Oplog&)>&& Fn) const; + void IterateOplogs(std::function<void(Oplog&)>&& Fn); + std::vector<std::string> ScanForOplogs() const; + bool IsExpired() const; + + Project(ProjectStore* PrjStore, CidStore& Store, std::filesystem::path BasePath); + virtual ~Project(); + + void Read(); + void Write(); + [[nodiscard]] static bool Exists(std::filesystem::path BasePath); + void Flush(); + void Scrub(ScrubContext& Ctx); + spdlog::logger& Log(); + void GatherReferences(GcContext& GcCtx); + uint64_t TotalSize() const; + bool PrepareForDelete(std::filesystem::path& OutDeletePath); + + private: + ProjectStore* m_ProjectStore; + CidStore& m_CidStore; + mutable RwLock m_ProjectLock; + std::map<std::string, std::unique_ptr<Oplog>> m_Oplogs; + std::vector<std::unique_ptr<Oplog>> m_DeletedOplogs; + std::filesystem::path m_OplogStoragePath; + + std::filesystem::path BasePathForOplog(std::string_view OplogId); + }; + + // Oplog* OpenProjectOplog(std::string_view ProjectId, std::string_view OplogId); + + Ref<Project> OpenProject(std::string_view ProjectId); + Ref<Project> NewProject(std::filesystem::path BasePath, + std::string_view ProjectId, + std::string_view RootDir, + std::string_view EngineRootDir, + std::string_view ProjectRootDir, + std::string_view ProjectFilePath); + bool DeleteProject(std::string_view ProjectId); + bool Exists(std::string_view ProjectId); + void Flush(); + void Scrub(ScrubContext& Ctx); + void DiscoverProjects(); + void IterateProjects(std::function<void(Project& Prj)>&& Fn); + + spdlog::logger& Log() { return m_Log; } + const std::filesystem::path& BasePath() const { return m_ProjectBasePath; } + + virtual void GatherReferences(GcContext& GcCtx) override; + virtual void CollectGarbage(GcContext& GcCtx) override; + virtual GcStorageSize StorageSize() const override; + + CbArray GetProjectsList(); + std::pair<HttpResponseCode, std::string> GetProjectFiles(const std::string_view ProjectId, + const std::string_view OplogId, + bool FilterClient, + CbObject& OutPayload); + std::pair<HttpResponseCode, std::string> GetChunkInfo(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view ChunkId, + CbObject& OutPayload); + std::pair<HttpResponseCode, std::string> GetChunkRange(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view ChunkId, + uint64_t Offset, + uint64_t Size, + ZenContentType AcceptType, + IoBuffer& OutChunk); + std::pair<HttpResponseCode, std::string> GetChunk(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view Cid, + ZenContentType AcceptType, + IoBuffer& OutChunk); + + std::pair<HttpResponseCode, std::string> PutChunk(const std::string_view ProjectId, + const std::string_view OplogId, + const std::string_view Cid, + ZenContentType ContentType, + IoBuffer&& Chunk); + + std::pair<HttpResponseCode, std::string> WriteOplog(const std::string_view ProjectId, + const std::string_view OplogId, + IoBuffer&& Payload, + CbObject& OutResponse); + + std::pair<HttpResponseCode, std::string> ReadOplog(const std::string_view ProjectId, + const std::string_view OplogId, + const HttpServerRequest::QueryParams& Params, + CbObject& OutResponse); + + std::pair<HttpResponseCode, std::string> WriteBlock(const std::string_view ProjectId, + const std::string_view OplogId, + IoBuffer&& Payload); + + void Rpc(HttpServerRequest& HttpReq, + const std::string_view ProjectId, + const std::string_view OplogId, + IoBuffer&& Payload, + AuthMgr& AuthManager); + + std::pair<HttpResponseCode, std::string> Export(ProjectStore::Project& Project, + ProjectStore::Oplog& Oplog, + CbObjectView&& Params, + AuthMgr& AuthManager); + + std::pair<HttpResponseCode, std::string> Import(ProjectStore::Project& Project, + ProjectStore::Oplog& Oplog, + CbObjectView&& Params, + AuthMgr& AuthManager); + +private: + spdlog::logger& m_Log; + CidStore& m_CidStore; + std::filesystem::path m_ProjectBasePath; + mutable RwLock m_ProjectsLock; + std::map<std::string, Ref<Project>> m_Projects; + + std::filesystem::path BasePathForProject(std::string_view ProjectId); +}; + +////////////////////////////////////////////////////////////////////////// +// +// {project} a project identifier +// {target} a variation of the project, typically a build target +// {lsn} oplog entry sequence number +// +// /prj/{project} +// /prj/{project}/oplog/{target} +// /prj/{project}/oplog/{target}/{lsn} +// +// oplog entry +// +// id: {id} +// key: {} +// meta: {} +// data: [] +// refs: +// + +class HttpProjectService : public HttpService, public IHttpStatsProvider +{ +public: + HttpProjectService(CidStore& Store, ProjectStore* InProjectStore, HttpStatsService& StatsService, AuthMgr& AuthMgr); + ~HttpProjectService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + + virtual void HandleStatsRequest(HttpServerRequest& Request) override; + +private: + inline spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + CidStore& m_CidStore; + HttpRequestRouter m_Router; + Ref<ProjectStore> m_ProjectStore; + HttpStatsService& m_StatsService; + AuthMgr& m_AuthMgr; +}; + +void prj_forcelink(); + +} // namespace zen diff --git a/src/zenserver/projectstore/remoteprojectstore.cpp b/src/zenserver/projectstore/remoteprojectstore.cpp new file mode 100644 index 000000000..1e6ca51a1 --- /dev/null +++ b/src/zenserver/projectstore/remoteprojectstore.cpp @@ -0,0 +1,1036 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "remoteprojectstore.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compress.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/workthreadpool.h> +#include <zenstore/cidstore.h> + +namespace zen { + +/* + OplogContainer + Binary("ops") // Compressed CompactBinary object to hide attachment references, also makes the oplog smaller + { + CbArray("ops") + { + CbObject Op + (CbFieldType::BinaryAttachment Attachments[]) + (OpData) + } + } + CbArray("blocks") + CbObject + CbFieldType::BinaryAttachment "rawhash" // Optional, only if we are creating blocks (Jupiter/File) + CbArray("chunks") + CbFieldType::Hash // Chunk hashes + CbArray("chunks") // Optional, only if we are not creating blocks (Zen) + CbFieldType::BinaryAttachment // Chunk attachment hashes + + CompressedBinary ChunkBlock + { + VarUInt ChunkCount + VarUInt ChunkSizes[ChunkCount] + uint8_t[chunksize])[ChunkCount] + } +*/ + +////////////////////////////// AsyncRemoteResult + +struct AsyncRemoteResult +{ + void SetError(int32_t ErrorCode, const std::string& ErrorReason, const std::string ErrorText) + { + int32_t Expected = 0; + if (m_ErrorCode.compare_exchange_weak(Expected, ErrorCode ? ErrorCode : -1)) + { + m_ErrorReason = ErrorReason; + m_ErrorText = ErrorText; + } + } + bool IsError() const { return m_ErrorCode.load() != 0; } + int GetError() const { return m_ErrorCode.load(); }; + const std::string& GetErrorReason() const { return m_ErrorReason; }; + const std::string& GetErrorText() const { return m_ErrorText; }; + RemoteProjectStore::Result ConvertResult(double ElapsedSeconds = 0.0) const + { + return RemoteProjectStore::Result{m_ErrorCode, ElapsedSeconds, m_ErrorReason, m_ErrorText}; + } + +private: + std::atomic<int32_t> m_ErrorCode = 0; + std::string m_ErrorReason; + std::string m_ErrorText; +}; + +bool +IterateBlock(IoBuffer&& CompressedBlock, std::function<void(CompressedBuffer&& Chunk, const IoHash& AttachmentHash)> Visitor) +{ + IoBuffer BlockPayload = CompressedBuffer::FromCompressedNoValidate(std::move(CompressedBlock)).Decompress().AsIoBuffer(); + + MemoryView BlockView = BlockPayload.GetView(); + const uint8_t* ReadPtr = reinterpret_cast<const uint8_t*>(BlockView.GetData()); + uint32_t NumberSize; + uint64_t ChunkCount = ReadVarUInt(ReadPtr, NumberSize); + ReadPtr += NumberSize; + std::vector<uint64_t> ChunkSizes; + ChunkSizes.reserve(ChunkCount); + while (ChunkCount--) + { + ChunkSizes.push_back(ReadVarUInt(ReadPtr, NumberSize)); + ReadPtr += NumberSize; + } + ptrdiff_t TempBufferLength = std::distance(reinterpret_cast<const uint8_t*>(BlockView.GetData()), ReadPtr); + ZEN_ASSERT(TempBufferLength > 0); + for (uint64_t ChunkSize : ChunkSizes) + { + IoBuffer Chunk(IoBuffer::Wrap, ReadPtr, ChunkSize); + IoHash AttachmentRawHash; + uint64_t AttachmentRawSize; + CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), AttachmentRawHash, AttachmentRawSize); + + if (!CompressedChunk) + { + ZEN_ERROR("Invalid chunk in block"); + return false; + } + Visitor(std::move(CompressedChunk), AttachmentRawHash); + ReadPtr += ChunkSize; + ZEN_ASSERT(ReadPtr <= BlockView.GetDataEnd()); + } + return true; +}; + +CompressedBuffer +GenerateBlock(std::vector<SharedBuffer>&& Chunks) +{ + size_t ChunkCount = Chunks.size(); + SharedBuffer SizeBuffer; + { + IoBuffer TempBuffer(ChunkCount * 9); + MutableMemoryView View = TempBuffer.GetMutableView(); + uint8_t* BufferStartPtr = reinterpret_cast<uint8_t*>(View.GetData()); + uint8_t* BufferEndPtr = BufferStartPtr; + BufferEndPtr += WriteVarUInt(gsl::narrow<uint64_t>(ChunkCount), BufferEndPtr); + auto It = Chunks.begin(); + while (It != Chunks.end()) + { + BufferEndPtr += WriteVarUInt(gsl::narrow<uint64_t>(It->GetSize()), BufferEndPtr); + It++; + } + ZEN_ASSERT(BufferEndPtr <= View.GetDataEnd()); + ptrdiff_t TempBufferLength = std::distance(BufferStartPtr, BufferEndPtr); + SizeBuffer = SharedBuffer(IoBuffer(TempBuffer, 0, gsl::narrow<size_t>(TempBufferLength))); + } + CompositeBuffer AllBuffers(std::move(SizeBuffer), CompositeBuffer(std::move(Chunks))); + + CompressedBuffer CompressedBlock = + CompressedBuffer::Compress(std::move(AllBuffers), OodleCompressor::Mermaid, OodleCompressionLevel::None); + + return CompressedBlock; +} + +struct Block +{ + IoHash BlockHash; + std::vector<IoHash> ChunksInBlock; +}; + +void +CreateBlock(WorkerThreadPool& WorkerPool, + Latch& OpSectionsLatch, + std::vector<SharedBuffer>&& ChunksInBlock, + RwLock& SectionsLock, + std::vector<Block>& Blocks, + size_t BlockIndex, + const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock, + AsyncRemoteResult& RemoteResult) +{ + OpSectionsLatch.AddCount(1); + WorkerPool.ScheduleWork( + [&Blocks, &SectionsLock, &OpSectionsLatch, BlockIndex, Chunks = std::move(ChunksInBlock), &AsyncOnBlock, &RemoteResult]() mutable { + auto _ = MakeGuard([&OpSectionsLatch] { OpSectionsLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + if (!Chunks.empty()) + { + CompressedBuffer CompressedBlock = GenerateBlock(std::move(Chunks)); // Move to callback and return IoHash + IoHash BlockHash = CompressedBlock.DecodeRawHash(); + AsyncOnBlock(std::move(CompressedBlock), BlockHash); + { + // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index + RwLock::SharedLockScope __(SectionsLock); + Blocks[BlockIndex].BlockHash = BlockHash; + } + } + }); +} + +size_t +AddBlock(RwLock& BlocksLock, std::vector<Block>& Blocks) +{ + size_t BlockIndex; + { + RwLock::ExclusiveLockScope _(BlocksLock); + BlockIndex = Blocks.size(); + Blocks.resize(BlockIndex + 1); + } + return BlockIndex; +} + +CbObject +BuildContainer(CidStore& ChunkStore, + ProjectStore::Oplog& Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize, + bool BuildBlocks, + WorkerThreadPool& WorkerPool, + const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock, + const std::function<void(const IoHash&)>& OnLargeAttachment, + const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks, + AsyncRemoteResult& RemoteResult) +{ + using namespace std::literals; + + std::unordered_set<IoHash, IoHash::Hasher> LargeChunkHashes; + CbObjectWriter SectionOpsWriter; + SectionOpsWriter.BeginArray("ops"sv); + + size_t OpCount = 0; + + CbObject OplogContainerObject; + { + RwLock BlocksLock; + std::vector<Block> Blocks; + CompressedBuffer OpsBuffer; + + Latch BlockCreateLatch(1); + + std::unordered_set<IoHash, IoHash::Hasher> BlockAttachmentHashes; + + size_t BlockSize = 0; + std::vector<SharedBuffer> ChunksInBlock; + + std::unordered_set<IoHash, IoHash::Hasher> Attachments; + Oplog.IterateOplog([&Attachments, &SectionOpsWriter, &OpCount](CbObject Op) { + Op.IterateAttachments([&](CbFieldView FieldView) { Attachments.insert(FieldView.AsAttachment()); }); + (SectionOpsWriter) << Op; + OpCount++; + }); + + for (const IoHash& AttachmentHash : Attachments) + { + IoBuffer Payload = ChunkStore.FindChunkByCid(AttachmentHash); + if (!Payload) + { + RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::NotFound), + fmt::format("Failed to find attachment {} for op", AttachmentHash), + {}); + ZEN_ERROR("Failed to build container ({}). Reason: '{}'", RemoteResult.GetError(), RemoteResult.GetErrorReason()); + return {}; + } + uint64_t PayloadSize = Payload.GetSize(); + if (PayloadSize > MaxChunkEmbedSize) + { + if (LargeChunkHashes.insert(AttachmentHash).second) + { + OnLargeAttachment(AttachmentHash); + } + continue; + } + + if (!BlockAttachmentHashes.insert(AttachmentHash).second) + { + continue; + } + + BlockSize += PayloadSize; + if (BuildBlocks) + { + ChunksInBlock.emplace_back(SharedBuffer(std::move(Payload))); + } + else + { + Payload = {}; + } + + if (BlockSize >= MaxBlockSize) + { + size_t BlockIndex = AddBlock(BlocksLock, Blocks); + if (BuildBlocks) + { + CreateBlock(WorkerPool, + BlockCreateLatch, + std::move(ChunksInBlock), + BlocksLock, + Blocks, + BlockIndex, + AsyncOnBlock, + RemoteResult); + } + else + { + OnBlockChunks(BlockAttachmentHashes); + } + { + // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index + RwLock::SharedLockScope _(BlocksLock); + Blocks[BlockIndex].ChunksInBlock.insert(Blocks[BlockIndex].ChunksInBlock.end(), + BlockAttachmentHashes.begin(), + BlockAttachmentHashes.end()); + } + BlockAttachmentHashes.clear(); + ChunksInBlock.clear(); + BlockSize = 0; + } + } + if (BlockSize > 0) + { + size_t BlockIndex = AddBlock(BlocksLock, Blocks); + if (BuildBlocks) + { + CreateBlock(WorkerPool, + BlockCreateLatch, + std::move(ChunksInBlock), + BlocksLock, + Blocks, + BlockIndex, + AsyncOnBlock, + RemoteResult); + } + else + { + OnBlockChunks(BlockAttachmentHashes); + } + { + // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index + RwLock::SharedLockScope _(BlocksLock); + Blocks[BlockIndex].ChunksInBlock.insert(Blocks[BlockIndex].ChunksInBlock.end(), + BlockAttachmentHashes.begin(), + BlockAttachmentHashes.end()); + } + BlockAttachmentHashes.clear(); + ChunksInBlock.clear(); + BlockSize = 0; + } + SectionOpsWriter.EndArray(); // "ops" + + CompressedBuffer CompressedOpsSection = CompressedBuffer::Compress(SectionOpsWriter.Save().GetBuffer()); + ZEN_DEBUG("Added oplog section {}, {}", CompressedOpsSection.DecodeRawHash(), NiceBytes(CompressedOpsSection.GetCompressedSize())); + + BlockCreateLatch.CountDown(); + while (!BlockCreateLatch.Wait(1000)) + { + ZEN_INFO("Creating blocks, {} remaining...", BlockCreateLatch.Remaining()); + } + + if (!RemoteResult.IsError()) + { + CbObjectWriter OplogContinerWriter; + RwLock::SharedLockScope _(BlocksLock); + OplogContinerWriter.AddBinary("ops"sv, CompressedOpsSection.GetCompressed().Flatten().AsIoBuffer()); + + OplogContinerWriter.BeginArray("blocks"sv); + { + for (const Block& B : Blocks) + { + ZEN_ASSERT(!B.ChunksInBlock.empty()); + if (BuildBlocks) + { + ZEN_ASSERT(B.BlockHash != IoHash::Zero); + + OplogContinerWriter.BeginObject(); + { + OplogContinerWriter.AddBinaryAttachment("rawhash"sv, B.BlockHash); + OplogContinerWriter.BeginArray("chunks"sv); + { + for (const IoHash& RawHash : B.ChunksInBlock) + { + OplogContinerWriter.AddHash(RawHash); + } + } + OplogContinerWriter.EndArray(); // "chunks" + } + OplogContinerWriter.EndObject(); + continue; + } + + ZEN_ASSERT(B.BlockHash == IoHash::Zero); + OplogContinerWriter.BeginObject(); + { + OplogContinerWriter.BeginArray("chunks"sv); + { + for (const IoHash& RawHash : B.ChunksInBlock) + { + OplogContinerWriter.AddBinaryAttachment(RawHash); + } + } + OplogContinerWriter.EndArray(); + } + OplogContinerWriter.EndObject(); + } + } + OplogContinerWriter.EndArray(); // "blocks"sv + + OplogContinerWriter.BeginArray("chunks"sv); + { + for (const IoHash& AttachmentHash : LargeChunkHashes) + { + OplogContinerWriter.AddBinaryAttachment(AttachmentHash); + } + } + OplogContinerWriter.EndArray(); // "chunks" + + OplogContainerObject = OplogContinerWriter.Save(); + } + } + return OplogContainerObject; +} + +RemoteProjectStore::LoadContainerResult +BuildContainer(CidStore& ChunkStore, + ProjectStore::Oplog& Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize, + bool BuildBlocks, + const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock, + const std::function<void(const IoHash&)>& OnLargeAttachment, + const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks) +{ + // We are creating a worker thread pool here since we are uploading a lot of attachments in one go and we dont want to keep a + // WorkerThreadPool alive + size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u); + WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount)); + + AsyncRemoteResult RemoteResult; + CbObject ContainerObject = BuildContainer(ChunkStore, + Oplog, + MaxBlockSize, + MaxChunkEmbedSize, + BuildBlocks, + WorkerPool, + AsyncOnBlock, + OnLargeAttachment, + OnBlockChunks, + RemoteResult); + return RemoteProjectStore::LoadContainerResult{RemoteResult.ConvertResult(), ContainerObject}; +} + +RemoteProjectStore::Result +SaveOplog(CidStore& ChunkStore, + RemoteProjectStore& RemoteStore, + ProjectStore::Oplog& Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize, + bool BuildBlocks, + bool UseTempBlocks, + bool ForceUpload) +{ + using namespace std::literals; + + Stopwatch Timer; + + // We are creating a worker thread pool here since we are uploading a lot of attachments in one go + // Doing upload is a rare and transient occation so we don't want to keep a WorkerThreadPool alive. + size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u); + WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount)); + + std::filesystem::path AttachmentTempPath; + if (UseTempBlocks) + { + AttachmentTempPath = Oplog.TempPath(); + AttachmentTempPath.append(".pending"); + CreateDirectories(AttachmentTempPath); + } + + AsyncRemoteResult RemoteResult; + RwLock AttachmentsLock; + std::unordered_set<IoHash, IoHash::Hasher> LargeAttachments; + std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> CreatedBlocks; + + auto MakeTempBlock = [AttachmentTempPath, &RemoteResult, &AttachmentsLock, &CreatedBlocks](CompressedBuffer&& CompressedBlock, + const IoHash& BlockHash) { + std::filesystem::path BlockPath = AttachmentTempPath; + BlockPath.append(BlockHash.ToHexString()); + if (!std::filesystem::exists(BlockPath)) + { + IoBuffer BlockBuffer; + try + { + BasicFile BlockFile; + BlockFile.Open(BlockPath, BasicFile::Mode::kTruncateDelete); + uint64_t Offset = 0; + for (const SharedBuffer& Buffer : CompressedBlock.GetCompressed().GetSegments()) + { + BlockFile.Write(Buffer.GetView(), Offset); + Offset += Buffer.GetSize(); + } + void* FileHandle = BlockFile.Detach(); + BlockBuffer = IoBuffer(IoBuffer::File, FileHandle, 0, Offset); + } + catch (std::exception& Ex) + { + RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), + Ex.what(), + "Unable to create temp block file"); + return; + } + + BlockBuffer.MarkAsDeleteOnClose(); + { + RwLock::ExclusiveLockScope __(AttachmentsLock); + CreatedBlocks.insert({BlockHash, std::move(BlockBuffer)}); + } + ZEN_DEBUG("Saved temp block {}, {}", BlockHash, NiceBytes(CompressedBlock.GetCompressedSize())); + } + }; + + auto UploadBlock = [&RemoteStore, &RemoteResult](CompressedBuffer&& CompressedBlock, const IoHash& BlockHash) { + RemoteProjectStore::SaveAttachmentResult Result = RemoteStore.SaveAttachment(CompressedBlock.GetCompressed(), BlockHash); + if (Result.ErrorCode) + { + RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); + ZEN_ERROR("Failed to save attachment ({}). Reason: '{}'", RemoteResult.GetErrorReason(), RemoteResult.GetError()); + return; + } + ZEN_DEBUG("Saved block {}, {}", BlockHash, NiceBytes(CompressedBlock.GetCompressedSize())); + }; + + std::vector<std::vector<IoHash>> BlockChunks; + auto OnBlockChunks = [&BlockChunks](const std::unordered_set<IoHash, IoHash::Hasher>& Chunks) { + BlockChunks.push_back({Chunks.begin(), Chunks.end()}); + ZEN_DEBUG("Found {} block chunks", Chunks.size()); + }; + + auto OnLargeAttachment = [&AttachmentsLock, &LargeAttachments](const IoHash& AttachmentHash) { + { + RwLock::ExclusiveLockScope _(AttachmentsLock); + LargeAttachments.insert(AttachmentHash); + } + ZEN_DEBUG("Found attachment {}", AttachmentHash); + }; + + std::function<void(CompressedBuffer&&, const IoHash&)> OnBlock; + if (UseTempBlocks) + { + OnBlock = MakeTempBlock; + } + else + { + OnBlock = UploadBlock; + } + + CbObject OplogContainerObject = BuildContainer(ChunkStore, + Oplog, + MaxBlockSize, + MaxChunkEmbedSize, + BuildBlocks, + WorkerPool, + OnBlock, + OnLargeAttachment, + OnBlockChunks, + RemoteResult); + + if (!RemoteResult.IsError()) + { + uint64_t ChunkCount = OplogContainerObject["chunks"sv].AsArrayView().Num(); + uint64_t BlockCount = OplogContainerObject["blocks"sv].AsArrayView().Num(); + ZEN_INFO("Saving oplog container with {} attachments and {} blocks...", ChunkCount, BlockCount); + RemoteProjectStore::SaveResult ContainerSaveResult = RemoteStore.SaveContainer(OplogContainerObject.GetBuffer().AsIoBuffer()); + if (ContainerSaveResult.ErrorCode) + { + RemoteResult.SetError(ContainerSaveResult.ErrorCode, ContainerSaveResult.Reason, "Failed to save oplog container"); + ZEN_ERROR("Failed to save oplog container ({}). Reason: '{}'", RemoteResult.GetErrorReason(), RemoteResult.GetError()); + } + ZEN_DEBUG("Saved container in {}", NiceTimeSpanMs(static_cast<uint64_t>(ContainerSaveResult.ElapsedSeconds * 1000))); + if (!ContainerSaveResult.Needs.empty()) + { + ZEN_INFO("Filtering needed attachments..."); + std::vector<IoHash> NeededLargeAttachments; + std::unordered_set<IoHash, IoHash::Hasher> NeededOtherAttachments; + NeededLargeAttachments.reserve(LargeAttachments.size()); + NeededOtherAttachments.reserve(CreatedBlocks.size()); + if (ForceUpload) + { + NeededLargeAttachments.insert(NeededLargeAttachments.end(), LargeAttachments.begin(), LargeAttachments.end()); + } + else + { + for (const IoHash& RawHash : ContainerSaveResult.Needs) + { + if (LargeAttachments.contains(RawHash)) + { + NeededLargeAttachments.push_back(RawHash); + continue; + } + NeededOtherAttachments.insert(RawHash); + } + } + + Latch SaveAttachmentsLatch(1); + if (!NeededLargeAttachments.empty()) + { + ZEN_INFO("Saving large attachments..."); + for (const IoHash& RawHash : NeededLargeAttachments) + { + if (RemoteResult.IsError()) + { + break; + } + SaveAttachmentsLatch.AddCount(1); + WorkerPool.ScheduleWork([&ChunkStore, &RemoteStore, &SaveAttachmentsLatch, &RemoteResult, RawHash, &CreatedBlocks]() { + auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + + IoBuffer Payload; + if (auto It = CreatedBlocks.find(RawHash); It != CreatedBlocks.end()) + { + Payload = std::move(It->second); + } + else + { + Payload = ChunkStore.FindChunkByCid(RawHash); + } + if (!Payload) + { + RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::NotFound), + fmt::format("Failed to find attachment {}", RawHash), + {}); + ZEN_ERROR("Failed to build container ({}). Reason: '{}'", + RemoteResult.GetErrorReason(), + RemoteResult.GetError()); + return; + } + + RemoteProjectStore::SaveAttachmentResult Result = + RemoteStore.SaveAttachment(CompositeBuffer(SharedBuffer(Payload)), RawHash); + if (Result.ErrorCode) + { + RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); + ZEN_ERROR("Failed to save attachment '{}', {} ({}). Reason: '{}'", + RawHash, + NiceBytes(Payload.GetSize()), + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + ZEN_DEBUG("Saved attachment {}, {} in {}", + RawHash, + NiceBytes(Payload.GetSize()), + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + return; + }); + } + } + + if (!CreatedBlocks.empty()) + { + ZEN_INFO("Saving created block attachments..."); + for (auto& It : CreatedBlocks) + { + if (RemoteResult.IsError()) + { + break; + } + const IoHash& RawHash = It.first; + if (ForceUpload || NeededOtherAttachments.contains(RawHash)) + { + IoBuffer Payload = It.second; + ZEN_ASSERT(Payload); + SaveAttachmentsLatch.AddCount(1); + WorkerPool.ScheduleWork( + [&ChunkStore, &RemoteStore, &SaveAttachmentsLatch, &RemoteResult, Payload = std::move(Payload), RawHash]() { + auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + + RemoteProjectStore::SaveAttachmentResult Result = + RemoteStore.SaveAttachment(CompositeBuffer(SharedBuffer(Payload)), RawHash); + if (Result.ErrorCode) + { + RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); + ZEN_ERROR("Failed to save attachment '{}', {} ({}). Reason: '{}'", + RawHash, + NiceBytes(Payload.GetSize()), + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + + ZEN_DEBUG("Saved attachment {}, {} in {}", + RawHash, + NiceBytes(Payload.GetSize()), + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + return; + }); + } + It.second = {}; + } + } + + if (!BlockChunks.empty()) + { + ZEN_INFO("Saving chunk block attachments..."); + for (const std::vector<IoHash>& Chunks : BlockChunks) + { + if (RemoteResult.IsError()) + { + break; + } + std::vector<IoHash> NeededChunks; + if (ForceUpload) + { + NeededChunks = Chunks; + } + else + { + NeededChunks.reserve(Chunks.size()); + for (const IoHash& Chunk : Chunks) + { + if (NeededOtherAttachments.contains(Chunk)) + { + NeededChunks.push_back(Chunk); + } + } + if (NeededChunks.empty()) + { + continue; + } + } + SaveAttachmentsLatch.AddCount(1); + WorkerPool.ScheduleWork([&RemoteStore, + &ChunkStore, + &SaveAttachmentsLatch, + &RemoteResult, + &Chunks, + NeededChunks = std::move(NeededChunks), + ForceUpload]() { + auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); }); + std::vector<SharedBuffer> ChunkBuffers; + ChunkBuffers.reserve(NeededChunks.size()); + for (const IoHash& Chunk : NeededChunks) + { + IoBuffer ChunkPayload = ChunkStore.FindChunkByCid(Chunk); + if (!ChunkPayload) + { + RemoteResult.SetError(static_cast<int32_t>(HttpResponseCode::NotFound), + fmt::format("Missing chunk {}"sv, Chunk), + fmt::format("Unable to fetch attachment {} required by the oplog"sv, Chunk)); + ChunkBuffers.clear(); + break; + } + ChunkBuffers.emplace_back(SharedBuffer(std::move(ChunkPayload))); + } + RemoteProjectStore::SaveAttachmentsResult Result = RemoteStore.SaveAttachments(ChunkBuffers); + if (Result.ErrorCode) + { + RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); + ZEN_ERROR("Failed to save attachments with {} chunks ({}). Reason: '{}'", + Chunks.size(), + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + ZEN_DEBUG("Saved {} bulk attachments in {}", + Chunks.size(), + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + }); + } + } + SaveAttachmentsLatch.CountDown(); + while (!SaveAttachmentsLatch.Wait(1000)) + { + ZEN_INFO("Saving attachments, {} remaining...", SaveAttachmentsLatch.Remaining()); + } + SaveAttachmentsLatch.Wait(); + } + + if (!RemoteResult.IsError()) + { + ZEN_INFO("Finalizing oplog container..."); + RemoteProjectStore::Result ContainerFinalizeResult = RemoteStore.FinalizeContainer(ContainerSaveResult.RawHash); + if (ContainerFinalizeResult.ErrorCode) + { + RemoteResult.SetError(ContainerFinalizeResult.ErrorCode, ContainerFinalizeResult.Reason, ContainerFinalizeResult.Text); + ZEN_ERROR("Failed to finalize oplog container {} ({}). Reason: '{}'", + ContainerSaveResult.RawHash, + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + } + ZEN_DEBUG("Finalized container in {}", NiceTimeSpanMs(static_cast<uint64_t>(ContainerFinalizeResult.ElapsedSeconds * 1000))); + } + } + + RemoteProjectStore::Result Result = RemoteResult.ConvertResult(); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + ZEN_INFO("Saved oplog {} in {}", + RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE", + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + return Result; +}; + +RemoteProjectStore::Result +SaveOplogContainer(ProjectStore::Oplog& Oplog, + const CbObject& ContainerObject, + const std::function<bool(const IoHash& RawHash)>& HasAttachment, + const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock, + const std::function<void(const IoHash& RawHash)>& OnNeedAttachment) +{ + using namespace std::literals; + + Stopwatch Timer; + + CbArrayView LargeChunksArray = ContainerObject["chunks"sv].AsArrayView(); + for (CbFieldView LargeChunksField : LargeChunksArray) + { + IoHash AttachmentHash = LargeChunksField.AsBinaryAttachment(); + if (HasAttachment(AttachmentHash)) + { + continue; + } + OnNeedAttachment(AttachmentHash); + }; + + CbArrayView BlocksArray = ContainerObject["blocks"sv].AsArrayView(); + for (CbFieldView BlockField : BlocksArray) + { + CbObjectView BlockView = BlockField.AsObjectView(); + IoHash BlockHash = BlockView["rawhash"sv].AsBinaryAttachment(); + + CbArrayView ChunksArray = BlockView["chunks"sv].AsArrayView(); + if (BlockHash == IoHash::Zero) + { + std::vector<IoHash> NeededChunks; + NeededChunks.reserve(ChunksArray.GetSize()); + for (CbFieldView ChunkField : ChunksArray) + { + IoHash ChunkHash = ChunkField.AsBinaryAttachment(); + if (HasAttachment(ChunkHash)) + { + continue; + } + NeededChunks.emplace_back(ChunkHash); + } + + if (!NeededChunks.empty()) + { + OnNeedBlock(IoHash::Zero, std::move(NeededChunks)); + } + continue; + } + + for (CbFieldView ChunkField : ChunksArray) + { + IoHash ChunkHash = ChunkField.AsHash(); + if (HasAttachment(ChunkHash)) + { + continue; + } + + OnNeedBlock(BlockHash, {}); + break; + } + }; + + MemoryView OpsSection = ContainerObject["ops"sv].AsBinaryView(); + IoBuffer OpsBuffer(IoBuffer::Wrap, OpsSection.GetData(), OpsSection.GetSize()); + IoBuffer SectionPayload = CompressedBuffer::FromCompressedNoValidate(std::move(OpsBuffer)).Decompress().AsIoBuffer(); + + CbObject SectionObject = LoadCompactBinaryObject(SectionPayload); + if (!SectionObject) + { + ZEN_ERROR("Failed to save oplog container. Reason: '{}'", "Section has unexpected data type"); + return RemoteProjectStore::Result{gsl::narrow<int>(HttpResponseCode::BadRequest), + Timer.GetElapsedTimeMs() / 1000.500, + "Section has unexpected data type", + "Failed to save oplog container"}; + } + + CbArrayView OpsArray = SectionObject["ops"sv].AsArrayView(); + for (CbFieldView OpEntry : OpsArray) + { + CbObjectView Core = OpEntry.AsObjectView(); + BinaryWriter Writer; + Core.CopyTo(Writer); + MemoryView OpView = Writer.GetView(); + IoBuffer OpBuffer(IoBuffer::Wrap, OpView.GetData(), OpView.GetSize()); + CbObject Op(SharedBuffer(OpBuffer), CbFieldType::HasFieldType); + const uint32_t OpLsn = Oplog.AppendNewOplogEntry(Op); + if (OpLsn == ProjectStore::Oplog::kInvalidOp) + { + return RemoteProjectStore::Result{gsl::narrow<int>(HttpResponseCode::BadRequest), + Timer.GetElapsedTimeMs() / 1000.500, + "Failed saving op", + "Failed to save oplog container"}; + } + ZEN_DEBUG("oplog entry #{}", OpLsn); + } + return RemoteProjectStore::Result{.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500}; +} + +RemoteProjectStore::Result +LoadOplog(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, ProjectStore::Oplog& Oplog, bool ForceDownload) +{ + using namespace std::literals; + + Stopwatch Timer; + + // We are creating a worker thread pool here since we are download a lot of attachments in one go and we dont want to keep a + // WorkerThreadPool alive + size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u); + WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount)); + + std::unordered_set<IoHash, IoHash::Hasher> Attachments; + std::vector<std::vector<IoHash>> ChunksInBlocks; + + RemoteProjectStore::LoadContainerResult LoadContainerResult = RemoteStore.LoadContainer(); + if (LoadContainerResult.ErrorCode) + { + ZEN_WARN("Failed to load oplog container, reason: '{}', error code: {}", LoadContainerResult.Reason, LoadContainerResult.ErrorCode); + return RemoteProjectStore::Result{.ErrorCode = LoadContainerResult.ErrorCode, + .ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500, + .Reason = LoadContainerResult.Reason, + .Text = LoadContainerResult.Text}; + } + ZEN_DEBUG("Loaded container in {}", NiceTimeSpanMs(static_cast<uint64_t>(LoadContainerResult.ElapsedSeconds * 1000))); + + AsyncRemoteResult RemoteResult; + Latch AttachmentsWorkLatch(1); + + auto HasAttachment = [&ChunkStore, ForceDownload](const IoHash& RawHash) { + return !ForceDownload && ChunkStore.ContainsChunk(RawHash); + }; + auto OnNeedBlock = [&RemoteStore, &ChunkStore, &WorkerPool, &ChunksInBlocks, &AttachmentsWorkLatch, &RemoteResult]( + const IoHash& BlockHash, + std::vector<IoHash>&& Chunks) { + if (BlockHash == IoHash::Zero) + { + AttachmentsWorkLatch.AddCount(1); + WorkerPool.ScheduleWork([&RemoteStore, &ChunkStore, &AttachmentsWorkLatch, &RemoteResult, Chunks = std::move(Chunks)]() { + auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + + RemoteProjectStore::LoadAttachmentsResult Result = RemoteStore.LoadAttachments(Chunks); + if (Result.ErrorCode) + { + RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text); + ZEN_ERROR("Failed to attachments with {} chunks ({}). Reason: '{}'", + Chunks.size(), + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + ZEN_DEBUG("Loaded {} bulk attachments in {}", + Chunks.size(), + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000))); + for (const auto& It : Result.Chunks) + { + ChunkStore.AddChunk(It.second.GetCompressed().Flatten().AsIoBuffer(), It.first, CidStore::InsertMode::kCopyOnly); + } + }); + return; + } + AttachmentsWorkLatch.AddCount(1); + WorkerPool.ScheduleWork([&AttachmentsWorkLatch, &ChunkStore, &RemoteStore, BlockHash, &RemoteResult]() { + auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash); + if (BlockResult.ErrorCode) + { + RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text); + ZEN_ERROR("Failed to load oplog container, missing attachment {} ({}). Reason: '{}'", + BlockHash, + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + ZEN_DEBUG("Loaded block attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(BlockResult.ElapsedSeconds * 1000))); + + if (!IterateBlock(std::move(BlockResult.Bytes), [&ChunkStore](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) { + ChunkStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), AttachmentRawHash); + })) + { + RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError), + fmt::format("Invalid format for block {}", BlockHash), + {}); + ZEN_ERROR("Failed to load oplog container, attachment {} has invalid format ({}). Reason: '{}'", + BlockHash, + RemoteResult.GetError(), + RemoteResult.GetErrorReason()); + return; + } + }); + }; + + auto OnNeedAttachment = + [&RemoteStore, &ChunkStore, &WorkerPool, &AttachmentsWorkLatch, &RemoteResult, &Attachments](const IoHash& RawHash) { + if (!Attachments.insert(RawHash).second) + { + return; + } + + AttachmentsWorkLatch.AddCount(1); + WorkerPool.ScheduleWork([&RemoteStore, &ChunkStore, &RemoteResult, &AttachmentsWorkLatch, RawHash]() { + auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); }); + if (RemoteResult.IsError()) + { + return; + } + RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash); + if (AttachmentResult.ErrorCode) + { + RemoteResult.SetError(AttachmentResult.ErrorCode, AttachmentResult.Reason, AttachmentResult.Text); + ZEN_ERROR("Failed to download attachment {}, reason: '{}', error code: {}", + RawHash, + AttachmentResult.Reason, + AttachmentResult.ErrorCode); + return; + } + ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(AttachmentResult.ElapsedSeconds * 1000))); + ChunkStore.AddChunk(AttachmentResult.Bytes, RawHash); + }); + }; + + RemoteProjectStore::Result Result = + SaveOplogContainer(Oplog, LoadContainerResult.ContainerObject, HasAttachment, OnNeedBlock, OnNeedAttachment); + + AttachmentsWorkLatch.CountDown(); + while (!AttachmentsWorkLatch.Wait(1000)) + { + ZEN_INFO("Loading attachments, {} remaining...", AttachmentsWorkLatch.Remaining()); + } + AttachmentsWorkLatch.Wait(); + if (Result.ErrorCode == 0) + { + Result = RemoteResult.ConvertResult(); + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + + ZEN_INFO("Loaded oplog {} in {}", + RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE", + NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0))); + + return Result; +} + +} // namespace zen diff --git a/src/zenserver/projectstore/remoteprojectstore.h b/src/zenserver/projectstore/remoteprojectstore.h new file mode 100644 index 000000000..dcabaedd4 --- /dev/null +++ b/src/zenserver/projectstore/remoteprojectstore.h @@ -0,0 +1,111 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "projectstore.h" + +#include <unordered_set> + +namespace zen { + +class CidStore; +class WorkerThreadPool; + +class RemoteProjectStore +{ +public: + struct Result + { + int32_t ErrorCode{}; + double ElapsedSeconds{}; + std::string Reason; + std::string Text; + }; + + struct SaveResult : public Result + { + std::unordered_set<IoHash, IoHash::Hasher> Needs; + IoHash RawHash; + }; + + struct SaveAttachmentResult : public Result + { + }; + + struct SaveAttachmentsResult : public Result + { + }; + + struct LoadAttachmentResult : public Result + { + IoBuffer Bytes; + }; + + struct LoadContainerResult : public Result + { + CbObject ContainerObject; + }; + + struct LoadAttachmentsResult : public Result + { + std::vector<std::pair<IoHash, CompressedBuffer>> Chunks; + }; + + struct RemoteStoreInfo + { + bool CreateBlocks; + bool UseTempBlockFiles; + std::string Description; + }; + + virtual ~RemoteProjectStore() {} + + virtual RemoteStoreInfo GetInfo() const = 0; + + virtual SaveResult SaveContainer(const IoBuffer& Payload) = 0; + virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) = 0; + virtual Result FinalizeContainer(const IoHash& RawHash) = 0; + virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Payloads) = 0; + + virtual LoadContainerResult LoadContainer() = 0; + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) = 0; + virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) = 0; +}; + +struct RemoteStoreOptions +{ + size_t MaxBlockSize = 128u * 1024u * 1024u; + size_t MaxChunkEmbedSize = 1024u * 1024u; +}; + +RemoteProjectStore::LoadContainerResult BuildContainer( + CidStore& ChunkStore, + ProjectStore::Oplog& Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize, + bool BuildBlocks, + const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock, + const std::function<void(const IoHash&)>& OnLargeAttachment, + const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks); + +RemoteProjectStore::Result SaveOplogContainer(ProjectStore::Oplog& Oplog, + const CbObject& ContainerObject, + const std::function<bool(const IoHash& RawHash)>& HasAttachment, + const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock, + const std::function<void(const IoHash& RawHash)>& OnNeedAttachment); + +RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore, + RemoteProjectStore& RemoteStore, + ProjectStore::Oplog& Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize, + bool BuildBlocks, + bool UseTempBlocks, + bool ForceUpload); + +RemoteProjectStore::Result LoadOplog(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, ProjectStore::Oplog& Oplog, bool ForceDownload); + +CompressedBuffer GenerateBlock(std::vector<SharedBuffer>&& Chunks); +bool IterateBlock(IoBuffer&& CompressedBlock, std::function<void(CompressedBuffer&& Chunk, const IoHash& AttachmentHash)> Visitor); + +} // namespace zen diff --git a/src/zenserver/projectstore/zenremoteprojectstore.cpp b/src/zenserver/projectstore/zenremoteprojectstore.cpp new file mode 100644 index 000000000..6ff471ae5 --- /dev/null +++ b/src/zenserver/projectstore/zenremoteprojectstore.cpp @@ -0,0 +1,341 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenremoteprojectstore.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compositebuffer.h> +#include <zencore/fmtutils.h> +#include <zencore/scopeguard.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zenhttp/httpshared.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +using namespace std::literals; + +class ZenRemoteStore : public RemoteProjectStore +{ +public: + ZenRemoteStore(std::string_view HostAddress, + std::string_view Project, + std::string_view Oplog, + size_t MaxBlockSize, + size_t MaxChunkEmbedSize) + : m_HostAddress(HostAddress) + , m_ProjectStoreUrl(fmt::format("{}/prj"sv, m_HostAddress)) + , m_Project(Project) + , m_Oplog(Oplog) + , m_MaxBlockSize(MaxBlockSize) + , m_MaxChunkEmbedSize(MaxChunkEmbedSize) + { + } + + virtual RemoteStoreInfo GetInfo() const override + { + return {.CreateBlocks = false, .UseTempBlockFiles = false, .Description = fmt::format("[zen] {}"sv, m_HostAddress)}; + } + + virtual SaveResult SaveContainer(const IoBuffer& Payload) override + { + Stopwatch Timer; + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + + std::string SaveRequest = fmt::format("{}/{}/oplog/{}/save"sv, m_ProjectStoreUrl, m_Project, m_Oplog); + Session->SetUrl({SaveRequest}); + Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbObject))}}); + MemoryView Data(Payload.GetView()); + Session->SetBody({reinterpret_cast<const char*>(Data.GetData()), Data.GetSize()}); + cpr::Response Response = Session->Post(); + SaveResult Result = SaveResult{ConvertResult(Response)}; + + if (Result.ErrorCode) + { + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + IoBuffer ResponsePayload(IoBuffer::Wrap, Response.text.data(), Response.text.size()); + CbObject ResponseObject = LoadCompactBinaryObject(ResponsePayload); + if (!ResponseObject) + { + Result.Reason = fmt::format("The response for {}/{}/{} is not formatted as a compact binary object"sv, + m_ProjectStoreUrl, + m_Project, + m_Oplog); + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + CbArrayView NeedsArray = ResponseObject["need"sv].AsArrayView(); + for (CbFieldView FieldView : NeedsArray) + { + IoHash ChunkHash = FieldView.AsHash(); + Result.Needs.insert(ChunkHash); + } + + Result.RawHash = IoHash::HashBuffer(Payload); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override + { + Stopwatch Timer; + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + + std::string SaveRequest = fmt::format("{}/{}/oplog/{}/{}"sv, m_ProjectStoreUrl, m_Project, m_Oplog, RawHash); + Session->SetUrl({SaveRequest}); + Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCompressedBinary))}}); + uint64_t SizeLeft = Payload.GetSize(); + CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0); + auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, SizeLeft); + MutableMemoryView Data(buffer, size); + Payload.CopyTo(Data, BufferIt); + SizeLeft -= size; + return true; + }; + Session->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback)); + cpr::Response Response = Session->Post(); + SaveAttachmentResult Result = SaveAttachmentResult{ConvertResult(Response)}; + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override + { + Stopwatch Timer; + + CbPackage RequestPackage; + { + CbObjectWriter RequestWriter; + RequestWriter.AddString("method"sv, "putchunks"sv); + RequestWriter.BeginArray("chunks"sv); + { + for (const SharedBuffer& Chunk : Chunks) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(Chunk, RawHash, RawSize); + RequestWriter.AddHash(RawHash); + RequestPackage.AddAttachment(CbAttachment(Compressed, RawHash)); + } + } + RequestWriter.EndArray(); // "chunks" + RequestPackage.SetObject(RequestWriter.Save()); + } + CompositeBuffer Payload = FormatPackageMessageBuffer(RequestPackage, FormatFlags::kDefault); + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + std::string SaveRequest = fmt::format("{}/{}/oplog/{}/rpc"sv, m_ProjectStoreUrl, m_Project, m_Oplog); + Session->SetUrl({SaveRequest}); + Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbPackage))}}); + + uint64_t SizeLeft = Payload.GetSize(); + CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0); + auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, SizeLeft); + MutableMemoryView Data(buffer, size); + Payload.CopyTo(Data, BufferIt); + SizeLeft -= size; + return true; + }; + Session->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback)); + cpr::Response Response = Session->Post(); + SaveAttachmentsResult Result = SaveAttachmentsResult{ConvertResult(Response)}; + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override + { + Stopwatch Timer; + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + std::string SaveRequest = fmt::format("{}/{}/oplog/{}/rpc"sv, m_ProjectStoreUrl, m_Project, m_Oplog); + + CbObject Request; + { + CbObjectWriter RequestWriter; + RequestWriter.AddString("method"sv, "getchunks"sv); + RequestWriter.BeginArray("chunks"sv); + { + for (const IoHash& RawHash : RawHashes) + { + RequestWriter.AddHash(RawHash); + } + } + RequestWriter.EndArray(); // "chunks" + Request = RequestWriter.Save(); + } + IoBuffer Payload = Request.GetBuffer().AsIoBuffer(); + Session->SetBody(cpr::Body{(const char*)Payload.GetData(), Payload.GetSize()}); + Session->SetUrl(SaveRequest); + Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbObject))}, + {"Accept", std::string(MapContentTypeToString(HttpContentType::kCbPackage))}}); + + cpr::Response Response = Session->Post(); + LoadAttachmentsResult Result = LoadAttachmentsResult{ConvertResult(Response)}; + if (!Result.ErrorCode) + { + CbPackage Package = ParsePackageMessage(IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size())); + std::span<const CbAttachment> Attachments = Package.GetAttachments(); + Result.Chunks.reserve(Attachments.size()); + for (const CbAttachment& Attachment : Attachments) + { + Result.Chunks.emplace_back( + std::pair<IoHash, CompressedBuffer>{Attachment.GetHash(), Attachment.AsCompressedBinary().MakeOwned()}); + } + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + }; + + virtual Result FinalizeContainer(const IoHash&) override + { + Stopwatch Timer; + + RwLock::ExclusiveLockScope _(SessionsLock); + Sessions.clear(); + return {.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500}; + } + + virtual LoadContainerResult LoadContainer() override + { + Stopwatch Timer; + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + std::string SaveRequest = fmt::format("{}/{}/oplog/{}/load"sv, m_ProjectStoreUrl, m_Project, m_Oplog); + Session->SetUrl(SaveRequest); + Session->SetHeader({{"Accept", std::string(MapContentTypeToString(HttpContentType::kCbObject))}}); + Session->SetParameters( + {{"maxblocksize", fmt::format("{}", m_MaxBlockSize)}, {"maxchunkembedsize", fmt::format("{}", m_MaxChunkEmbedSize)}}); + cpr::Response Response = Session->Get(); + + LoadContainerResult Result = LoadContainerResult{ConvertResult(Response)}; + if (!Result.ErrorCode) + { + Result.ContainerObject = LoadCompactBinaryObject(IoBuffer(IoBuffer::Clone, Response.text.data(), Response.text.size())); + if (!Result.ContainerObject) + { + Result.Reason = fmt::format("The response for {}/{}/{} is not formatted as a compact binary object"sv, + m_ProjectStoreUrl, + m_Project, + m_Oplog); + Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError); + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + + virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override + { + Stopwatch Timer; + + std::unique_ptr<cpr::Session> Session(AllocateSession()); + auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); }); + + std::string LoadRequest = fmt::format("{}/{}/oplog/{}/{}"sv, m_ProjectStoreUrl, m_Project, m_Oplog, RawHash); + Session->SetUrl({LoadRequest}); + Session->SetHeader({{"Accept", std::string(MapContentTypeToString(HttpContentType::kCompressedBinary))}}); + cpr::Response Response = Session->Get(); + LoadAttachmentResult Result = LoadAttachmentResult{ConvertResult(Response)}; + if (!Result.ErrorCode) + { + Result.Bytes = IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()); + } + Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500; + return Result; + } + +private: + std::unique_ptr<cpr::Session> AllocateSession() + { + RwLock::ExclusiveLockScope _(SessionsLock); + if (Sessions.empty()) + { + Sessions.emplace_back(std::make_unique<cpr::Session>()); + } + std::unique_ptr<cpr::Session> Session = std::move(Sessions.back()); + Sessions.pop_back(); + return Session; + } + + void ReleaseSession(std::unique_ptr<cpr::Session>&& Session) + { + RwLock::ExclusiveLockScope _(SessionsLock); + Sessions.emplace_back(std::move(Session)); + } + + static Result ConvertResult(const cpr::Response& Response) + { + std::string Text; + std::string Reason = Response.reason; + int32_t ErrorCode = 0; + if (Response.error.code != cpr::ErrorCode::OK) + { + ErrorCode = static_cast<int32_t>(Response.error.code); + if (!Response.error.message.empty()) + { + Reason = Response.error.message; + } + } + else if (!IsHttpSuccessCode(Response.status_code)) + { + ErrorCode = static_cast<int32_t>(Response.status_code); + + if (auto It = Response.header.find("Content-Type"); It != Response.header.end()) + { + zen::HttpContentType ContentType = zen::ParseContentType(It->second); + if (ContentType == zen::HttpContentType::kText) + { + Text = Response.text; + } + } + + Reason = fmt::format("{}"sv, Response.status_code); + } + return {.ErrorCode = ErrorCode, .ElapsedSeconds = Response.elapsed, .Reason = Reason, .Text = Text}; + } + + RwLock SessionsLock; + std::vector<std::unique_ptr<cpr::Session>> Sessions; + + const std::string m_HostAddress; + const std::string m_ProjectStoreUrl; + const std::string m_Project; + const std::string m_Oplog; + const size_t m_MaxBlockSize; + const size_t m_MaxChunkEmbedSize; +}; + +std::unique_ptr<RemoteProjectStore> +CreateZenRemoteStore(const ZenRemoteStoreOptions& Options) +{ + std::string Url = Options.Url; + if (Url.find("://"sv) == std::string::npos) + { + // Assume https URL + Url = fmt::format("http://{}"sv, Url); + } + std::unique_ptr<RemoteProjectStore> RemoteStore = + std::make_unique<ZenRemoteStore>(Url, Options.ProjectId, Options.OplogId, Options.MaxBlockSize, Options.MaxChunkEmbedSize); + return RemoteStore; +} + +} // namespace zen diff --git a/src/zenserver/projectstore/zenremoteprojectstore.h b/src/zenserver/projectstore/zenremoteprojectstore.h new file mode 100644 index 000000000..ef9dcad8c --- /dev/null +++ b/src/zenserver/projectstore/zenremoteprojectstore.h @@ -0,0 +1,18 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "remoteprojectstore.h" + +namespace zen { + +struct ZenRemoteStoreOptions : RemoteStoreOptions +{ + std::string Url; + std::string ProjectId; + std::string OplogId; +}; + +std::unique_ptr<RemoteProjectStore> CreateZenRemoteStore(const ZenRemoteStoreOptions& Options); + +} // namespace zen diff --git a/src/zenserver/resource.h b/src/zenserver/resource.h new file mode 100644 index 000000000..f2e3b471b --- /dev/null +++ b/src/zenserver/resource.h @@ -0,0 +1,18 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +//{{NO_DEPENDENCIES}} +// Microsoft Visual C++ generated include file. +// Used by zenserver.rc +// +#define IDI_ICON1 101 + +// Next default values for new objects +// +#ifdef APSTUDIO_INVOKED +# ifndef APSTUDIO_READONLY_SYMBOLS +# define _APS_NEXT_RESOURCE_VALUE 102 +# define _APS_NEXT_COMMAND_VALUE 40001 +# define _APS_NEXT_CONTROL_VALUE 1001 +# define _APS_NEXT_SYMED_VALUE 101 +# endif +#endif diff --git a/src/zenserver/targetver.h b/src/zenserver/targetver.h new file mode 100644 index 000000000..d432d6993 --- /dev/null +++ b/src/zenserver/targetver.h @@ -0,0 +1,10 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +// Including SDKDDKVer.h defines the highest available Windows platform. + +// If you wish to build your application for a previous Windows platform, include WinSDKVer.h and +// set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h. + +#include <SDKDDKVer.h> diff --git a/src/zenserver/testing/httptest.cpp b/src/zenserver/testing/httptest.cpp new file mode 100644 index 000000000..349a95ab3 --- /dev/null +++ b/src/zenserver/testing/httptest.cpp @@ -0,0 +1,207 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "httptest.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/timer.h> + +namespace zen { + +using namespace std::literals; + +HttpTestingService::HttpTestingService() +{ + m_Router.RegisterRoute( + "hello", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "hello_slow", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { + Stopwatch Timer; + Sleep(1000); + Request.WriteResponse(HttpResponseCode::OK, + HttpContentType::kText, + fmt::format("hello, took me {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs()))); + }); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "hello_veryslow", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) { + Stopwatch Timer; + Sleep(60000); + Request.WriteResponse(HttpResponseCode::OK, + HttpContentType::kText, + fmt::format("hello, took me {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs()))); + }); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "hello_throw", + [](HttpRouterRequest& Req) { + Req.ServerRequest().WriteResponseAsync([](HttpServerRequest&) { throw std::runtime_error("intentional error"); }); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "hello_noresponse", + [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponseAsync([](HttpServerRequest&) {}); }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "metrics", + [this](HttpRouterRequest& Req) { + metrics::OperationTiming::Scope _(m_TimingStats); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "get_metrics", + [this](HttpRouterRequest& Req) { + CbObjectWriter Cbo; + EmitSnapshot("requests", m_TimingStats, Cbo); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "json", + [this](HttpRouterRequest& Req) { + CbObjectWriter Obj; + Obj.AddBool("ok", true); + Obj.AddInteger("counter", ++m_Counter); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save()); + }, + HttpVerb::kGet); + + m_Router.RegisterRoute( + "echo", + [](HttpRouterRequest& Req) { + IoBuffer Body = Req.ServerRequest().ReadPayload(); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Body); + }, + HttpVerb::kPost); + + m_Router.RegisterRoute( + "package", + [](HttpRouterRequest& Req) { + CbPackage Pkg = Req.ServerRequest().ReadPayloadPackage(); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Pkg); + }, + HttpVerb::kPost); +} + +HttpTestingService::~HttpTestingService() +{ +} + +const char* +HttpTestingService::BaseUri() const +{ + return "/testing/"; +} + +void +HttpTestingService::HandleRequest(HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +Ref<IHttpPackageHandler> +HttpTestingService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) +{ + RwLock::ExclusiveLockScope _(m_RwLock); + + const uint32_t RequestId = HttpServiceRequest.RequestId(); + + if (auto It = m_HandlerMap.find(RequestId); It != m_HandlerMap.end()) + { + Ref<HttpTestingService::PackageHandler> Handler = std::move(It->second); + + m_HandlerMap.erase(It); + + return Handler; + } + + auto InsertResult = m_HandlerMap.insert({RequestId, Ref<PackageHandler>()}); + + _.ReleaseNow(); + + return (InsertResult.first->second = Ref<PackageHandler>(new PackageHandler(*this, RequestId))); +} + +void +HttpTestingService::RegisterHandlers(WebSocketServer& Server) +{ + Server.RegisterRequestHandler("SayHello"sv, *this); +} + +bool +HttpTestingService::HandleRequest(const WebSocketMessage& RequestMsg) +{ + CbObjectView Request = RequestMsg.Body().GetObject(); + + std::string_view Method = Request["Method"].AsString(); + + if (Method != "SayHello"sv) + { + return false; + } + + CbObjectWriter Response; + Response.AddString("Result"sv, "Hello Friend!!"); + + WebSocketMessage ResponseMsg; + ResponseMsg.SetMessageType(WebSocketMessageType::kResponse); + ResponseMsg.SetCorrelationId(RequestMsg.CorrelationId()); + ResponseMsg.SetSocketId(RequestMsg.SocketId()); + ResponseMsg.SetBody(Response.Save()); + + SocketServer().SendResponse(std::move(ResponseMsg)); + + return true; +} + +////////////////////////////////////////////////////////////////////////// + +HttpTestingService::PackageHandler::PackageHandler(HttpTestingService& Svc, uint32_t RequestId) : m_Svc(Svc), m_RequestId(RequestId) +{ +} + +HttpTestingService::PackageHandler::~PackageHandler() +{ +} + +void +HttpTestingService::PackageHandler::FilterOffer(std::vector<IoHash>& OfferCids) +{ + ZEN_UNUSED(OfferCids); + // No-op + return; +} +void +HttpTestingService::PackageHandler::OnRequestBegin() +{ +} + +void +HttpTestingService::PackageHandler::OnRequestComplete() +{ +} + +IoBuffer +HttpTestingService::PackageHandler::CreateTarget(const IoHash& Cid, uint64_t StorageSize) +{ + ZEN_UNUSED(Cid); + return IoBuffer{StorageSize}; +} + +} // namespace zen diff --git a/src/zenserver/testing/httptest.h b/src/zenserver/testing/httptest.h new file mode 100644 index 000000000..57d2d63f3 --- /dev/null +++ b/src/zenserver/testing/httptest.h @@ -0,0 +1,55 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/logging.h> +#include <zencore/stats.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> + +#include <atomic> + +namespace zen { + +/** + * Test service to facilitate testing the HTTP framework and client interactions + */ +class HttpTestingService : public HttpService, public WebSocketService +{ +public: + HttpTestingService(); + ~HttpTestingService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(HttpServerRequest& Request) override; + virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest) override; + + class PackageHandler : public IHttpPackageHandler + { + public: + PackageHandler(HttpTestingService& Svc, uint32_t RequestId); + ~PackageHandler(); + + virtual void FilterOffer(std::vector<IoHash>& OfferCids) override; + virtual void OnRequestBegin() override; + virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) override; + virtual void OnRequestComplete() override; + + private: + HttpTestingService& m_Svc; + uint32_t m_RequestId; + }; + +private: + virtual void RegisterHandlers(WebSocketServer& Server) override; + virtual bool HandleRequest(const WebSocketMessage& Request) override; + + HttpRequestRouter m_Router; + std::atomic<uint32_t> m_Counter{0}; + metrics::OperationTiming m_TimingStats; + + RwLock m_RwLock; + std::unordered_map<uint32_t, Ref<PackageHandler>> m_HandlerMap; +}; + +} // namespace zen diff --git a/src/zenserver/upstream/hordecompute.cpp b/src/zenserver/upstream/hordecompute.cpp new file mode 100644 index 000000000..64d9fff72 --- /dev/null +++ b/src/zenserver/upstream/hordecompute.cpp @@ -0,0 +1,1457 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "upstreamapply.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include "jupiter.h" + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/compactbinarypackage.h> +# include <zencore/compactbinaryvalidation.h> +# include <zencore/fmtutils.h> +# include <zencore/session.h> +# include <zencore/stream.h> +# include <zencore/thread.h> +# include <zencore/timer.h> +# include <zencore/workthreadpool.h> + +# include <zenstore/cidstore.h> + +# include <auth/authmgr.h> +# include <upstream/upstreamcache.h> + +# include "cache/structuredcachestore.h" +# include "diag/logging.h" + +# include <fmt/format.h> + +# include <algorithm> +# include <atomic> +# include <set> +# include <stack> + +namespace zen { + +using namespace std::literals; + +static const IoBuffer EmptyBuffer; +static const IoHash EmptyBufferId = IoHash::HashBuffer(EmptyBuffer); + +namespace detail { + + class HordeUpstreamApplyEndpoint final : public UpstreamApplyEndpoint + { + public: + HordeUpstreamApplyEndpoint(const CloudCacheClientOptions& ComputeOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& StorageAuthConfig, + CidStore& CidStore, + AuthMgr& Mgr) + : m_Log(logging::Get("upstream-apply")) + , m_CidStore(CidStore) + , m_AuthMgr(Mgr) + { + m_DisplayName = fmt::format("{} - '{}'+'{}'", ComputeOptions.Name, ComputeOptions.ServiceUrl, StorageOptions.ServiceUrl); + m_ChannelId = fmt::format("zen-{}", zen::GetSessionIdString()); + + { + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + + if (ComputeAuthConfig.OAuthUrl.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = ComputeAuthConfig.OAuthUrl, + .ClientId = ComputeAuthConfig.OAuthClientId, + .ClientSecret = ComputeAuthConfig.OAuthClientSecret}); + } + else if (ComputeAuthConfig.OpenIdProvider.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(ComputeAuthConfig.OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + else + { + CloudCacheAccessToken AccessToken{.Value = std::string(ComputeAuthConfig.AccessToken), + .ExpireTime = CloudCacheAccessToken::TimePoint::max()}; + TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken); + } + + m_Client = new CloudCacheClient(ComputeOptions, std::move(TokenProvider)); + } + + { + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + + if (StorageAuthConfig.OAuthUrl.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = StorageAuthConfig.OAuthUrl, + .ClientId = StorageAuthConfig.OAuthClientId, + .ClientSecret = StorageAuthConfig.OAuthClientSecret}); + } + else if (StorageAuthConfig.OpenIdProvider.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(StorageAuthConfig.OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + else + { + CloudCacheAccessToken AccessToken{.Value = std::string(StorageAuthConfig.AccessToken), + .ExpireTime = CloudCacheAccessToken::TimePoint::max()}; + TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken); + } + + m_StorageClient = new CloudCacheClient(StorageOptions, std::move(TokenProvider)); + } + } + + virtual ~HordeUpstreamApplyEndpoint() = default; + + virtual UpstreamEndpointHealth Initialize() override { return CheckHealth(); } + + virtual bool IsHealthy() const override { return m_HealthOk.load(); } + + virtual UpstreamEndpointHealth CheckHealth() override + { + try + { + CloudCacheSession Session(m_Client); + CloudCacheResult Result = Session.Authenticate(); + + m_HealthOk = Result.ErrorCode == 0; + + return {.Reason = std::move(Result.Reason), .Ok = Result.Success}; + } + catch (std::exception& Err) + { + return {.Reason = Err.what(), .Ok = false}; + } + } + + virtual std::string_view DisplayName() const override { return m_DisplayName; } + + virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) override + { + PostUpstreamApplyResult ApplyResult{}; + ApplyResult.Timepoints.merge(ApplyRecord.Timepoints); + + try + { + UpstreamData UpstreamData; + if (!ProcessApplyKey(ApplyRecord, UpstreamData)) + { + return {.Error{.ErrorCode = -1, .Reason = "Failed to generate task data"}}; + } + + { + ApplyResult.Timepoints["zen-storage-build-ref"] = DateTime::NowTicks(); + + bool AlreadyQueued; + { + std::scoped_lock Lock(m_TaskMutex); + AlreadyQueued = m_PendingTasks.contains(UpstreamData.TaskId); + } + if (AlreadyQueued) + { + // Pending task is already queued, return success + ApplyResult.Success = true; + return ApplyResult; + } + m_PendingTasks[UpstreamData.TaskId] = std::move(ApplyRecord); + } + + CloudCacheSession ComputeSession(m_Client); + CloudCacheSession StorageSession(m_StorageClient); + + { + CloudCacheResult Result = BatchPutBlobsIfMissing(StorageSession, UpstreamData.Blobs, UpstreamData.CasIds); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-upload-blobs"] = DateTime::NowTicks(); + if (!Result.Success) + { + ApplyResult.Error = {.ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload blobs"}; + return ApplyResult; + } + UpstreamData.Blobs.clear(); + UpstreamData.CasIds.clear(); + } + + { + CloudCacheResult Result = BatchPutCompressedBlobsIfMissing(StorageSession, UpstreamData.Cids); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-upload-compressed-blobs"] = DateTime::NowTicks(); + if (!Result.Success) + { + ApplyResult.Error = { + .ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload compressed blobs"}; + return ApplyResult; + } + UpstreamData.Cids.clear(); + } + + { + CloudCacheResult Result = BatchPutObjectsIfMissing(StorageSession, UpstreamData.Objects); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-upload-objects"] = DateTime::NowTicks(); + if (!Result.Success) + { + ApplyResult.Error = {.ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload objects"}; + return ApplyResult; + } + } + + { + PutRefResult RefResult = StorageSession.PutRef(StorageSession.Client().DefaultBlobStoreNamespace(), + "requests"sv, + UpstreamData.TaskId, + UpstreamData.Objects[UpstreamData.TaskId].GetBuffer().AsIoBuffer(), + ZenContentType::kCbObject); + Log().debug("Put ref {} Need={} Bytes={} Duration={}s Result={}", + UpstreamData.TaskId, + RefResult.Needs.size(), + RefResult.Bytes, + RefResult.ElapsedSeconds, + RefResult.Success); + ApplyResult.Bytes += RefResult.Bytes; + ApplyResult.ElapsedSeconds += RefResult.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-put-ref"] = DateTime::NowTicks(); + + if (RefResult.Needs.size() > 0) + { + Log().error("Failed to add task ref {} due to {} missing blobs", UpstreamData.TaskId, RefResult.Needs.size()); + for (const auto& Hash : RefResult.Needs) + { + Log().debug("Task ref {} missing blob {}", UpstreamData.TaskId, Hash); + } + + ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode, + .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason) + : "Failed to add task ref due to missing blob"}; + return ApplyResult; + } + + if (!RefResult.Success) + { + ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode, + .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason) : "Failed to add task ref"}; + return ApplyResult; + } + UpstreamData.Objects.clear(); + } + + { + CbObjectWriter Writer; + Writer.AddString("c"sv, m_ChannelId); + Writer.AddObjectAttachment("r"sv, UpstreamData.RequirementsId); + Writer.BeginArray("t"sv); + Writer.AddObjectAttachment(UpstreamData.TaskId); + Writer.EndArray(); + CbObject TasksObject = Writer.Save(); + IoBuffer TasksData = TasksObject.GetBuffer().AsIoBuffer(); + + CloudCacheResult Result = ComputeSession.PostComputeTasks(TasksData); + Log().debug("Post compute task {} Bytes={} Duration={}s Result={}", + TasksObject.GetHash(), + Result.Bytes, + Result.ElapsedSeconds, + Result.Success); + ApplyResult.Bytes += Result.Bytes; + ApplyResult.ElapsedSeconds += Result.ElapsedSeconds; + ApplyResult.Timepoints["zen-horde-post-task"] = DateTime::NowTicks(); + if (!Result.Success) + { + { + std::scoped_lock Lock(m_TaskMutex); + m_PendingTasks.erase(UpstreamData.TaskId); + } + + ApplyResult.Error = {.ErrorCode = Result.ErrorCode, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to post compute task"}; + return ApplyResult; + } + } + + Log().info("Task posted {}", UpstreamData.TaskId); + ApplyResult.Success = true; + return ApplyResult; + } + catch (std::exception& Err) + { + m_HealthOk = false; + return {.Error{.ErrorCode = -1, .Reason = Err.what()}}; + } + } + + [[nodiscard]] CloudCacheResult BatchPutBlobsIfMissing(CloudCacheSession& Session, + const std::map<IoHash, IoBuffer>& Blobs, + const std::set<IoHash>& CasIds) + { + if (Blobs.size() == 0 && CasIds.size() == 0) + { + return {.Success = true}; + } + + int64_t Bytes{}; + double ElapsedSeconds{}; + + // Batch check for missing blobs + std::set<IoHash> Keys; + std::transform(Blobs.begin(), Blobs.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; }); + Keys.insert(CasIds.begin(), CasIds.end()); + + CloudCacheExistsResult ExistsResult = Session.BlobExists(Session.Client().DefaultBlobStoreNamespace(), Keys); + Log().debug("Queried {} missing blobs Need={} Duration={}s Result={}", + Keys.size(), + ExistsResult.Needs.size(), + ExistsResult.ElapsedSeconds, + ExistsResult.Success); + ElapsedSeconds += ExistsResult.ElapsedSeconds; + if (!ExistsResult.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1, + .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if blobs exist"}; + } + + for (const auto& Hash : ExistsResult.Needs) + { + IoBuffer DataBuffer; + if (Blobs.contains(Hash)) + { + DataBuffer = Blobs.at(Hash); + } + else + { + DataBuffer = m_CidStore.FindChunkByCid(Hash); + if (!DataBuffer) + { + Log().warn("Put blob FAILED, input chunk '{}' missing", Hash); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put blobs"}; + } + } + + CloudCacheResult Result = Session.PutBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer); + Log().debug("Put blob {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success); + Bytes += Result.Bytes; + ElapsedSeconds += Result.ElapsedSeconds; + if (!Result.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put blobs"}; + } + } + + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + + [[nodiscard]] CloudCacheResult BatchPutCompressedBlobsIfMissing(CloudCacheSession& Session, const std::set<IoHash>& Cids) + { + if (Cids.size() == 0) + { + return {.Success = true}; + } + + int64_t Bytes{}; + double ElapsedSeconds{}; + + // Batch check for missing compressed blobs + CloudCacheExistsResult ExistsResult = Session.CompressedBlobExists(Session.Client().DefaultBlobStoreNamespace(), Cids); + Log().debug("Queried {} missing compressed blobs Need={} Duration={}s Result={}", + Cids.size(), + ExistsResult.Needs.size(), + ExistsResult.ElapsedSeconds, + ExistsResult.Success); + ElapsedSeconds += ExistsResult.ElapsedSeconds; + if (!ExistsResult.Success) + { + return { + .Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1, + .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if compressed blobs exist"}; + } + + for (const auto& Hash : ExistsResult.Needs) + { + IoBuffer DataBuffer = m_CidStore.FindChunkByCid(Hash); + if (!DataBuffer) + { + Log().warn("Put compressed blob FAILED, input CID chunk '{}' missing", Hash); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put compressed blobs"}; + } + + CloudCacheResult Result = Session.PutCompressedBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer); + Log().debug("Put compressed blob {} Bytes={} Duration={}s Result={}", + Hash, + Result.Bytes, + Result.ElapsedSeconds, + Result.Success); + Bytes += Result.Bytes; + ElapsedSeconds += Result.ElapsedSeconds; + if (!Result.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put compressed blobs"}; + } + } + + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + + [[nodiscard]] CloudCacheResult BatchPutObjectsIfMissing(CloudCacheSession& Session, const std::map<IoHash, CbObject>& Objects) + { + if (Objects.size() == 0) + { + return {.Success = true}; + } + + int64_t Bytes{}; + double ElapsedSeconds{}; + + // Batch check for missing objects + std::set<IoHash> Keys; + std::transform(Objects.begin(), Objects.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; }); + + CloudCacheExistsResult ExistsResult = Session.ObjectExists(Session.Client().DefaultBlobStoreNamespace(), Keys); + Log().debug("Queried {} missing objects Need={} Duration={}s Result={}", + Keys.size(), + ExistsResult.Needs.size(), + ExistsResult.ElapsedSeconds, + ExistsResult.Success); + ElapsedSeconds += ExistsResult.ElapsedSeconds; + if (!ExistsResult.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1, + .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if objects exist"}; + } + + for (const auto& Hash : ExistsResult.Needs) + { + CloudCacheResult Result = + Session.PutObject(Session.Client().DefaultBlobStoreNamespace(), Hash, Objects.at(Hash).GetBuffer().AsIoBuffer()); + Log().debug("Put object {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success); + Bytes += Result.Bytes; + ElapsedSeconds += Result.ElapsedSeconds; + if (!Result.Success) + { + return {.Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1, + .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put objects"}; + } + } + + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + + enum class ComputeTaskState : int32_t + { + Queued = 0, + Executing = 1, + Complete = 2, + }; + + enum class ComputeTaskOutcome : int32_t + { + Success = 0, + Failed = 1, + Cancelled = 2, + NoResult = 3, + Exipred = 4, + BlobNotFound = 5, + Exception = 6, + }; + + [[nodiscard]] static std::string_view ComputeTaskStateToString(const ComputeTaskState Outcome) + { + switch (Outcome) + { + case ComputeTaskState::Queued: + return "Queued"sv; + case ComputeTaskState::Executing: + return "Executing"sv; + case ComputeTaskState::Complete: + return "Complete"sv; + }; + return "Unknown"sv; + } + + [[nodiscard]] static std::string_view ComputeTaskOutcomeToString(const ComputeTaskOutcome Outcome) + { + switch (Outcome) + { + case ComputeTaskOutcome::Success: + return "Success"sv; + case ComputeTaskOutcome::Failed: + return "Failed"sv; + case ComputeTaskOutcome::Cancelled: + return "Cancelled"sv; + case ComputeTaskOutcome::NoResult: + return "NoResult"sv; + case ComputeTaskOutcome::Exipred: + return "Exipred"sv; + case ComputeTaskOutcome::BlobNotFound: + return "BlobNotFound"sv; + case ComputeTaskOutcome::Exception: + return "Exception"sv; + }; + return "Unknown"sv; + } + + virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) override + { + int64_t Bytes{}; + double ElapsedSeconds{}; + + { + std::scoped_lock Lock(m_TaskMutex); + if (m_PendingTasks.empty()) + { + if (m_CompletedTasks.empty()) + { + // Nothing to do. + return {.Success = true}; + } + + UpstreamApplyCompleted CompletedTasks; + std::swap(CompletedTasks, m_CompletedTasks); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true}; + } + } + + try + { + CloudCacheSession ComputeSession(m_Client); + + CloudCacheResult UpdatesResult = ComputeSession.GetComputeUpdates(m_ChannelId); + Log().debug("Get compute updates Bytes={} Duration={}s Result={}", + UpdatesResult.Bytes, + UpdatesResult.ElapsedSeconds, + UpdatesResult.Success); + Bytes += UpdatesResult.Bytes; + ElapsedSeconds += UpdatesResult.ElapsedSeconds; + if (!UpdatesResult.Success) + { + return {.Error{.ErrorCode = UpdatesResult.ErrorCode, .Reason = std::move(UpdatesResult.Reason)}, + .Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds}; + } + + if (!UpdatesResult.Success) + { + return {.Error{.ErrorCode = -1, .Reason = "Failed get task updates"}, .Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds}; + } + + CbObject TaskStatus = LoadCompactBinaryObject(std::move(UpdatesResult.Response)); + + for (auto& It : TaskStatus["u"sv]) + { + CbObjectView Status = It.AsObjectView(); + IoHash TaskId = Status["h"sv].AsHash(); + const ComputeTaskState State = (ComputeTaskState)Status["s"sv].AsInt32(); + const ComputeTaskOutcome Outcome = (ComputeTaskOutcome)Status["o"sv].AsInt32(); + + Log().info("Task {} State={}", TaskId, ComputeTaskStateToString(State)); + + // Only completed tasks need to be processed + if (State != ComputeTaskState::Complete) + { + continue; + } + + IoHash WorkerId{}; + IoHash ActionId{}; + UpstreamApplyType ApplyType{}; + + { + std::scoped_lock Lock(m_TaskMutex); + auto TaskIt = m_PendingTasks.find(TaskId); + if (TaskIt != m_PendingTasks.end()) + { + WorkerId = TaskIt->second.WorkerDescriptor.GetHash(); + ActionId = TaskIt->second.Action.GetHash(); + ApplyType = TaskIt->second.Type; + m_PendingTasks.erase(TaskIt); + } + } + + if (WorkerId == IoHash::Zero) + { + Log().warn("Task {} missing from pending tasks", TaskId); + continue; + } + + std::map<std::string, uint64_t> Timepoints; + ProcessQueueTimings(Status["qs"sv].AsObjectView(), Timepoints); + ProcessExecuteTimings(Status["es"sv].AsObjectView(), Timepoints); + + if (Outcome != ComputeTaskOutcome::Success) + { + const std::string_view Detail = Status["d"sv].AsString(); + { + std::scoped_lock Lock(m_TaskMutex); + m_CompletedTasks[WorkerId][ActionId] = { + .Error{.ErrorCode = -1, .Reason = fmt::format("Task {} {}", ComputeTaskOutcomeToString(Outcome), Detail)}, + .Timepoints = std::move(Timepoints)}; + } + continue; + } + + Timepoints["zen-complete-queue-added"] = DateTime::NowTicks(); + ThreadPool.ScheduleWork([this, + ApplyType, + ResultHash = Status["r"sv].AsHash(), + Timepoints = std::move(Timepoints), + TaskId = std::move(TaskId), + WorkerId = std::move(WorkerId), + ActionId = std::move(ActionId)]() mutable { + Timepoints["zen-complete-queue-dispatched"] = DateTime::NowTicks(); + GetUpstreamApplyResult Result = ProcessTaskStatus(ApplyType, ResultHash); + Timepoints["zen-complete-queue-complete"] = DateTime::NowTicks(); + Result.Timepoints.merge(Timepoints); + + Log().debug("Task Processed {} Files={} Attachments={} ExitCode={}", + TaskId, + Result.OutputFiles.size(), + Result.OutputPackage.GetAttachments().size(), + Result.Error.ErrorCode); + { + std::scoped_lock Lock(m_TaskMutex); + m_CompletedTasks[WorkerId][ActionId] = std::move(Result); + } + }); + } + + { + std::scoped_lock Lock(m_TaskMutex); + if (m_CompletedTasks.empty()) + { + // Nothing to do. + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true}; + } + UpstreamApplyCompleted CompletedTasks; + std::swap(CompletedTasks, m_CompletedTasks); + return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true}; + } + } + catch (std::exception& Err) + { + m_HealthOk = false; + return { + .Error{.ErrorCode = -1, .Reason = Err.what()}, + .Bytes = Bytes, + .ElapsedSeconds = ElapsedSeconds, + }; + } + } + + virtual UpstreamApplyEndpointStats& Stats() override { return m_Stats; } + + private: + spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + CidStore& m_CidStore; + AuthMgr& m_AuthMgr; + std::string m_DisplayName; + RefPtr<CloudCacheClient> m_Client; + RefPtr<CloudCacheClient> m_StorageClient; + UpstreamApplyEndpointStats m_Stats; + std::atomic_bool m_HealthOk{false}; + std::string m_ChannelId; + + std::mutex m_TaskMutex; + std::unordered_map<IoHash, UpstreamApplyRecord> m_PendingTasks; + UpstreamApplyCompleted m_CompletedTasks; + + struct UpstreamData + { + std::map<IoHash, IoBuffer> Blobs; + std::map<IoHash, CbObject> Objects; + std::set<IoHash> CasIds; + std::set<IoHash> Cids; + IoHash TaskId; + IoHash RequirementsId; + }; + + struct UpstreamDirectory + { + std::filesystem::path Path; + std::map<std::string, UpstreamDirectory> Directories; + std::set<std::string> Files; + }; + + static void ProcessQueueTimings(CbObjectView QueueStats, std::map<std::string, uint64_t>& Timepoints) + { + uint64_t Ticks = QueueStats["t"sv].AsDateTimeTicks(); + if (Ticks == 0) + { + return; + } + + // Scope is an array of miliseconds after start time + // TODO: cleanup + Timepoints["horde-queue-added"] = Ticks; + int Index = 0; + for (auto& Item : QueueStats["s"sv].AsArrayView()) + { + Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond; + switch (Index) + { + case 0: + Timepoints["horde-queue-dispatched"] = Ticks; + break; + case 1: + Timepoints["horde-queue-complete"] = Ticks; + break; + } + Index++; + } + } + + static void ProcessExecuteTimings(CbObjectView ExecutionStats, std::map<std::string, uint64_t>& Timepoints) + { + uint64_t Ticks = ExecutionStats["t"sv].AsDateTimeTicks(); + if (Ticks == 0) + { + return; + } + + // Scope is an array of miliseconds after start time + // TODO: cleanup + Timepoints["horde-execution-start"] = Ticks; + int Index = 0; + for (auto& Item : ExecutionStats["s"sv].AsArrayView()) + { + Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond; + switch (Index) + { + case 0: + Timepoints["horde-execution-download-ref"] = Ticks; + break; + case 1: + Timepoints["horde-execution-download-input"] = Ticks; + break; + case 2: + Timepoints["horde-execution-execute"] = Ticks; + break; + case 3: + Timepoints["horde-execution-upload-log"] = Ticks; + break; + case 4: + Timepoints["horde-execution-upload-output"] = Ticks; + break; + case 5: + Timepoints["horde-execution-upload-ref"] = Ticks; + break; + } + Index++; + } + } + + [[nodiscard]] GetUpstreamApplyResult ProcessTaskStatus(const UpstreamApplyType ApplyType, const IoHash& ResultHash) + { + try + { + CloudCacheSession Session(m_StorageClient); + + GetUpstreamApplyResult ApplyResult{}; + + IoHash StdOutHash; + IoHash StdErrHash; + IoHash OutputHash; + + std::map<IoHash, IoBuffer> BinaryData; + + { + CloudCacheResult ObjectRefResult = + Session.GetRef(Session.Client().DefaultBlobStoreNamespace(), "responses"sv, ResultHash, ZenContentType::kCbObject); + Log().debug("Get ref {} Bytes={} Duration={}s Result={}", + ResultHash, + ObjectRefResult.Bytes, + ObjectRefResult.ElapsedSeconds, + ObjectRefResult.Success); + ApplyResult.Bytes += ObjectRefResult.Bytes; + ApplyResult.ElapsedSeconds += ObjectRefResult.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-get-ref"] = DateTime::NowTicks(); + + if (!ObjectRefResult.Success) + { + ApplyResult.Error.Reason = "Failed to get result object data"; + return ApplyResult; + } + + CbObject ResultObject = LoadCompactBinaryObject(ObjectRefResult.Response); + ApplyResult.Error.ErrorCode = ResultObject["e"sv].AsInt32(); + StdOutHash = ResultObject["so"sv].AsBinaryAttachment(); + StdErrHash = ResultObject["se"sv].AsBinaryAttachment(); + OutputHash = ResultObject["o"sv].AsObjectAttachment(); + } + + { + std::set<IoHash> NeededData; + if (OutputHash != IoHash::Zero) + { + GetObjectReferencesResult ObjectReferenceResult = + Session.GetObjectReferences(Session.Client().DefaultBlobStoreNamespace(), OutputHash); + Log().debug("Get object references {} References={} Bytes={} Duration={}s Result={}", + ResultHash, + ObjectReferenceResult.References.size(), + ObjectReferenceResult.Bytes, + ObjectReferenceResult.ElapsedSeconds, + ObjectReferenceResult.Success); + ApplyResult.Bytes += ObjectReferenceResult.Bytes; + ApplyResult.ElapsedSeconds += ObjectReferenceResult.ElapsedSeconds; + ApplyResult.Timepoints["zen-storage-get-object-references"] = DateTime::NowTicks(); + + if (!ObjectReferenceResult.Success) + { + ApplyResult.Error.Reason = "Failed to get result object references"; + return ApplyResult; + } + + NeededData = std::move(ObjectReferenceResult.References); + } + + NeededData.insert(OutputHash); + NeededData.insert(StdOutHash); + NeededData.insert(StdErrHash); + + for (const auto& Hash : NeededData) + { + if (Hash == IoHash::Zero) + { + continue; + } + CloudCacheResult BlobResult = Session.GetBlob(Session.Client().DefaultBlobStoreNamespace(), Hash); + Log().debug("Get blob {} Bytes={} Duration={}s Result={}", + Hash, + BlobResult.Bytes, + BlobResult.ElapsedSeconds, + BlobResult.Success); + ApplyResult.Bytes += BlobResult.Bytes; + ApplyResult.ElapsedSeconds += BlobResult.ElapsedSeconds; + if (!BlobResult.Success) + { + ApplyResult.Error.Reason = "Failed to get blob"; + return ApplyResult; + } + BinaryData[Hash] = std::move(BlobResult.Response); + } + ApplyResult.Timepoints["zen-storage-get-blobs"] = DateTime::NowTicks(); + } + + ApplyResult.StdOut = StdOutHash != IoHash::Zero + ? std::string((const char*)BinaryData[StdOutHash].GetData(), BinaryData[StdOutHash].GetSize()) + : ""; + ApplyResult.StdErr = StdErrHash != IoHash::Zero + ? std::string((const char*)BinaryData[StdErrHash].GetData(), BinaryData[StdErrHash].GetSize()) + : ""; + + if (OutputHash == IoHash::Zero) + { + ApplyResult.Error.Reason = "Task completed with no output object"; + return ApplyResult; + } + + CbObject OutputObject = LoadCompactBinaryObject(BinaryData[OutputHash]); + + switch (ApplyType) + { + case UpstreamApplyType::Simple: + { + ResolveMerkleTreeDirectory(""sv, OutputHash, BinaryData, ApplyResult.OutputFiles); + for (const auto& Pair : BinaryData) + { + ApplyResult.FileData[Pair.first] = std::move(BinaryData.at(Pair.first)); + } + + ApplyResult.Success = ApplyResult.Error.ErrorCode == 0; + return ApplyResult; + } + break; + case UpstreamApplyType::Asset: + { + if (ApplyResult.Error.ErrorCode != 0) + { + ApplyResult.Error.Reason = "Task completed with errors"; + return ApplyResult; + } + + // Get build.output + IoHash BuildOutputId; + IoBuffer BuildOutput; + for (auto& It : OutputObject["f"sv]) + { + const CbObjectView FileObject = It.AsObjectView(); + if (FileObject["n"sv].AsString() == "Build.output"sv) + { + BuildOutputId = FileObject["h"sv].AsBinaryAttachment(); + BuildOutput = BinaryData[BuildOutputId]; + break; + } + } + + if (BuildOutput.GetSize() == 0) + { + ApplyResult.Error.Reason = "Build.output file not found in task results"; + return ApplyResult; + } + + // Get Output directory node + IoBuffer OutputDirectoryTree; + for (auto& It : OutputObject["d"sv]) + { + const CbObjectView DirectoryObject = It.AsObjectView(); + if (DirectoryObject["n"sv].AsString() == "Outputs"sv) + { + OutputDirectoryTree = BinaryData[DirectoryObject["h"sv].AsObjectAttachment()]; + break; + } + } + + if (OutputDirectoryTree.GetSize() == 0) + { + ApplyResult.Error.Reason = "Outputs directory not found in task results"; + return ApplyResult; + } + + // load build.output as CbObject + + // Move Outputs from Horde to CbPackage + + std::unordered_map<IoHash, IoHash> CidToCompressedId; + CbPackage OutputPackage; + CbObject OutputDirectoryTreeObject = LoadCompactBinaryObject(OutputDirectoryTree); + + for (auto& It : OutputDirectoryTreeObject["f"sv]) + { + CbObjectView FileObject = It.AsObjectView(); + // Name is the uncompressed hash + IoHash DecompressedId = IoHash::FromHexString(FileObject["n"sv].AsString()); + // Hash is the compressed data hash, and how it is stored in Horde + IoHash CompressedId = FileObject["h"sv].AsBinaryAttachment(); + + if (!BinaryData.contains(CompressedId)) + { + Log().warn("Object attachment chunk not retrieved from Horde {}", CompressedId); + ApplyResult.Error.Reason = "Object attachment chunk not retrieved from Horde"; + return ApplyResult; + } + CidToCompressedId[DecompressedId] = CompressedId; + } + + // Iterate attachments, verify all chunks exist, and add to CbPackage + bool AnyErrors = false; + CbObject BuildOutputObject = LoadCompactBinaryObject(BuildOutput); + BuildOutputObject.IterateAttachments([&](CbFieldView Field) { + const IoHash DecompressedId = Field.AsHash(); + if (!CidToCompressedId.contains(DecompressedId)) + { + Log().warn("Attachment not found {}", DecompressedId); + AnyErrors = true; + return; + } + const IoHash& CompressedId = CidToCompressedId.at(DecompressedId); + + if (!BinaryData.contains(CompressedId)) + { + Log().warn("Missing output {} compressed {} uncompressed", CompressedId, DecompressedId); + AnyErrors = true; + return; + } + + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer AttachmentBuffer = + CompressedBuffer::FromCompressed(SharedBuffer(BinaryData[CompressedId]), RawHash, RawSize); + + if (!AttachmentBuffer || RawHash != DecompressedId) + { + Log().warn( + "Invalid output encountered (not valid CompressedBuffer format) {} compressed {} uncompressed", + CompressedId, + DecompressedId); + AnyErrors = true; + return; + } + + ApplyResult.TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize(); + ApplyResult.TotalRawAttachmentBytes += RawSize; + + CbAttachment Attachment(AttachmentBuffer, DecompressedId); + OutputPackage.AddAttachment(Attachment); + }); + + if (AnyErrors) + { + ApplyResult.Error.Reason = "Failed to get result object attachment data"; + return ApplyResult; + } + + OutputPackage.SetObject(BuildOutputObject); + ApplyResult.OutputPackage = std::move(OutputPackage); + + ApplyResult.Success = ApplyResult.Error.ErrorCode == 0; + return ApplyResult; + } + break; + } + + ApplyResult.Error.Reason = "Unknown apply type"; + return ApplyResult; + } + catch (std::exception& Err) + { + return {.Error{.ErrorCode = -1, .Reason = Err.what()}}; + } + } + + [[nodiscard]] bool ProcessApplyKey(const UpstreamApplyRecord& ApplyRecord, UpstreamData& Data) + { + std::string ExecutablePath; + std::string WorkingDirectory; + std::vector<std::string> Arguments; + std::map<std::string, std::string> Environment; + std::set<std::filesystem::path> InputFiles; + std::set<std::string> Outputs; + std::map<std::filesystem::path, IoHash> InputFileHashes; + + ExecutablePath = ApplyRecord.WorkerDescriptor["path"sv].AsString(); + if (ExecutablePath.empty()) + { + Log().warn("process apply upstream FAILED, '{}', path missing from worker descriptor", + ApplyRecord.WorkerDescriptor.GetHash()); + return false; + } + + WorkingDirectory = ApplyRecord.WorkerDescriptor["workdir"sv].AsString(); + + for (auto& It : ApplyRecord.WorkerDescriptor["executables"sv]) + { + CbObjectView FileEntry = It.AsObjectView(); + if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds)) + { + return false; + } + } + + for (auto& It : ApplyRecord.WorkerDescriptor["files"sv]) + { + CbObjectView FileEntry = It.AsObjectView(); + if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds)) + { + return false; + } + } + + for (auto& It : ApplyRecord.WorkerDescriptor["dirs"sv]) + { + std::string_view Directory = It.AsString(); + std::string DummyFile = fmt::format("{}/.zen_empty_file", Directory); + InputFiles.insert(DummyFile); + Data.Blobs[EmptyBufferId] = EmptyBuffer; + InputFileHashes[DummyFile] = EmptyBufferId; + } + + if (!WorkingDirectory.empty()) + { + std::string DummyFile = fmt::format("{}/.zen_empty_file", WorkingDirectory); + InputFiles.insert(DummyFile); + Data.Blobs[EmptyBufferId] = EmptyBuffer; + InputFileHashes[DummyFile] = EmptyBufferId; + } + + for (auto& It : ApplyRecord.WorkerDescriptor["environment"sv]) + { + std::string_view Env = It.AsString(); + auto Index = Env.find('='); + if (Index == std::string_view::npos) + { + Log().warn("process apply upstream FAILED, environment '{}' malformed", Env); + return false; + } + + Environment[std::string(Env.substr(0, Index))] = Env.substr(Index + 1); + } + + switch (ApplyRecord.Type) + { + case UpstreamApplyType::Simple: + { + for (auto& It : ApplyRecord.WorkerDescriptor["arguments"sv]) + { + Arguments.push_back(std::string(It.AsString())); + } + + for (auto& It : ApplyRecord.WorkerDescriptor["outputs"sv]) + { + Outputs.insert(std::string(It.AsString())); + } + } + break; + case UpstreamApplyType::Asset: + { + static const std::filesystem::path BuildActionPath = "Build.action"sv; + static const std::filesystem::path InputPath = "Inputs"sv; + const IoHash ActionId = ApplyRecord.Action.GetHash(); + + Arguments.push_back("-Build=build.action"); + Outputs.insert("Build.output"); + Outputs.insert("Outputs"); + + InputFiles.insert(BuildActionPath); + InputFileHashes[BuildActionPath] = ActionId; + Data.Blobs[ActionId] = IoBufferBuilder::MakeCloneFromMemory(ApplyRecord.Action.GetBuffer().GetData(), + ApplyRecord.Action.GetBuffer().GetSize()); + + bool AnyErrors = false; + ApplyRecord.Action.IterateAttachments([&](CbFieldView Field) { + const IoHash Cid = Field.AsHash(); + const std::filesystem::path FilePath = {InputPath / Cid.ToHexString()}; + + if (!m_CidStore.ContainsChunk(Cid)) + { + Log().warn("process apply upstream FAILED, input CID chunk '{}' missing", Cid); + AnyErrors = true; + return; + } + + if (InputFiles.contains(FilePath)) + { + return; + } + + InputFiles.insert(FilePath); + InputFileHashes[FilePath] = Cid; + Data.Cids.insert(Cid); + }); + + if (AnyErrors) + { + return false; + } + } + break; + } + + const UpstreamDirectory RootDirectory = BuildDirectoryTree(InputFiles); + + CbObject Sandbox = BuildMerkleTreeDirectory(RootDirectory, InputFileHashes, Data.Cids, Data.Objects); + const IoHash SandboxHash = Sandbox.GetHash(); + Data.Objects[SandboxHash] = std::move(Sandbox); + + { + std::string_view HostPlatform = ApplyRecord.WorkerDescriptor["host"sv].AsString(); + if (HostPlatform.empty()) + { + Log().warn("process apply upstream FAILED, 'host' platform not provided"); + return false; + } + + int32_t LogicalCores = ApplyRecord.WorkerDescriptor["cores"sv].AsInt32(); + int64_t Memory = ApplyRecord.WorkerDescriptor["memory"sv].AsInt64(); + bool Exclusive = ApplyRecord.WorkerDescriptor["exclusive"sv].AsBool(); + + std::string Condition = fmt::format("Platform == '{}'", HostPlatform); + if (HostPlatform == "Win64") + { + // TODO + // Condition += " && Pool == 'Win-RemoteExec'"; + } + + std::map<std::string_view, int64_t> Resources; + if (LogicalCores > 0) + { + Resources["LogicalCores"sv] = LogicalCores; + } + if (Memory > 0) + { + Resources["RAM"sv] = std::max(Memory / 1024LL / 1024LL / 1024LL, 1LL); + } + + CbObject Requirements = BuildRequirements(Condition, Resources, Exclusive); + const IoHash RequirementsId = Requirements.GetHash(); + Data.Objects[RequirementsId] = std::move(Requirements); + Data.RequirementsId = RequirementsId; + } + + CbObject Task = BuildTask(ExecutablePath, Arguments, Environment, WorkingDirectory, SandboxHash, Data.RequirementsId, Outputs); + + const IoHash TaskId = Task.GetHash(); + Data.Objects[TaskId] = std::move(Task); + Data.TaskId = TaskId; + + return true; + } + + [[nodiscard]] bool ProcessFileEntry(const CbObjectView& FileEntry, + std::set<std::filesystem::path>& InputFiles, + std::map<std::filesystem::path, IoHash>& InputFileHashes, + std::set<IoHash>& CasIds) + { + const std::filesystem::path FilePath = FileEntry["name"sv].AsString(); + const IoHash ChunkId = FileEntry["hash"sv].AsHash(); + const uint64_t Size = FileEntry["size"sv].AsUInt64(); + + if (!m_CidStore.ContainsChunk(ChunkId)) + { + Log().warn("process apply upstream FAILED, worker CAS chunk '{}' missing", ChunkId); + return false; + } + + if (InputFiles.contains(FilePath)) + { + Log().warn("process apply upstream FAILED, worker CAS chunk '{}' size: {} duplicate filename {}", ChunkId, Size, FilePath); + return false; + } + + InputFiles.insert(FilePath); + InputFileHashes[FilePath] = ChunkId; + CasIds.insert(ChunkId); + return true; + } + + [[nodiscard]] UpstreamDirectory BuildDirectoryTree(const std::set<std::filesystem::path>& InputFiles) + { + static const std::filesystem::path RootPath; + std::map<std::filesystem::path, UpstreamDirectory*> AllDirectories; + UpstreamDirectory RootDirectory = {.Path = RootPath}; + + AllDirectories[RootPath] = &RootDirectory; + + // Build tree from flat list + for (const auto& Path : InputFiles) + { + if (Path.has_parent_path()) + { + if (!AllDirectories.contains(Path.parent_path())) + { + std::stack<std::string> PathSplit; + { + std::filesystem::path ParentPath = Path.parent_path(); + PathSplit.push(ParentPath.filename().string()); + while (ParentPath.has_parent_path()) + { + ParentPath = ParentPath.parent_path(); + PathSplit.push(ParentPath.filename().string()); + } + } + UpstreamDirectory* ParentPtr = &RootDirectory; + while (!PathSplit.empty()) + { + if (!ParentPtr->Directories.contains(PathSplit.top())) + { + std::filesystem::path NewParentPath = {ParentPtr->Path / PathSplit.top()}; + ParentPtr->Directories[PathSplit.top()] = {.Path = NewParentPath}; + AllDirectories[NewParentPath] = &ParentPtr->Directories[PathSplit.top()]; + } + ParentPtr = &ParentPtr->Directories[PathSplit.top()]; + PathSplit.pop(); + } + } + + AllDirectories[Path.parent_path()]->Files.insert(Path.filename().string()); + } + else + { + RootDirectory.Files.insert(Path.filename().string()); + } + } + + return RootDirectory; + } + + [[nodiscard]] CbObject BuildMerkleTreeDirectory(const UpstreamDirectory& RootDirectory, + const std::map<std::filesystem::path, IoHash>& InputFileHashes, + const std::set<IoHash>& Cids, + std::map<IoHash, CbObject>& Objects) + { + CbObjectWriter DirectoryTreeWriter; + + if (!RootDirectory.Files.empty()) + { + DirectoryTreeWriter.BeginArray("f"sv); + for (const auto& File : RootDirectory.Files) + { + const std::filesystem::path FilePath = {RootDirectory.Path / File}; + const IoHash& FileHash = InputFileHashes.at(FilePath); + const bool Compressed = Cids.contains(FileHash); + DirectoryTreeWriter.BeginObject(); + DirectoryTreeWriter.AddString("n"sv, File); + DirectoryTreeWriter.AddBinaryAttachment("h"sv, FileHash); + DirectoryTreeWriter.AddBool("c"sv, Compressed); + DirectoryTreeWriter.EndObject(); + } + DirectoryTreeWriter.EndArray(); + } + + if (!RootDirectory.Directories.empty()) + { + DirectoryTreeWriter.BeginArray("d"sv); + for (const auto& Item : RootDirectory.Directories) + { + CbObject Directory = BuildMerkleTreeDirectory(Item.second, InputFileHashes, Cids, Objects); + const IoHash DirectoryHash = Directory.GetHash(); + Objects[DirectoryHash] = std::move(Directory); + + DirectoryTreeWriter.BeginObject(); + DirectoryTreeWriter.AddString("n"sv, Item.first); + DirectoryTreeWriter.AddObjectAttachment("h"sv, DirectoryHash); + DirectoryTreeWriter.EndObject(); + } + DirectoryTreeWriter.EndArray(); + } + + return DirectoryTreeWriter.Save(); + } + + void ResolveMerkleTreeDirectory(const std::filesystem::path& ParentDirectory, + const IoHash& DirectoryHash, + const std::map<IoHash, IoBuffer>& Objects, + std::map<std::filesystem::path, IoHash>& OutputFiles) + { + CbObject Directory = LoadCompactBinaryObject(Objects.at(DirectoryHash)); + + for (auto& It : Directory["f"sv]) + { + const CbObjectView FileObject = It.AsObjectView(); + const std::filesystem::path Path = ParentDirectory / FileObject["n"sv].AsString(); + + OutputFiles[Path] = FileObject["h"sv].AsBinaryAttachment(); + } + + for (auto& It : Directory["d"sv]) + { + const CbObjectView DirectoryObject = It.AsObjectView(); + + ResolveMerkleTreeDirectory(ParentDirectory / DirectoryObject["n"sv].AsString(), + DirectoryObject["h"sv].AsObjectAttachment(), + Objects, + OutputFiles); + } + } + + [[nodiscard]] CbObject BuildRequirements(const std::string_view Condition, + const std::map<std::string_view, int64_t>& Resources, + const bool Exclusive) + { + CbObjectWriter Writer; + Writer.AddString("c", Condition); + if (!Resources.empty()) + { + Writer.BeginArray("r"); + for (const auto& Resource : Resources) + { + Writer.BeginArray(); + Writer.AddString(Resource.first); + Writer.AddInteger(Resource.second); + Writer.EndArray(); + } + Writer.EndArray(); + } + Writer.AddBool("e", Exclusive); + return Writer.Save(); + } + + [[nodiscard]] CbObject BuildTask(const std::string_view Executable, + const std::vector<std::string>& Arguments, + const std::map<std::string, std::string>& Environment, + const std::string_view WorkingDirectory, + const IoHash& SandboxHash, + const IoHash& RequirementsId, + const std::set<std::string>& Outputs) + { + CbObjectWriter TaskWriter; + TaskWriter.AddString("e"sv, Executable); + + if (!Arguments.empty()) + { + TaskWriter.BeginArray("a"sv); + for (const auto& Argument : Arguments) + { + TaskWriter.AddString(Argument); + } + TaskWriter.EndArray(); + } + + if (!Environment.empty()) + { + TaskWriter.BeginArray("v"sv); + for (const auto& Env : Environment) + { + TaskWriter.BeginArray(); + TaskWriter.AddString(Env.first); + TaskWriter.AddString(Env.second); + TaskWriter.EndArray(); + } + TaskWriter.EndArray(); + } + + if (!WorkingDirectory.empty()) + { + TaskWriter.AddString("w"sv, WorkingDirectory); + } + + TaskWriter.AddObjectAttachment("s"sv, SandboxHash); + TaskWriter.AddObjectAttachment("r"sv, RequirementsId); + + // Outputs + if (!Outputs.empty()) + { + TaskWriter.BeginArray("o"sv); + for (const auto& Output : Outputs) + { + TaskWriter.AddString(Output); + } + TaskWriter.EndArray(); + } + + return TaskWriter.Save(); + } + }; +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +std::unique_ptr<UpstreamApplyEndpoint> +UpstreamApplyEndpoint::CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& StorageAuthConfig, + CidStore& CidStore, + AuthMgr& Mgr) +{ + return std::make_unique<detail::HordeUpstreamApplyEndpoint>(ComputeOptions, + ComputeAuthConfig, + StorageOptions, + StorageAuthConfig, + CidStore, + Mgr); +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/upstream/jupiter.cpp b/src/zenserver/upstream/jupiter.cpp new file mode 100644 index 000000000..dbb185bec --- /dev/null +++ b/src/zenserver/upstream/jupiter.cpp @@ -0,0 +1,965 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "jupiter.h" + +#include "diag/formatters.h" +#include "diag/logging.h" + +#include <zencore/compactbinary.h> +#include <zencore/compositebuffer.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/trace.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +#include <fmt/format.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#if ZEN_PLATFORM_WINDOWS +# pragma comment(lib, "Crypt32.lib") +# pragma comment(lib, "Wldap32.lib") +#endif + +#include <json11.hpp> + +using namespace std::literals; + +namespace zen { + +namespace detail { + struct CloudCacheSessionState + { + CloudCacheSessionState(CloudCacheClient& Client) : m_Client(Client) {} + + const CloudCacheAccessToken& GetAccessToken(bool RefreshToken) + { + if (RefreshToken) + { + m_AccessToken = m_Client.AcquireAccessToken(); + } + + return m_AccessToken; + } + + cpr::Session& GetSession() { return m_Session; } + + void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout) + { + m_Session.SetBody({}); + m_Session.SetHeader({}); + m_Session.SetConnectTimeout(ConnectTimeout); + m_Session.SetTimeout(Timeout); + } + + private: + friend class zen::CloudCacheClient; + + CloudCacheClient& m_Client; + CloudCacheAccessToken m_AccessToken; + cpr::Session m_Session; + }; + +} // namespace detail + +CloudCacheSession::CloudCacheSession(CloudCacheClient* CacheClient) : m_Log(CacheClient->Logger()), m_CacheClient(CacheClient) +{ + m_SessionState = m_CacheClient->AllocSessionState(); +} + +CloudCacheSession::~CloudCacheSession() +{ + m_CacheClient->FreeSessionState(m_SessionState); +} + +CloudCacheResult +CloudCacheSession::Authenticate() +{ + const bool RefreshToken = true; + const CloudCacheAccessToken& AccessToken = GetAccessToken(RefreshToken); + + return {.Success = AccessToken.IsValid()}; +} + +CloudCacheResult +CloudCacheSession::GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType) +{ + const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream"; + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", ContentType}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetBlob(std::string_view Namespace, const IoHash& Key) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/octet-stream"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = + Success && Response.text.size() > 0 ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetCompressedBlob(std::string_view Namespace, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::GetCompressedBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-comp"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash) +{ + ZEN_TRACE_CPU("HordeClient::GetInlineBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-jupiter-inline"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + if (auto It = Response.header.find("X-Jupiter-InlinePayloadHash"); It != Response.header.end()) + { + const std::string& PayloadHashHeader = It->second; + if (PayloadHashHeader.length() == IoHash::StringLength) + { + OutPayloadHash = IoHash::FromHexString(PayloadHashHeader); + } + } + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +CloudCacheResult +CloudCacheSession::GetObject(std::string_view Namespace, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::GetObject"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +PutRefResult +CloudCacheSession::PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType) +{ + ZEN_TRACE_CPU("HordeClient::PutRef"); + + IoHash Hash = IoHash::HashBuffer(Ref.Data(), Ref.Size()); + + const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream"; + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption( + cpr::Header{{"Authorization", AccessToken.Value}, {"X-Jupiter-IoHash", Hash.ToHexString()}, {"Content-Type", ContentType}}); + Session.SetBody(cpr::Body{(const char*)Ref.Data(), Ref.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + PutRefResult Result; + Result.ErrorCode = static_cast<int32_t>(Response.error.code); + Result.Reason = std::move(Response.error.message); + return Result; + } + else if (!VerifyAccessToken(Response.status_code)) + { + PutRefResult Result; + Result.ErrorCode = 401; + Result.Reason = "Invalid access token"sv; + return Result; + } + + PutRefResult Result; + Result.Success = (Response.status_code == 200 || Response.status_code == 201); + Result.Bytes = Response.uploaded_bytes; + Result.ElapsedSeconds = Response.elapsed; + + if (Result.Success) + { + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + if (JsonError.empty()) + { + json11::Json::array Needs = Json["needs"].array_items(); + for (const auto& Need : Needs) + { + Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value())); + } + } + } + + return Result; +} + +FinalizeRefResult +CloudCacheSession::FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHash) +{ + ZEN_TRACE_CPU("HordeClient::FinalizeRef"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString() << "/finalize/" + << RefHash.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, + {"X-Jupiter-IoHash", RefHash.ToHexString()}, + {"Content-Type", "application/x-ue-cb"}}); + Session.SetBody(cpr::Body{}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + FinalizeRefResult Result; + Result.ErrorCode = static_cast<int32_t>(Response.error.code); + Result.Reason = std::move(Response.error.message); + return Result; + } + else if (!VerifyAccessToken(Response.status_code)) + { + FinalizeRefResult Result; + Result.ErrorCode = 401; + Result.Reason = "Invalid access token"sv; + return Result; + } + + FinalizeRefResult Result; + Result.Success = (Response.status_code == 200 || Response.status_code == 201); + Result.Bytes = Response.uploaded_bytes; + Result.ElapsedSeconds = Response.elapsed; + + if (Result.Success) + { + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + if (JsonError.empty()) + { + json11::Json::array Needs = Json["needs"].array_items(); + for (const auto& Need : Needs) + { + Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value())); + } + } + } + + return Result; +} + +CloudCacheResult +CloudCacheSession::PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob) +{ + ZEN_TRACE_CPU("HordeClient::PutBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/octet-stream"}}); + Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob) +{ + ZEN_TRACE_CPU("HordeClient::PutCompressedBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}}); + Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Payload) +{ + ZEN_TRACE_CPU("HordeClient::PutCompressedBlob"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}}); + uint64_t SizeLeft = Payload.GetSize(); + CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0); + auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) { + size = Min<size_t>(size, SizeLeft); + MutableMemoryView Data(buffer, size); + Payload.CopyTo(Data, BufferIt); + SizeLeft -= size; + return true; + }; + Session.SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback)); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object) +{ + ZEN_TRACE_CPU("HordeClient::PutObject"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}}); + Session.SetBody(cpr::Body{(const char*)Object.Data(), Object.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Success = (Response.status_code == 200 || Response.status_code == 201)}; +} + +CloudCacheResult +CloudCacheSession::RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::RefExists"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Head(); + ZEN_DEBUG("HEAD {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +GetObjectReferencesResult +CloudCacheSession::GetObjectReferences(std::string_view Namespace, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::GetObjectReferences"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString() << "/references"; + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}}; + } + + GetObjectReferencesResult Result{ + CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}}; + + if (Result.Success) + { + IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size()); + const CbObject ReferencesResponse = LoadCompactBinaryObject(Buffer); + for (auto& Item : ReferencesResponse["references"sv]) + { + Result.References.insert(Item.AsHash()); + } + } + + return Result; +} + +CloudCacheResult +CloudCacheSession::BlobExists(std::string_view Namespace, const IoHash& Key) +{ + return CacheTypeExists(Namespace, "blobs"sv, Key); +} + +CloudCacheResult +CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const IoHash& Key) +{ + return CacheTypeExists(Namespace, "compressed-blobs"sv, Key); +} + +CloudCacheResult +CloudCacheSession::ObjectExists(std::string_view Namespace, const IoHash& Key) +{ + return CacheTypeExists(Namespace, "objects"sv, Key); +} + +CloudCacheExistsResult +CloudCacheSession::BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys) +{ + return CacheTypeExists(Namespace, "blobs"sv, Keys); +} + +CloudCacheExistsResult +CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys) +{ + return CacheTypeExists(Namespace, "compressed-blobs"sv, Keys); +} + +CloudCacheExistsResult +CloudCacheSession::ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys) +{ + return CacheTypeExists(Namespace, "objects"sv, Keys); +} + +CloudCacheResult +CloudCacheSession::PostComputeTasks(IoBuffer TasksData) +{ + ZEN_TRACE_CPU("HordeClient::PostComputeTasks"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}}); + Session.SetBody(cpr::Body{(const char*)TasksData.Data(), TasksData.Size()}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +CloudCacheResult +CloudCacheSession::GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds) +{ + ZEN_TRACE_CPU("HordeClient::GetComputeUpdates"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster() << "/updates/" << ChannelId + << "?wait=" << WaitSeconds; + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +std::vector<IoHash> +CloudCacheSession::Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl(); + Uri << "/api/v1/s/" << Namespace; + + ZEN_UNUSED(BucketId, ChunkHashes); + + return {}; +} + +cpr::Session& +CloudCacheSession::GetSession() +{ + return m_SessionState->GetSession(); +} + +CloudCacheAccessToken +CloudCacheSession::GetAccessToken(bool RefreshToken) +{ + return m_SessionState->GetAccessToken(RefreshToken); +} + +bool +CloudCacheSession::VerifyAccessToken(long StatusCode) +{ + return StatusCode != 401; +} + +CloudCacheResult +CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key) +{ + ZEN_TRACE_CPU("HordeClient::CacheTypeExists"); + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/" << Key.ToHexString(); + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}}); + Session.SetOption(cpr::Body{}); + + cpr::Response Response = Session.Head(); + ZEN_DEBUG("HEAD {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {.ErrorCode = 401, .Reason = std::string("Invalid access token")}; + } + + return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +CloudCacheExistsResult +CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys) +{ + ZEN_TRACE_CPU("HordeClient::CacheTypeExists"); + + ExtendableStringBuilder<256> Body; + Body << "["; + for (const auto& Key : Keys) + { + Body << (Body.Size() != 1 ? ",\"" : "\"") << Key.ToHexString() << "\""; + } + Body << "]"; + + ExtendableStringBuilder<256> Uri; + Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/exist"; + + cpr::Session& Session = GetSession(); + const CloudCacheAccessToken& AccessToken = GetAccessToken(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetOption( + cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}, {"Content-Type", "application/json"}}); + Session.SetOption(cpr::Body(Body.ToString())); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}}; + } + else if (!VerifyAccessToken(Response.status_code)) + { + return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}}; + } + + CloudCacheExistsResult Result{ + CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}}; + + if (Result.Success) + { + IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size()); + const CbObject ExistsResponse = LoadCompactBinaryObject(Buffer); + for (auto& Item : ExistsResponse["needs"sv]) + { + Result.Needs.insert(Item.AsHash()); + } + } + + return Result; +} + +/** + * An access token provider that holds a token that will never change. + */ +class StaticTokenProvider final : public CloudCacheTokenProvider +{ +public: + StaticTokenProvider(CloudCacheAccessToken Token) : m_Token(std::move(Token)) {} + + virtual ~StaticTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Token; } + +private: + CloudCacheAccessToken m_Token; +}; + +std::unique_ptr<CloudCacheTokenProvider> +CloudCacheTokenProvider::CreateFromStaticToken(CloudCacheAccessToken Token) +{ + return std::make_unique<StaticTokenProvider>(std::move(Token)); +} + +class OAuthClientCredentialsTokenProvider final : public CloudCacheTokenProvider +{ +public: + OAuthClientCredentialsTokenProvider(const CloudCacheTokenProvider::OAuthClientCredentialsParams& Params) + { + m_Url = std::string(Params.Url); + m_ClientId = std::string(Params.ClientId); + m_ClientSecret = std::string(Params.ClientSecret); + } + + virtual ~OAuthClientCredentialsTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() final override + { + using namespace std::chrono; + + std::string Body = + fmt::format("client_id={}&scope=cache_access&grant_type=client_credentials&client_secret={}", m_ClientId, m_ClientSecret); + + cpr::Response Response = + cpr::Post(cpr::Url{m_Url}, cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}}, cpr::Body{std::move(Body)}); + + if (Response.error || Response.status_code != 200) + { + return {}; + } + + std::string JsonError; + json11::Json Json = json11::Json::parse(Response.text, JsonError); + + if (JsonError.empty() == false) + { + return {}; + } + + std::string Token = Json["access_token"].string_value(); + int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()); + CloudCacheAccessToken::TimePoint ExpireTime = CloudCacheAccessToken::Clock::now() + seconds(ExpiresInSeconds); + + return {.Value = fmt::format("Bearer {}", Token), .ExpireTime = ExpireTime}; + } + +private: + std::string m_Url; + std::string m_ClientId; + std::string m_ClientSecret; +}; + +std::unique_ptr<CloudCacheTokenProvider> +CloudCacheTokenProvider::CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params) +{ + return std::make_unique<OAuthClientCredentialsTokenProvider>(Params); +} + +class CallbackTokenProvider final : public CloudCacheTokenProvider +{ +public: + CallbackTokenProvider(std::function<CloudCacheAccessToken()>&& Callback) : m_Callback(std::move(Callback)) {} + + virtual ~CallbackTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Callback(); } + +private: + std::function<CloudCacheAccessToken()> m_Callback; +}; + +std::unique_ptr<CloudCacheTokenProvider> +CloudCacheTokenProvider::CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback) +{ + return std::make_unique<CallbackTokenProvider>(std::move(Callback)); +} + +CloudCacheClient::CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider) +: m_Log(zen::logging::Get("jupiter")) +, m_ServiceUrl(Options.ServiceUrl) +, m_DefaultDdcNamespace(Options.DdcNamespace) +, m_DefaultBlobStoreNamespace(Options.BlobStoreNamespace) +, m_ComputeCluster(Options.ComputeCluster) +, m_ConnectTimeout(Options.ConnectTimeout) +, m_Timeout(Options.Timeout) +, m_TokenProvider(std::move(TokenProvider)) +{ + ZEN_ASSERT(m_TokenProvider.get() != nullptr); +} + +CloudCacheClient::~CloudCacheClient() +{ + RwLock::ExclusiveLockScope _(m_SessionStateLock); + + for (auto State : m_SessionStateCache) + { + delete State; + } +} + +CloudCacheAccessToken +CloudCacheClient::AcquireAccessToken() +{ + ZEN_TRACE_CPU("HordeClient::AcquireAccessToken"); + + return m_TokenProvider->AcquireAccessToken(); +} + +detail::CloudCacheSessionState* +CloudCacheClient::AllocSessionState() +{ + detail::CloudCacheSessionState* State = nullptr; + + bool IsTokenValid = false; + + { + RwLock::ExclusiveLockScope _(m_SessionStateLock); + + if (m_SessionStateCache.empty() == false) + { + State = m_SessionStateCache.front(); + IsTokenValid = State->m_AccessToken.IsValid(); + + m_SessionStateCache.pop_front(); + } + } + + if (State == nullptr) + { + State = new detail::CloudCacheSessionState(*this); + } + + State->Reset(m_ConnectTimeout, m_Timeout); + + if (IsTokenValid == false) + { + State->m_AccessToken = m_TokenProvider->AcquireAccessToken(); + } + + return State; +} + +void +CloudCacheClient::FreeSessionState(detail::CloudCacheSessionState* State) +{ + RwLock::ExclusiveLockScope _(m_SessionStateLock); + m_SessionStateCache.push_front(State); +} + +} // namespace zen diff --git a/src/zenserver/upstream/jupiter.h b/src/zenserver/upstream/jupiter.h new file mode 100644 index 000000000..99e5c530f --- /dev/null +++ b/src/zenserver/upstream/jupiter.h @@ -0,0 +1,217 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/refcount.h> +#include <zencore/thread.h> +#include <zenhttp/httpserver.h> + +#include <atomic> +#include <chrono> +#include <list> +#include <memory> +#include <set> +#include <vector> + +struct ZenCacheValue; + +namespace cpr { +class Session; +} + +namespace zen { +namespace detail { + struct CloudCacheSessionState; +} + +class CbObjectView; +class CloudCacheClient; +class IoBuffer; +struct IoHash; + +/** + * Cached access token, for use with `Authorization:` header + */ +struct CloudCacheAccessToken +{ + using Clock = std::chrono::system_clock; + using TimePoint = Clock::time_point; + + static constexpr int64_t ExpireMarginInSeconds = 30; + + std::string Value; + TimePoint ExpireTime; + + bool IsValid() const + { + return Value.empty() == false && + ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count(); + } +}; + +struct CloudCacheResult +{ + IoBuffer Response; + int64_t Bytes{}; + double ElapsedSeconds{}; + int32_t ErrorCode{}; + std::string Reason; + bool Success = false; +}; + +struct PutRefResult : CloudCacheResult +{ + std::vector<IoHash> Needs; +}; + +struct FinalizeRefResult : CloudCacheResult +{ + std::vector<IoHash> Needs; +}; + +struct CloudCacheExistsResult : CloudCacheResult +{ + std::set<IoHash> Needs; +}; + +struct GetObjectReferencesResult : CloudCacheResult +{ + std::set<IoHash> References; +}; + +/** + * Context for performing Jupiter operations + * + * Maintains an HTTP connection so that subsequent operations don't need to go + * through the whole connection setup process + * + */ +class CloudCacheSession +{ +public: + CloudCacheSession(CloudCacheClient* CacheClient); + ~CloudCacheSession(); + + CloudCacheResult Authenticate(); + CloudCacheResult GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType); + CloudCacheResult GetBlob(std::string_view Namespace, const IoHash& Key); + CloudCacheResult GetCompressedBlob(std::string_view Namespace, const IoHash& Key); + CloudCacheResult GetObject(std::string_view Namespace, const IoHash& Key); + CloudCacheResult GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash); + + PutRefResult PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType); + CloudCacheResult PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob); + CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob); + CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Blob); + CloudCacheResult PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object); + + FinalizeRefResult FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHah); + + CloudCacheResult RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key); + + GetObjectReferencesResult GetObjectReferences(std::string_view Namespace, const IoHash& Key); + + CloudCacheResult BlobExists(std::string_view Namespace, const IoHash& Key); + CloudCacheResult CompressedBlobExists(std::string_view Namespace, const IoHash& Key); + CloudCacheResult ObjectExists(std::string_view Namespace, const IoHash& Key); + + CloudCacheExistsResult BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys); + CloudCacheExistsResult CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys); + CloudCacheExistsResult ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys); + + CloudCacheResult PostComputeTasks(IoBuffer TasksData); + CloudCacheResult GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds = 0); + + std::vector<IoHash> Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes); + + CloudCacheClient& Client() { return *m_CacheClient; }; + +private: + inline spdlog::logger& Log() { return m_Log; } + cpr::Session& GetSession(); + CloudCacheAccessToken GetAccessToken(bool RefreshToken = false); + bool VerifyAccessToken(long StatusCode); + + CloudCacheResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key); + + CloudCacheExistsResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys); + + spdlog::logger& m_Log; + RefPtr<CloudCacheClient> m_CacheClient; + detail::CloudCacheSessionState* m_SessionState; +}; + +/** + * Access token provider interface + */ +class CloudCacheTokenProvider +{ +public: + virtual ~CloudCacheTokenProvider() = default; + + virtual CloudCacheAccessToken AcquireAccessToken() = 0; + + static std::unique_ptr<CloudCacheTokenProvider> CreateFromStaticToken(CloudCacheAccessToken Token); + + struct OAuthClientCredentialsParams + { + std::string_view Url; + std::string_view ClientId; + std::string_view ClientSecret; + }; + + static std::unique_ptr<CloudCacheTokenProvider> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params); + + static std::unique_ptr<CloudCacheTokenProvider> CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback); +}; + +struct CloudCacheClientOptions +{ + std::string_view Name; + std::string_view ServiceUrl; + std::string_view DdcNamespace; + std::string_view BlobStoreNamespace; + std::string_view ComputeCluster; + std::chrono::milliseconds ConnectTimeout{5000}; + std::chrono::milliseconds Timeout{}; +}; + +/** + * Jupiter upstream cache client + */ +class CloudCacheClient : public RefCounted +{ +public: + CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider); + ~CloudCacheClient(); + + CloudCacheAccessToken AcquireAccessToken(); + std::string_view DefaultDdcNamespace() const { return m_DefaultDdcNamespace; } + std::string_view DefaultBlobStoreNamespace() const { return m_DefaultBlobStoreNamespace; } + std::string_view ComputeCluster() const { return m_ComputeCluster; } + std::string_view ServiceUrl() const { return m_ServiceUrl; } + + spdlog::logger& Logger() { return m_Log; } + +private: + spdlog::logger& m_Log; + std::string m_ServiceUrl; + std::string m_DefaultDdcNamespace; + std::string m_DefaultBlobStoreNamespace; + std::string m_ComputeCluster; + std::chrono::milliseconds m_ConnectTimeout{}; + std::chrono::milliseconds m_Timeout{}; + std::unique_ptr<CloudCacheTokenProvider> m_TokenProvider; + + RwLock m_SessionStateLock; + std::list<detail::CloudCacheSessionState*> m_SessionStateCache; + + detail::CloudCacheSessionState* AllocSessionState(); + void FreeSessionState(detail::CloudCacheSessionState*); + + friend class CloudCacheSession; +}; + +} // namespace zen diff --git a/src/zenserver/upstream/upstream.h b/src/zenserver/upstream/upstream.h new file mode 100644 index 000000000..a57301206 --- /dev/null +++ b/src/zenserver/upstream/upstream.h @@ -0,0 +1,8 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <upstream/jupiter.h> +#include <upstream/upstreamcache.h> +#include <upstream/upstreamservice.h> +#include <upstream/zen.h> diff --git a/src/zenserver/upstream/upstreamapply.cpp b/src/zenserver/upstream/upstreamapply.cpp new file mode 100644 index 000000000..c719b225d --- /dev/null +++ b/src/zenserver/upstream/upstreamapply.cpp @@ -0,0 +1,459 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "upstreamapply.h" + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinary.h> +# include <zencore/compactbinarybuilder.h> +# include <zencore/fmtutils.h> +# include <zencore/stream.h> +# include <zencore/timer.h> +# include <zencore/workthreadpool.h> + +# include <zenstore/cidstore.h> + +# include "diag/logging.h" + +# include <fmt/format.h> + +# include <atomic> + +namespace zen { + +using namespace std::literals; + +struct UpstreamApplyStats +{ + static constexpr uint64_t MaxSampleCount = 1000ull; + + UpstreamApplyStats(bool Enabled) : m_Enabled(Enabled) {} + + void Add(UpstreamApplyEndpoint& Endpoint, const PostUpstreamApplyResult& Result) + { + UpstreamApplyEndpointStats& Stats = Endpoint.Stats(); + + if (Result.Error) + { + Stats.ErrorCount.Increment(1); + } + else if (Result.Success) + { + Stats.PostCount.Increment(1); + Stats.UpBytes.Increment(Result.Bytes / 1024 / 1024); + } + } + + void Add(UpstreamApplyEndpoint& Endpoint, const GetUpstreamApplyUpdatesResult& Result) + { + UpstreamApplyEndpointStats& Stats = Endpoint.Stats(); + + if (Result.Error) + { + Stats.ErrorCount.Increment(1); + } + else if (Result.Success) + { + Stats.UpdateCount.Increment(1); + Stats.DownBytes.Increment(Result.Bytes / 1024 / 1024); + if (!Result.Completed.empty()) + { + uint64_t Completed = 0; + for (auto& It : Result.Completed) + { + Completed += It.second.size(); + } + Stats.CompleteCount.Increment(Completed); + } + } + } + + bool m_Enabled; +}; + +////////////////////////////////////////////////////////////////////////// + +class UpstreamApplyImpl final : public UpstreamApply +{ +public: + UpstreamApplyImpl(const UpstreamApplyOptions& Options, CidStore& CidStore) + : m_Log(logging::Get("upstream-apply")) + , m_Options(Options) + , m_CidStore(CidStore) + , m_Stats(Options.StatsEnabled) + , m_UpstreamAsyncWorkPool(Options.UpstreamThreadCount) + , m_DownstreamAsyncWorkPool(Options.DownstreamThreadCount) + { + } + + virtual ~UpstreamApplyImpl() { Shutdown(); } + + virtual bool Initialize() override + { + for (auto& Endpoint : m_Endpoints) + { + const UpstreamEndpointHealth Health = Endpoint->Initialize(); + if (Health.Ok) + { + Log().info("initialize endpoint '{}' OK", Endpoint->DisplayName()); + } + else + { + Log().warn("initialize endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason); + } + } + + m_RunState.IsRunning = !m_Endpoints.empty(); + + if (m_RunState.IsRunning) + { + m_ShutdownEvent.Reset(); + + m_UpstreamUpdatesThread = std::thread(&UpstreamApplyImpl::ProcessUpstreamUpdates, this); + + m_EndpointMonitorThread = std::thread(&UpstreamApplyImpl::MonitorEndpoints, this); + } + + return m_RunState.IsRunning; + } + + virtual bool IsHealthy() const override + { + if (m_RunState.IsRunning) + { + for (const auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy()) + { + return true; + } + } + } + + return false; + } + + virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) override + { + m_Endpoints.emplace_back(std::move(Endpoint)); + } + + virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) override + { + if (m_RunState.IsRunning) + { + const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash(); + const IoHash ActionId = ApplyRecord.Action.GetHash(); + const uint32_t TimeoutSeconds = ApplyRecord.WorkerDescriptor["timeout"sv].AsInt32(300); + + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + // Already in progress + return {.ApplyId = ActionId, .Success = true}; + } + + std::chrono::steady_clock::time_point ExpireTime = + TimeoutSeconds > 0 ? std::chrono::steady_clock::now() + std::chrono::seconds(TimeoutSeconds) + : std::chrono::steady_clock::time_point::max(); + + m_ApplyTasks[WorkerId][ActionId] = {.State = UpstreamApplyState::Queued, .Result{}, .ExpireTime = std::move(ExpireTime)}; + } + + ApplyRecord.Timepoints["zen-queue-added"] = DateTime::NowTicks(); + m_UpstreamAsyncWorkPool.ScheduleWork( + [this, ApplyRecord = std::move(ApplyRecord)]() { ProcessApplyRecord(std::move(ApplyRecord)); }); + + return {.ApplyId = ActionId, .Success = true}; + } + + return {}; + } + + virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) override + { + if (m_RunState.IsRunning) + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + return {.Status = *Status, .Success = true}; + } + } + + return {}; + } + + virtual void GetStatus(CbObjectWriter& Status) override + { + Status << "upstream_worker_threads" << m_Options.UpstreamThreadCount; + Status << "upstream_queue_count" << m_UpstreamAsyncWorkPool.PendingWork(); + Status << "downstream_worker_threads" << m_Options.DownstreamThreadCount; + Status << "downstream_queue_count" << m_DownstreamAsyncWorkPool.PendingWork(); + + Status.BeginArray("endpoints"); + for (const auto& Ep : m_Endpoints) + { + Status.BeginObject(); + Status << "name" << Ep->DisplayName(); + Status << "health" << (Ep->IsHealthy() ? "ok"sv : "inactive"sv); + + UpstreamApplyEndpointStats& Stats = Ep->Stats(); + const uint64_t PostCount = Stats.PostCount.Value(); + const uint64_t CompleteCount = Stats.CompleteCount.Value(); + // const uint64_t UpdateCount = Stats.UpdateCount; + const double CompleteRate = CompleteCount > 0 ? (double(PostCount) / double(CompleteCount)) : 0.0; + + Status << "post_count" << PostCount; + Status << "complete_count" << PostCount; + Status << "update_count" << Stats.UpdateCount.Value(); + + Status << "complete_ratio" << CompleteRate; + Status << "downloaded_mb" << Stats.DownBytes.Value(); + Status << "uploaded_mb" << Stats.UpBytes.Value(); + Status << "error_count" << Stats.ErrorCount.Value(); + + Status.EndObject(); + } + Status.EndArray(); + } + +private: + // The caller is responsible for locking if required + UpstreamApplyStatus* FindStatus(const IoHash& WorkerId, const IoHash& ActionId) + { + if (auto It = m_ApplyTasks.find(WorkerId); It != m_ApplyTasks.end()) + { + if (auto It2 = It->second.find(ActionId); It2 != It->second.end()) + { + return &It2->second; + } + } + return nullptr; + } + + void ProcessApplyRecord(UpstreamApplyRecord ApplyRecord) + { + const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash(); + const IoHash ActionId = ApplyRecord.Action.GetHash(); + try + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy()) + { + ApplyRecord.Timepoints["zen-queue-dispatched"] = DateTime::NowTicks(); + PostUpstreamApplyResult Result = Endpoint->PostApply(std::move(ApplyRecord)); + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + Status->Timepoints.merge(Result.Timepoints); + + if (Result.Success) + { + Status->State = UpstreamApplyState::Executing; + } + else + { + Status->State = UpstreamApplyState::Complete; + Status->Result = {.Error = std::move(Result.Error), + .Bytes = Result.Bytes, + .ElapsedSeconds = Result.ElapsedSeconds}; + } + } + } + m_Stats.Add(*Endpoint, Result); + return; + } + } + + Log().warn("process upstream apply ({}/{}) FAILED 'No available endpoint'", WorkerId, ActionId); + + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + Status->State = UpstreamApplyState::Complete; + Status->Result = {.Error{.ErrorCode = -1, .Reason = "No available endpoint"}}; + } + } + } + catch (std::exception& e) + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr) + { + Status->State = UpstreamApplyState::Complete; + Status->Result = {.Error{.ErrorCode = -1, .Reason = e.what()}}; + } + Log().warn("process upstream apply ({}/{}) FAILED '{}'", WorkerId, ActionId, e.what()); + } + } + + void ProcessApplyUpdates() + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->IsHealthy()) + { + GetUpstreamApplyUpdatesResult Result = Endpoint->GetUpdates(m_DownstreamAsyncWorkPool); + m_Stats.Add(*Endpoint, Result); + + if (!Result.Success) + { + Log().warn("process upstream apply updates FAILED '{}'", Result.Error.Reason); + } + + if (!Result.Completed.empty()) + { + for (auto& It : Result.Completed) + { + for (auto& It2 : It.second) + { + std::scoped_lock Lock(m_ApplyTasksMutex); + if (auto Status = FindStatus(It.first, It2.first); Status != nullptr) + { + Status->State = UpstreamApplyState::Complete; + Status->Result = std::move(It2.second); + Status->Result.Timepoints.merge(Status->Timepoints); + Status->Result.Timepoints["zen-queue-complete"] = DateTime::NowTicks(); + Status->Timepoints.clear(); + } + } + } + } + } + } + } + + void ProcessUpstreamUpdates() + { + const auto& UpdateSleep = std::chrono::milliseconds(m_Options.UpdatesInterval); + while (!m_ShutdownEvent.Wait(uint32_t(UpdateSleep.count()))) + { + if (!m_RunState.IsRunning) + { + break; + } + + ProcessApplyUpdates(); + + // Remove any expired tasks, regardless of state + { + std::scoped_lock Lock(m_ApplyTasksMutex); + for (auto& WorkerIt : m_ApplyTasks) + { + const auto Count = std::erase_if(WorkerIt.second, [](const auto& Item) { + return Item.second.ExpireTime < std::chrono::steady_clock::now(); + }); + if (Count > 0) + { + Log().debug("Removed '{}' expired tasks", Count); + } + } + const auto Count = std::erase_if(m_ApplyTasks, [](const auto& Item) { return Item.second.empty(); }); + if (Count > 0) + { + Log().debug("Removed '{}' empty task lists", Count); + } + } + } + } + + void MonitorEndpoints() + { + for (;;) + { + { + std::unique_lock Lock(m_RunState.Mutex); + if (m_RunState.ExitSignal.wait_for(Lock, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); })) + { + break; + } + } + + for (auto& Endpoint : m_Endpoints) + { + if (!Endpoint->IsHealthy()) + { + if (const UpstreamEndpointHealth Health = Endpoint->CheckHealth(); Health.Ok) + { + Log().warn("health check endpoint '{}' OK", Endpoint->DisplayName(), Health.Reason); + } + else + { + Log().warn("health check endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason); + } + } + } + } + } + + void Shutdown() + { + if (m_RunState.Stop()) + { + m_ShutdownEvent.Set(); + m_EndpointMonitorThread.join(); + m_UpstreamUpdatesThread.join(); + m_Endpoints.clear(); + } + } + + spdlog::logger& Log() { return m_Log; } + + struct RunState + { + std::mutex Mutex; + std::condition_variable ExitSignal; + std::atomic_bool IsRunning{false}; + + bool Stop() + { + bool Stopped = false; + { + std::scoped_lock Lock(Mutex); + Stopped = IsRunning.exchange(false); + } + if (Stopped) + { + ExitSignal.notify_all(); + } + return Stopped; + } + }; + + spdlog::logger& m_Log; + UpstreamApplyOptions m_Options; + CidStore& m_CidStore; + UpstreamApplyStats m_Stats; + UpstreamApplyTasks m_ApplyTasks; + std::mutex m_ApplyTasksMutex; + std::vector<std::unique_ptr<UpstreamApplyEndpoint>> m_Endpoints; + Event m_ShutdownEvent; + WorkerThreadPool m_UpstreamAsyncWorkPool; + WorkerThreadPool m_DownstreamAsyncWorkPool; + std::thread m_UpstreamUpdatesThread; + std::thread m_EndpointMonitorThread; + RunState m_RunState; +}; + +////////////////////////////////////////////////////////////////////////// + +bool +UpstreamApply::IsHealthy() const +{ + return false; +} + +std::unique_ptr<UpstreamApply> +UpstreamApply::Create(const UpstreamApplyOptions& Options, CidStore& CidStore) +{ + return std::make_unique<UpstreamApplyImpl>(Options, CidStore); +} + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/upstream/upstreamapply.h b/src/zenserver/upstream/upstreamapply.h new file mode 100644 index 000000000..4a095be6c --- /dev/null +++ b/src/zenserver/upstream/upstreamapply.h @@ -0,0 +1,192 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#if ZEN_WITH_COMPUTE_SERVICES + +# include <zencore/compactbinarypackage.h> +# include <zencore/iobuffer.h> +# include <zencore/iohash.h> +# include <zencore/stats.h> +# include <zencore/zencore.h> + +# include <chrono> +# include <map> +# include <unordered_map> +# include <unordered_set> + +namespace zen { + +class AuthMgr; +class CbObjectWriter; +class CidStore; +class CloudCacheTokenProvider; +class WorkerThreadPool; +class ZenCacheNamespace; +struct CloudCacheClientOptions; +struct UpstreamAuthConfig; + +enum class UpstreamApplyState : int32_t +{ + Queued = 0, + Executing = 1, + Complete = 2, +}; + +enum class UpstreamApplyType +{ + Simple = 0, + Asset = 1, +}; + +struct UpstreamApplyRecord +{ + CbObject WorkerDescriptor; + CbObject Action; + UpstreamApplyType Type; + std::map<std::string, uint64_t> Timepoints{}; +}; + +struct UpstreamApplyOptions +{ + std::chrono::seconds HealthCheckInterval{5}; + std::chrono::seconds UpdatesInterval{5}; + uint32_t UpstreamThreadCount = 4; + uint32_t DownstreamThreadCount = 4; + bool StatsEnabled = false; +}; + +struct UpstreamApplyError +{ + int32_t ErrorCode{}; + std::string Reason{}; + + explicit operator bool() const { return ErrorCode != 0; } +}; + +struct PostUpstreamApplyResult +{ + UpstreamApplyError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + std::map<std::string, uint64_t> Timepoints{}; + bool Success = false; +}; + +struct GetUpstreamApplyResult +{ + // UpstreamApplyType::Simple + std::map<std::filesystem::path, IoHash> OutputFiles{}; + std::map<IoHash, IoBuffer> FileData{}; + + // UpstreamApplyType::Asset + CbPackage OutputPackage{}; + int64_t TotalAttachmentBytes{}; + int64_t TotalRawAttachmentBytes{}; + + UpstreamApplyError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + std::string StdOut{}; + std::string StdErr{}; + std::string Agent{}; + std::string Detail{}; + std::map<std::string, uint64_t> Timepoints{}; + bool Success = false; +}; + +using UpstreamApplyCompleted = std::unordered_map<IoHash, std::unordered_map<IoHash, GetUpstreamApplyResult>>; + +struct GetUpstreamApplyUpdatesResult +{ + UpstreamApplyError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + UpstreamApplyCompleted Completed{}; + bool Success = false; +}; + +struct UpstreamApplyStatus +{ + UpstreamApplyState State{}; + GetUpstreamApplyResult Result{}; + std::chrono::steady_clock::time_point ExpireTime{}; + std::map<std::string, uint64_t> Timepoints{}; +}; + +using UpstreamApplyTasks = std::unordered_map<IoHash, std::unordered_map<IoHash, UpstreamApplyStatus>>; + +struct UpstreamEndpointHealth +{ + std::string Reason; + bool Ok = false; +}; + +struct UpstreamApplyEndpointStats +{ + metrics::Counter PostCount; + metrics::Counter CompleteCount; + metrics::Counter UpdateCount; + metrics::Counter ErrorCount; + metrics::Counter UpBytes; + metrics::Counter DownBytes; +}; + +/** + * The upstream apply endpoint is responsible for handling remote execution. + */ +class UpstreamApplyEndpoint +{ +public: + virtual ~UpstreamApplyEndpoint() = default; + + virtual UpstreamEndpointHealth Initialize() = 0; + virtual bool IsHealthy() const = 0; + virtual UpstreamEndpointHealth CheckHealth() = 0; + virtual std::string_view DisplayName() const = 0; + virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) = 0; + virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) = 0; + virtual UpstreamApplyEndpointStats& Stats() = 0; + + static std::unique_ptr<UpstreamApplyEndpoint> CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions, + const UpstreamAuthConfig& ComputeAuthConfig, + const CloudCacheClientOptions& StorageOptions, + const UpstreamAuthConfig& StorageAuthConfig, + CidStore& CidStore, + AuthMgr& Mgr); +}; + +/** + * Manages one or more upstream compute endpoints. + */ +class UpstreamApply +{ +public: + virtual ~UpstreamApply() = default; + + virtual bool Initialize() = 0; + virtual bool IsHealthy() const = 0; + virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) = 0; + + struct EnqueueResult + { + IoHash ApplyId{}; + bool Success = false; + }; + + struct StatusResult + { + UpstreamApplyStatus Status{}; + bool Success = false; + }; + + virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) = 0; + virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) = 0; + virtual void GetStatus(CbObjectWriter& CbO) = 0; + + static std::unique_ptr<UpstreamApply> Create(const UpstreamApplyOptions& Options, CidStore& CidStore); +}; + +} // namespace zen + +#endif // ZEN_WITH_COMPUTE_SERVICES diff --git a/src/zenserver/upstream/upstreamcache.cpp b/src/zenserver/upstream/upstreamcache.cpp new file mode 100644 index 000000000..e838b5fe2 --- /dev/null +++ b/src/zenserver/upstream/upstreamcache.cpp @@ -0,0 +1,2112 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "upstreamcache.h" +#include "jupiter.h" +#include "zen.h" + +#include <zencore/blockingqueue.h> +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/fmtutils.h> +#include <zencore/stats.h> +#include <zencore/stream.h> +#include <zencore/timer.h> +#include <zencore/trace.h> + +#include <zenhttp/httpshared.h> + +#include <zenstore/cidstore.h> + +#include <auth/authmgr.h> +#include "cache/structuredcache.h" +#include "cache/structuredcachestore.h" +#include "diag/logging.h" + +#include <fmt/format.h> + +#include <algorithm> +#include <atomic> +#include <shared_mutex> +#include <thread> +#include <unordered_map> + +namespace zen { + +using namespace std::literals; + +namespace detail { + + class UpstreamStatus + { + public: + UpstreamEndpointState EndpointState() const { return static_cast<UpstreamEndpointState>(m_State.load(std::memory_order_relaxed)); } + + UpstreamEndpointStatus EndpointStatus() const + { + const UpstreamEndpointState State = EndpointState(); + { + std::unique_lock _(m_Mutex); + return {.Reason = m_ErrorText, .State = State}; + } + } + + void Set(UpstreamEndpointState NewState) + { + m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed); + { + std::unique_lock _(m_Mutex); + m_ErrorText.clear(); + } + } + + void Set(UpstreamEndpointState NewState, std::string ErrorText) + { + m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed); + { + std::unique_lock _(m_Mutex); + m_ErrorText = std::move(ErrorText); + } + } + + void SetFromErrorCode(int32_t ErrorCode, std::string_view ErrorText) + { + if (ErrorCode != 0) + { + Set(ErrorCode == 401 ? UpstreamEndpointState::kUnauthorized : UpstreamEndpointState::kError, std::string(ErrorText)); + } + } + + private: + mutable std::mutex m_Mutex; + std::string m_ErrorText; + std::atomic_uint32_t m_State; + }; + + class JupiterUpstreamEndpoint final : public UpstreamEndpoint + { + public: + JupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr) + : m_AuthMgr(Mgr) + , m_Log(zen::logging::Get("upstream")) + { + ZEN_ASSERT(!Options.Name.empty()); + m_Info.Name = Options.Name; + m_Info.Url = Options.ServiceUrl; + + std::unique_ptr<CloudCacheTokenProvider> TokenProvider; + + if (AuthConfig.OAuthUrl.empty() == false) + { + TokenProvider = CloudCacheTokenProvider::CreateFromOAuthClientCredentials( + {.Url = AuthConfig.OAuthUrl, .ClientId = AuthConfig.OAuthClientId, .ClientSecret = AuthConfig.OAuthClientSecret}); + } + else if (AuthConfig.OpenIdProvider.empty() == false) + { + TokenProvider = + CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(AuthConfig.OpenIdProvider)]() { + AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName); + return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime}; + }); + } + else + { + CloudCacheAccessToken AccessToken{.Value = std::string(AuthConfig.AccessToken), + .ExpireTime = CloudCacheAccessToken::TimePoint::max()}; + + TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken); + } + + m_Client = new CloudCacheClient(Options, std::move(TokenProvider)); + } + + virtual ~JupiterUpstreamEndpoint() = default; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; } + + virtual UpstreamEndpointStatus Initialize() override + { + try + { + if (m_Status.EndpointState() == UpstreamEndpointState::kOk) + { + return {.State = UpstreamEndpointState::kOk}; + } + + CloudCacheSession Session(m_Client); + const CloudCacheResult Result = Session.Authenticate(); + + if (Result.Success) + { + m_Status.Set(UpstreamEndpointState::kOk); + } + else if (Result.ErrorCode != 0) + { + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + } + else + { + m_Status.Set(UpstreamEndpointState::kUnauthorized); + } + + return m_Status.EndpointStatus(); + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = Err.what(), .State = GetState()}; + } + } + + std::string_view GetActualDdcNamespace(CloudCacheSession& Session, std::string_view Namespace) + { + if (Namespace == ZenCacheStore::DefaultNamespace) + { + return Session.Client().DefaultDdcNamespace(); + } + return Namespace; + } + + std::string_view GetActualBlobStoreNamespace(CloudCacheSession& Session, std::string_view Namespace) + { + if (Namespace == ZenCacheStore::DefaultNamespace) + { + return Session.Client().DefaultBlobStoreNamespace(); + } + return Namespace; + } + + virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); } + + virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, + const CacheKey& CacheKey, + ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheRecord"); + + try + { + CloudCacheSession Session(m_Client); + CloudCacheResult Result; + + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + + if (Type == ZenContentType::kCompressedBinary) + { + Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject); + + if (Result.Success) + { + const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All); + if (Result.Success = ValidationResult == CbValidateError::None; Result.Success) + { + CbObject CacheRecord = LoadCompactBinaryObject(Result.Response); + IoBuffer ContentBuffer; + int NumAttachments = 0; + + CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + Result.Bytes += AttachmentResult.Bytes; + Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds; + Result.ErrorCode = AttachmentResult.ErrorCode; + + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(AttachmentResult.Response, RawHash, RawSize)) + { + Result.Response = AttachmentResult.Response; + ++NumAttachments; + } + else + { + Result.Success = false; + } + }); + if (NumAttachments != 1) + { + Result.Success = false; + } + } + } + } + else + { + const ZenContentType AcceptType = Type == ZenContentType::kCbPackage ? ZenContentType::kCbObject : Type; + Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, AcceptType); + + if (Result.Success && Type == ZenContentType::kCbPackage) + { + CbPackage Package; + + const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All); + if (Result.Success = ValidationResult == CbValidateError::None; Result.Success) + { + CbObject CacheRecord = LoadCompactBinaryObject(Result.Response); + + CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + Result.Bytes += AttachmentResult.Bytes; + Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds; + Result.ErrorCode = AttachmentResult.ErrorCode; + + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer Chunk = + CompressedBuffer::FromCompressed(SharedBuffer(AttachmentResult.Response), RawHash, RawSize)) + { + Package.AddAttachment(CbAttachment(Chunk, AttachmentHash.AsHash())); + } + else + { + Result.Success = false; + } + }); + + Package.SetObject(CacheRecord); + } + + if (Result.Success) + { + BinaryWriter MemStream; + Package.Save(MemStream); + + Result.Response = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size()); + } + } + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheRecords"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheKeyRequest* Request : Requests) + { + const CacheKey& CacheKey = Request->Key; + CbPackage Package; + CbObject Record; + + double ElapsedSeconds = 0.0; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + CloudCacheResult RefResult = + Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject); + AppendResult(RefResult, Result); + ElapsedSeconds = RefResult.ElapsedSeconds; + + m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason); + + if (RefResult.ErrorCode == 0) + { + const CbValidateError ValidationResult = ValidateCompactBinary(RefResult.Response, CbValidateMode::All); + if (ValidationResult == CbValidateError::None) + { + Record = LoadCompactBinaryObject(RefResult.Response); + Record.IterateAttachments([&](CbFieldView AttachmentHash) { + CloudCacheResult BlobResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash()); + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + + if (BlobResult.ErrorCode == 0) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer Chunk = + CompressedBuffer::FromCompressed(SharedBuffer(BlobResult.Response), RawHash, RawSize)) + { + if (RawHash == AttachmentHash.AsHash()) + { + Package.AddAttachment(CbAttachment(Chunk, RawHash)); + } + } + } + }); + } + } + } + + OnComplete( + {.Request = *Request, .Record = Record, .Package = Package, .ElapsedSeconds = ElapsedSeconds, .Source = &m_Info}); + } + + return Result; + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey&, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheChunk"); + + try + { + CloudCacheSession Session(m_Client); + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const CloudCacheResult Result = Session.GetCompressedBlob(BlobStoreNamespace, ValueContentId); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheChunks"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + CacheChunkRequest& Request = *RequestPtr; + IoBuffer Payload; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + + double ElapsedSeconds = 0.0; + bool IsCompressed = false; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const CloudCacheResult BlobResult = + Request.ChunkId == IoHash::Zero + ? Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, Request.ChunkId) + : Session.GetCompressedBlob(BlobStoreNamespace, Request.ChunkId); + ElapsedSeconds = BlobResult.ElapsedSeconds; + Payload = BlobResult.Response; + + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + if (Payload && IsCompressedBinary(Payload.GetContentType())) + { + IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize); + } + } + + if (IsCompressed) + { + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = Payload, + .ElapsedSeconds = ElapsedSeconds, + .Source = &m_Info}); + } + else + { + OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + return Result; + } + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Horde::GetCacheValues"); + + CloudCacheSession Session(m_Client); + GetUpstreamCacheResult Result; + + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + CacheValueRequest& Request = *RequestPtr; + IoBuffer Payload; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + + double ElapsedSeconds = 0.0; + bool IsCompressed = false; + if (!Result.Error) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + IoHash PayloadHash; + const CloudCacheResult BlobResult = + Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, PayloadHash); + ElapsedSeconds = BlobResult.ElapsedSeconds; + Payload = BlobResult.Response; + + AppendResult(BlobResult, Result); + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + if (Payload) + { + if (IsCompressedBinary(Payload.GetContentType())) + { + IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize) && RawHash != PayloadHash; + } + else + { + CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer(Payload)); + RawHash = Compressed.DecodeRawHash(); + if (RawHash == PayloadHash) + { + IsCompressed = true; + } + else + { + ZEN_WARN("Horde request for inline payload of {}/{}/{} has hash {}, expected hash {} from header", + Namespace, + Request.Key.Bucket, + Request.Key.Hash.ToHexString(), + RawHash.ToHexString(), + PayloadHash.ToHexString()); + } + } + } + } + + if (IsCompressed) + { + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = Payload, + .ElapsedSeconds = ElapsedSeconds, + .Source = &m_Info}); + } + else + { + OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + return Result; + } + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Values) override + { + ZEN_TRACE_CPU("Upstream::Horde::PutCacheRecord"); + + ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size()); + const int32_t MaxAttempts = 3; + + try + { + CloudCacheSession Session(m_Client); + + if (CacheRecord.Type == ZenContentType::kBinary) + { + CloudCacheResult Result; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, CacheRecord.Namespace); + Result = Session.PutRef(BlobStoreNamespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + RecordValue, + ZenContentType::kBinary); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + return {.Reason = std::move(Result.Reason), + .Bytes = Result.Bytes, + .ElapsedSeconds = Result.ElapsedSeconds, + .Success = Result.Success}; + } + else if (CacheRecord.Type == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + if (!CompressedBuffer::ValidateCompressedHeader(RecordValue, RawHash, RawSize)) + { + return {.Reason = std::string("Invalid compressed value buffer"), .Success = false}; + } + + CbObjectWriter ReferencingObject; + ReferencingObject.AddBinaryAttachment("RawHash", RawHash); + ReferencingObject.AddInteger("RawSize", RawSize); + + return PerformStructuredPut( + Session, + CacheRecord.Namespace, + CacheRecord.Key, + ReferencingObject.Save().GetBuffer().AsIoBuffer(), + MaxAttempts, + [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) { + if (ValueContentId != RawHash) + { + OutReason = + fmt::format("Value '{}' MISMATCHED from compressed buffer raw hash {}", ValueContentId, RawHash); + return false; + } + + OutBuffer = RecordValue; + return true; + }); + } + else + { + return PerformStructuredPut( + Session, + CacheRecord.Namespace, + CacheRecord.Key, + RecordValue, + MaxAttempts, + [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) { + const auto It = + std::find(std::begin(CacheRecord.ValueContentIds), std::end(CacheRecord.ValueContentIds), ValueContentId); + + if (It == std::end(CacheRecord.ValueContentIds)) + { + OutReason = fmt::format("value '{}' MISSING from local cache", ValueContentId); + return false; + } + + const size_t Idx = std::distance(std::begin(CacheRecord.ValueContentIds), It); + + OutBuffer = Values[Idx]; + return true; + }); + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = std::string(Err.what()), .Success = false}; + } + } + + virtual UpstreamEndpointStats& Stats() override { return m_Stats; } + + private: + static void AppendResult(const CloudCacheResult& Result, GetUpstreamCacheResult& Out) + { + Out.Success &= Result.Success; + Out.Bytes += Result.Bytes; + Out.ElapsedSeconds += Result.ElapsedSeconds; + + if (Result.ErrorCode) + { + Out.Error = {.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}; + } + }; + + PutUpstreamCacheResult PerformStructuredPut( + CloudCacheSession& Session, + std::string_view Namespace, + const CacheKey& Key, + IoBuffer ObjectBuffer, + const int32_t MaxAttempts, + std::function<bool(const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason)>&& BlobFetchFn) + { + int64_t TotalBytes = 0ull; + double TotalElapsedSeconds = 0.0; + + std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace); + const auto PutBlobs = [&](std::span<IoHash> ValueContentIds, std::string& OutReason) -> bool { + for (const IoHash& ValueContentId : ValueContentIds) + { + IoBuffer BlobBuffer; + if (!BlobFetchFn(ValueContentId, BlobBuffer, OutReason)) + { + return false; + } + + CloudCacheResult BlobResult; + for (int32_t Attempt = 0; Attempt < MaxAttempts && !BlobResult.Success; Attempt++) + { + BlobResult = Session.PutCompressedBlob(BlobStoreNamespace, ValueContentId, BlobBuffer); + } + + m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason); + + if (!BlobResult.Success) + { + OutReason = fmt::format("upload value '{}' FAILED, reason '{}'", ValueContentId, BlobResult.Reason); + return false; + } + + TotalBytes += BlobResult.Bytes; + TotalElapsedSeconds += BlobResult.ElapsedSeconds; + } + + return true; + }; + + PutRefResult RefResult; + for (int32_t Attempt = 0; Attempt < MaxAttempts && !RefResult.Success; Attempt++) + { + RefResult = Session.PutRef(BlobStoreNamespace, Key.Bucket, Key.Hash, ObjectBuffer, ZenContentType::kCbObject); + } + + m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason); + + if (!RefResult.Success) + { + return {.Reason = fmt::format("upload cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, RefResult.Reason), + .Success = false}; + } + + TotalBytes += RefResult.Bytes; + TotalElapsedSeconds += RefResult.ElapsedSeconds; + + std::string Reason; + if (!PutBlobs(RefResult.Needs, Reason)) + { + return {.Reason = std::move(Reason), .Success = false}; + } + + const IoHash RefHash = IoHash::HashBuffer(ObjectBuffer); + FinalizeRefResult FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash); + + m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason); + + if (!FinalizeResult.Success) + { + return { + .Reason = fmt::format("finalize cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason), + .Success = false}; + } + + if (!FinalizeResult.Needs.empty()) + { + if (!PutBlobs(FinalizeResult.Needs, Reason)) + { + return {.Reason = std::move(Reason), .Success = false}; + } + + FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash); + + m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason); + + if (!FinalizeResult.Success) + { + return {.Reason = fmt::format("finalize '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason), + .Success = false}; + } + + if (!FinalizeResult.Needs.empty()) + { + ExtendableStringBuilder<256> Sb; + for (const IoHash& MissingHash : FinalizeResult.Needs) + { + Sb << MissingHash.ToHexString() << ","; + } + + return { + .Reason = fmt::format("finalize '{}/{}' FAILED, still needs value(s) '{}'", Key.Bucket, Key.Hash, Sb.ToString()), + .Success = false}; + } + } + + TotalBytes += FinalizeResult.Bytes; + TotalElapsedSeconds += FinalizeResult.ElapsedSeconds; + + return {.Bytes = TotalBytes, .ElapsedSeconds = TotalElapsedSeconds, .Success = true}; + } + + spdlog::logger& Log() { return m_Log; } + + AuthMgr& m_AuthMgr; + spdlog::logger& m_Log; + UpstreamEndpointInfo m_Info; + UpstreamStatus m_Status; + UpstreamEndpointStats m_Stats; + RefPtr<CloudCacheClient> m_Client; + }; + + class ZenUpstreamEndpoint final : public UpstreamEndpoint + { + struct ZenEndpoint + { + std::string Url; + std::string Reason; + double Latency{}; + bool Ok = false; + + bool operator<(const ZenEndpoint& RHS) const { return Ok && RHS.Ok ? Latency < RHS.Latency : Ok; } + }; + + public: + ZenUpstreamEndpoint(const ZenStructuredCacheClientOptions& Options) + : m_Log(zen::logging::Get("upstream")) + , m_ConnectTimeout(Options.ConnectTimeout) + , m_Timeout(Options.Timeout) + { + ZEN_ASSERT(!Options.Name.empty()); + m_Info.Name = Options.Name; + + for (const auto& Url : Options.Urls) + { + m_Endpoints.push_back({.Url = Url}); + } + } + + ~ZenUpstreamEndpoint() = default; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; } + + virtual UpstreamEndpointStatus Initialize() override + { + try + { + if (m_Status.EndpointState() == UpstreamEndpointState::kOk) + { + return {.State = UpstreamEndpointState::kOk}; + } + + const ZenEndpoint& Ep = GetEndpoint(); + + if (m_Info.Url != Ep.Url) + { + ZEN_INFO("Setting Zen upstream URL to '{}'", Ep.Url); + m_Info.Url = Ep.Url; + } + + if (Ep.Ok) + { + RwLock::ExclusiveLockScope _(m_ClientLock); + m_Client = new ZenStructuredCacheClient({.Url = m_Info.Url, .ConnectTimeout = m_ConnectTimeout, .Timeout = m_Timeout}); + m_Status.Set(UpstreamEndpointState::kOk); + } + else + { + m_Status.Set(UpstreamEndpointState::kError, Ep.Reason); + } + + return m_Status.EndpointStatus(); + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = Err.what(), .State = GetState()}; + } + } + + virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); } + + virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, + const CacheKey& CacheKey, + ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetSingleCacheRecord"); + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + const ZenCacheResult Result = Session.GetCacheRecord(Namespace, CacheKey.Bucket, CacheKey.Hash, Type); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheRecords"); + ZEN_ASSERT(Requests.size() > 0); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheRecords"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = Requests[0]->Policy.GetRecordPolicy(); + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy); + + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("Requests"sv); + for (CacheKeyRequest* Request : Requests) + { + BatchRequest.BeginObject(); + { + const CacheKey& Key = Request->Key; + BatchRequest.BeginObject("Key"sv); + { + BatchRequest << "Bucket"sv << Key.Bucket; + BatchRequest << "Hash"sv << Key.Hash; + } + BatchRequest.EndObject(); + if (!Request->Policy.IsUniform() || Request->Policy.GetRecordPolicy() != DefaultPolicy) + { + BatchRequest.SetName("Policy"sv); + Request->Policy.Save(BatchRequest); + } + } + BatchRequest.EndObject(); + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (Results.Num() != Requests.size()) + { + ZEN_WARN("Upstream::Zen::GetCacheRecords invalid number of Response results from Upstream."); + } + else + { + for (size_t Index = 0; CbFieldView Record : Results) + { + CacheKeyRequest* Request = Requests[Index++]; + OnComplete({.Request = *Request, + .Record = Record.AsObjectView(), + .Package = BatchResponse, + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheRecords invalid Response from Upstream."); + } + } + + for (CacheKeyRequest* Request : Requests) + { + OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunk"); + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + const ZenCacheResult Result = Session.GetCacheChunk(Namespace, CacheKey.Bucket, CacheKey.Hash, ValueContentId); + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.ErrorCode == 0) + { + return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success}, + .Value = Result.Response, + .Source = &m_Info}; + } + else + { + return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}}; + } + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}}; + } + } + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheValues"); + ZEN_ASSERT(!CacheValueRequests.empty()); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheValues"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = CacheValueRequests[0]->Policy; + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView(); + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("Requests"sv); + { + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + const CacheValueRequest& Request = *RequestPtr; + + BatchRequest.BeginObject(); + { + BatchRequest.BeginObject("Key"sv); + BatchRequest << "Bucket"sv << Request.Key.Bucket; + BatchRequest << "Hash"sv << Request.Key.Hash; + BatchRequest.EndObject(); + if (Request.Policy != DefaultPolicy) + { + BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView(); + } + } + BatchRequest.EndObject(); + } + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (CacheValueRequests.size() != Results.Num()) + { + ZEN_WARN("Upstream::Zen::GetCacheValues invalid number of Response results from Upstream."); + } + else + { + for (size_t RequestIndex = 0; CbFieldView ChunkField : Results) + { + CacheValueRequest& Request = *CacheValueRequests[RequestIndex++]; + CbObjectView ChunkObject = ChunkField.AsObjectView(); + IoHash RawHash = ChunkObject["RawHash"sv].AsHash(); + IoBuffer Payload; + uint64_t RawSize = 0; + if (RawHash != IoHash::Zero) + { + bool Success = false; + const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash); + if (Attachment) + { + if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + { + Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + RawSize = Compressed.DecodeRawSize(); + Success = true; + } + } + if (!Success) + { + CbFieldView RawSizeField = ChunkObject["RawSize"sv]; + RawSize = RawSizeField.AsUInt64(); + Success = !RawSizeField.HasError(); + } + if (!Success) + { + RawHash = IoHash::Zero; + } + } + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = std::move(Payload), + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheValues invalid Response from Upstream."); + } + } + + for (CacheValueRequest* RequestPtr : CacheValueRequests) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunks"); + ZEN_ASSERT(!CacheChunkRequests.empty()); + + CbObjectWriter BatchRequest; + BatchRequest << "Method"sv + << "GetCacheChunks"sv; + BatchRequest << "Accept"sv << kCbPkgMagic; + + BatchRequest.BeginObject("Params"sv); + { + CachePolicy DefaultPolicy = CacheChunkRequests[0]->Policy; + BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView(); + BatchRequest << "Namespace"sv << Namespace; + + BatchRequest.BeginArray("ChunkRequests"sv); + { + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + const CacheChunkRequest& Request = *RequestPtr; + + BatchRequest.BeginObject(); + { + BatchRequest.BeginObject("Key"sv); + BatchRequest << "Bucket"sv << Request.Key.Bucket; + BatchRequest << "Hash"sv << Request.Key.Hash; + BatchRequest.EndObject(); + if (Request.ValueId) + { + BatchRequest.AddObjectId("ValueId"sv, Request.ValueId); + } + if (Request.ChunkId != Request.ChunkId.Zero) + { + BatchRequest << "ChunkId"sv << Request.ChunkId; + } + if (Request.RawOffset != 0) + { + BatchRequest << "RawOffset"sv << Request.RawOffset; + } + if (Request.RawSize != UINT64_MAX) + { + BatchRequest << "RawSize"sv << Request.RawSize; + } + if (Request.Policy != DefaultPolicy) + { + BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView(); + } + } + BatchRequest.EndObject(); + } + } + BatchRequest.EndArray(); + } + BatchRequest.EndObject(); + + ZenCacheResult Result; + + { + ZenStructuredCacheSession Session(GetClientRef()); + Result = Session.InvokeRpc(BatchRequest.Save()); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + if (Result.Success) + { + CbPackage BatchResponse; + if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse)) + { + CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView(); + if (CacheChunkRequests.size() != Results.Num()) + { + ZEN_WARN("Upstream::Zen::GetCacheChunks invalid number of Response results from Upstream."); + } + else + { + for (size_t RequestIndex = 0; CbFieldView ChunkField : Results) + { + CacheChunkRequest& Request = *CacheChunkRequests[RequestIndex++]; + CbObjectView ChunkObject = ChunkField.AsObjectView(); + IoHash RawHash = ChunkObject["RawHash"sv].AsHash(); + IoBuffer Payload; + uint64_t RawSize = 0; + if (RawHash != IoHash::Zero) + { + bool Success = false; + const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash); + if (Attachment) + { + if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary()) + { + Payload = Compressed.GetCompressed().Flatten().AsIoBuffer(); + Payload.SetContentType(ZenContentType::kCompressedBinary); + RawSize = Compressed.DecodeRawSize(); + Success = true; + } + } + if (!Success) + { + CbFieldView RawSizeField = ChunkObject["RawSize"sv]; + RawSize = RawSizeField.AsUInt64(); + Success = !RawSizeField.HasError(); + } + if (!Success) + { + RawHash = IoHash::Zero; + } + } + OnComplete({.Request = Request, + .RawHash = RawHash, + .RawSize = RawSize, + .Value = std::move(Payload), + .ElapsedSeconds = Result.ElapsedSeconds, + .Source = &m_Info}); + } + + return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true}; + } + } + else + { + ZEN_WARN("Upstream::Zen::GetCacheChunks invalid Response from Upstream."); + } + } + + for (CacheChunkRequest* RequestPtr : CacheChunkRequests) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + + return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}; + } + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Values) override + { + ZEN_TRACE_CPU("Upstream::Zen::PutCacheRecord"); + + ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size()); + const int32_t MaxAttempts = 3; + + try + { + ZenStructuredCacheSession Session(GetClientRef()); + ZenCacheResult Result; + int64_t TotalBytes = 0ull; + double TotalElapsedSeconds = 0.0; + + if (CacheRecord.Type == ZenContentType::kCbPackage) + { + CbPackage Package; + Package.SetObject(CbObject(SharedBuffer(RecordValue))); + + for (const IoBuffer& Value : Values) + { + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer AttachmentBuffer = CompressedBuffer::FromCompressed(SharedBuffer(Value), RawHash, RawSize)) + { + Package.AddAttachment(CbAttachment(AttachmentBuffer, RawHash)); + } + else + { + return {.Reason = std::string("Invalid value buffer"), .Success = false}; + } + } + + BinaryWriter MemStream; + Package.Save(MemStream); + IoBuffer PackagePayload(IoBuffer::Wrap, MemStream.Data(), MemStream.Size()); + + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheRecord(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + PackagePayload, + CacheRecord.Type); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes = Result.Bytes; + TotalElapsedSeconds = Result.ElapsedSeconds; + } + else if (CacheRecord.Type == ZenContentType::kCompressedBinary) + { + IoHash RawHash; + uint64_t RawSize; + CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(RecordValue), RawHash, RawSize); + if (!Compressed) + { + return {.Reason = std::string("Invalid value compressed buffer"), .Success = false}; + } + + CbPackage BatchPackage; + CbObjectWriter BatchWriter; + BatchWriter << "Method"sv + << "PutCacheValues"sv; + BatchWriter << "Accept"sv << kCbPkgMagic; + + BatchWriter.BeginObject("Params"sv); + { + // DefaultPolicy unspecified and expected to be Default + + BatchWriter << "Namespace"sv << CacheRecord.Namespace; + + BatchWriter.BeginArray("Requests"sv); + { + BatchWriter.BeginObject(); + { + const CacheKey& Key = CacheRecord.Key; + BatchWriter.BeginObject("Key"sv); + { + BatchWriter << "Bucket"sv << Key.Bucket; + BatchWriter << "Hash"sv << Key.Hash; + } + BatchWriter.EndObject(); + // Policy unspecified and expected to be Default + BatchWriter.AddBinaryAttachment("RawHash"sv, RawHash); + BatchPackage.AddAttachment(CbAttachment(Compressed, RawHash)); + } + BatchWriter.EndObject(); + } + BatchWriter.EndArray(); + } + BatchWriter.EndObject(); + BatchPackage.SetObject(BatchWriter.Save()); + + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.InvokeRpc(BatchPackage); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + } + else + { + for (size_t Idx = 0, Count = Values.size(); Idx < Count; Idx++) + { + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheValue(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + CacheRecord.ValueContentIds[Idx], + Values[Idx]); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + + if (!Result.Success) + { + return {.Reason = "Failed to upload value", + .Bytes = TotalBytes, + .ElapsedSeconds = TotalElapsedSeconds, + .Success = false}; + } + } + + Result.Success = false; + for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++) + { + Result = Session.PutCacheRecord(CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + RecordValue, + CacheRecord.Type); + } + + m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason); + + TotalBytes += Result.Bytes; + TotalElapsedSeconds += Result.ElapsedSeconds; + } + + return {.Reason = std::move(Result.Reason), + .Bytes = TotalBytes, + .ElapsedSeconds = TotalElapsedSeconds, + .Success = Result.Success}; + } + catch (std::exception& Err) + { + m_Status.Set(UpstreamEndpointState::kError, Err.what()); + + return {.Reason = std::string(Err.what()), .Success = false}; + } + } + + virtual UpstreamEndpointStats& Stats() override { return m_Stats; } + + private: + Ref<ZenStructuredCacheClient> GetClientRef() + { + // m_Client can be modified at any time by a different thread. + // Make sure we safely bump the refcount inside a scope lock + RwLock::SharedLockScope _(m_ClientLock); + ZEN_ASSERT(m_Client); + Ref<ZenStructuredCacheClient> ClientRef(m_Client); + _.ReleaseNow(); + return ClientRef; + } + + const ZenEndpoint& GetEndpoint() + { + for (ZenEndpoint& Ep : m_Endpoints) + { + Ref<ZenStructuredCacheClient> Client( + new ZenStructuredCacheClient({.Url = Ep.Url, .ConnectTimeout = std::chrono::milliseconds(1000)})); + ZenStructuredCacheSession Session(std::move(Client)); + const int32_t SampleCount = 2; + + Ep.Ok = false; + Ep.Latency = {}; + + for (int32_t Sample = 0; Sample < SampleCount; ++Sample) + { + ZenCacheResult Result = Session.CheckHealth(); + Ep.Ok = Result.Success; + Ep.Reason = std::move(Result.Reason); + Ep.Latency += Result.ElapsedSeconds; + } + Ep.Latency /= double(SampleCount); + } + + std::sort(std::begin(m_Endpoints), std::end(m_Endpoints)); + + for (const auto& Ep : m_Endpoints) + { + ZEN_INFO("ping 'Zen' endpoint '{}' latency '{:.3}s' {}", Ep.Url, Ep.Latency, Ep.Ok ? "OK" : Ep.Reason); + } + + return m_Endpoints.front(); + } + + spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + UpstreamEndpointInfo m_Info; + UpstreamStatus m_Status; + UpstreamEndpointStats m_Stats; + std::vector<ZenEndpoint> m_Endpoints; + std::chrono::milliseconds m_ConnectTimeout; + std::chrono::milliseconds m_Timeout; + RwLock m_ClientLock; + RefPtr<ZenStructuredCacheClient> m_Client; + }; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +class UpstreamCacheImpl final : public UpstreamCache +{ +public: + UpstreamCacheImpl(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore) + : m_Log(logging::Get("upstream")) + , m_Options(Options) + , m_CacheStore(CacheStore) + , m_CidStore(CidStore) + { + } + + virtual ~UpstreamCacheImpl() { Shutdown(); } + + virtual void Initialize() override + { + for (uint32_t Idx = 0; Idx < m_Options.ThreadCount; Idx++) + { + m_UpstreamThreads.emplace_back(&UpstreamCacheImpl::ProcessUpstreamQueue, this); + } + + m_EndpointMonitorThread = std::thread(&UpstreamCacheImpl::MonitorEndpoints, this); + m_RunState.IsRunning = true; + } + + virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) override + { + const UpstreamEndpointStatus Status = Endpoint->Initialize(); + const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo(); + + if (Status.State == UpstreamEndpointState::kOk) + { + ZEN_INFO("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State)); + } + else + { + ZEN_WARN("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State)); + } + + // Register endpoint even if it fails, the health monitor thread will probe failing endpoint(s) + std::unique_lock<std::shared_mutex> _(m_EndpointsMutex); + m_Endpoints.emplace_back(std::move(Endpoint)); + } + + virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) override + { + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Ep : m_Endpoints) + { + if (!Fn(*Ep)) + { + break; + } + } + } + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) override + { + ZEN_TRACE_CPU("Upstream::GetCacheRecord"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + GetUpstreamCacheSingleResult Result = Endpoint->GetCacheRecord(Namespace, CacheKey, Type); + Scope.Stop(); + + Stats.CacheGetCount.Increment(1); + Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes); + + if (Result.Status.Success) + { + Stats.CacheHitCount.Increment(1); + + return Result; + } + + if (Result.Status.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache record FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Status.Error.Reason, + Result.Status.Error.ErrorCode); + } + } + } + + return {}; + } + + virtual void GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheRecords"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheKeyRequest*> RemainingKeys(Requests.begin(), Requests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheKeyRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + + Result = Endpoint->GetCacheRecords(Namespace, RemainingKeys, [&](CacheRecordGetCompleteParams&& Params) { + if (Params.Record) + { + OnComplete(std::forward<CacheRecordGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache record(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheKeyRequest* Request : RemainingKeys) + { + OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()}); + } + } + + virtual void GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheChunks"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheChunkRequest*> RemainingKeys(CacheChunkRequests.begin(), CacheChunkRequests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheChunkRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming); + + Result = Endpoint->GetCacheChunks(Namespace, RemainingKeys, [&](CacheChunkGetCompleteParams&& Params) { + if (Params.RawHash != Params.RawHash.Zero) + { + OnComplete(std::forward<CacheChunkGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache chunks(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheChunkRequest* RequestPtr : RemainingKeys) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) override + { + ZEN_TRACE_CPU("Upstream::GetCacheChunk"); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming); + GetUpstreamCacheSingleResult Result = Endpoint->GetCacheChunk(Namespace, CacheKey, ValueContentId); + Scope.Stop(); + + Stats.CacheGetCount.Increment(1); + Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes); + + if (Result.Status.Success) + { + Stats.CacheHitCount.Increment(1); + + return Result; + } + + if (Result.Status.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache chunk FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Status.Error.Reason, + Result.Status.Error.ErrorCode); + } + } + } + + return {}; + } + + virtual void GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) override final + { + ZEN_TRACE_CPU("Upstream::GetCacheValues"); + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + std::vector<CacheValueRequest*> RemainingKeys(CacheValueRequests.begin(), CacheValueRequests.end()); + + if (m_Options.ReadUpstream) + { + for (auto& Endpoint : m_Endpoints) + { + if (RemainingKeys.empty()) + { + break; + } + + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + std::vector<CacheValueRequest*> Missing; + GetUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming); + + Result = Endpoint->GetCacheValues(Namespace, RemainingKeys, [&](CacheValueGetCompleteParams&& Params) { + if (Params.RawHash != Params.RawHash.Zero) + { + OnComplete(std::forward<CacheValueGetCompleteParams>(Params)); + + Stats.CacheHitCount.Increment(1); + } + else + { + Missing.push_back(&Params.Request); + } + }); + } + + Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size())); + Stats.CacheGetTotalBytes.Increment(Result.Bytes); + + if (Result.Error) + { + Stats.CacheErrorCount.Increment(1); + + ZEN_WARN("get cache values(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'", + Endpoint->GetEndpointInfo().Url, + Result.Error.Reason, + Result.Error.ErrorCode); + } + + RemainingKeys = std::move(Missing); + } + } + + const UpstreamEndpointInfo Info; + for (CacheValueRequest* RequestPtr : RemainingKeys) + { + OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()}); + } + } + + virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) override + { + if (m_RunState.IsRunning && m_Options.WriteUpstream && m_Endpoints.size() > 0) + { + if (!m_UpstreamThreads.empty()) + { + m_UpstreamQueue.Enqueue(std::move(CacheRecord)); + } + else + { + ProcessCacheRecord(std::move(CacheRecord)); + } + } + } + + virtual void GetStatus(CbObjectWriter& Status) override + { + Status << "reading" << m_Options.ReadUpstream; + Status << "writing" << m_Options.WriteUpstream; + Status << "worker_threads" << m_Options.ThreadCount; + Status << "queue_count" << m_UpstreamQueue.Size(); + + Status.BeginArray("endpoints"); + for (const auto& Ep : m_Endpoints) + { + const UpstreamEndpointInfo& EpInfo = Ep->GetEndpointInfo(); + const UpstreamEndpointStatus EpStatus = Ep->GetStatus(); + UpstreamEndpointStats& EpStats = Ep->Stats(); + + Status.BeginObject(); + Status << "name" << EpInfo.Name; + Status << "url" << EpInfo.Url; + Status << "state" << ToString(EpStatus.State); + Status << "reason" << EpStatus.Reason; + + Status.BeginObject("cache"sv); + { + const int64_t GetCount = EpStats.CacheGetCount.Value(); + const int64_t HitCount = EpStats.CacheHitCount.Value(); + const int64_t ErrorCount = EpStats.CacheErrorCount.Value(); + const double HitRatio = GetCount > 0 ? double(HitCount) / double(GetCount) : 0.0; + const double ErrorRatio = GetCount > 0 ? double(ErrorCount) / double(GetCount) : 0.0; + + metrics::EmitSnapshot("get_requests"sv, EpStats.CacheGetRequestTiming, Status); + Status << "get_bytes" << EpStats.CacheGetTotalBytes.Value(); + Status << "get_count" << GetCount; + Status << "hit_count" << HitCount; + Status << "hit_ratio" << HitRatio; + Status << "error_count" << ErrorCount; + Status << "error_ratio" << ErrorRatio; + metrics::EmitSnapshot("put_requests"sv, EpStats.CachePutRequestTiming, Status); + Status << "put_bytes" << EpStats.CachePutTotalBytes.Value(); + } + Status.EndObject(); + + Status.EndObject(); + } + Status.EndArray(); + } + +private: + void ProcessCacheRecord(UpstreamCacheRecord CacheRecord) + { + ZEN_TRACE_CPU("Upstream::ProcessCacheRecord"); + + ZenCacheValue CacheValue; + std::vector<IoBuffer> Payloads; + + if (!m_CacheStore.Get(CacheRecord.Namespace, CacheRecord.Key.Bucket, CacheRecord.Key.Hash, CacheValue)) + { + ZEN_WARN("process upstream FAILED, '{}/{}/{}', cache record doesn't exist", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash); + return; + } + + for (const IoHash& ValueContentId : CacheRecord.ValueContentIds) + { + if (IoBuffer Payload = m_CidStore.FindChunkByCid(ValueContentId)) + { + Payloads.push_back(Payload); + } + else + { + ZEN_WARN("process upstream FAILED, '{}/{}/{}/{}', ValueContentId doesn't exist in CAS", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + ValueContentId); + return; + } + } + + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Endpoint : m_Endpoints) + { + if (Endpoint->GetState() != UpstreamEndpointState::kOk) + { + continue; + } + + UpstreamEndpointStats& Stats = Endpoint->Stats(); + PutUpstreamCacheResult Result; + { + metrics::OperationTiming::Scope Scope(Stats.CachePutRequestTiming); + Result = Endpoint->PutCacheRecord(CacheRecord, CacheValue.Value, std::span(Payloads)); + } + + Stats.CachePutTotalBytes.Increment(Result.Bytes); + + if (!Result.Success) + { + ZEN_WARN("upload cache record '{}/{}/{}' FAILED, endpoint '{}', reason '{}'", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + Endpoint->GetEndpointInfo().Url, + Result.Reason); + } + } + } + + void ProcessUpstreamQueue() + { + for (;;) + { + UpstreamCacheRecord CacheRecord; + if (m_UpstreamQueue.WaitAndDequeue(CacheRecord)) + { + try + { + ProcessCacheRecord(std::move(CacheRecord)); + } + catch (std::exception& Err) + { + ZEN_ERROR("upload cache record '{}/{}/{}' FAILED, reason '{}'", + CacheRecord.Namespace, + CacheRecord.Key.Bucket, + CacheRecord.Key.Hash, + Err.what()); + } + } + + if (!m_RunState.IsRunning) + { + break; + } + } + } + + void MonitorEndpoints() + { + for (;;) + { + { + std::unique_lock lk(m_RunState.Mutex); + if (m_RunState.ExitSignal.wait_for(lk, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); })) + { + break; + } + } + + try + { + std::vector<UpstreamEndpoint*> Endpoints; + + { + std::shared_lock<std::shared_mutex> _(m_EndpointsMutex); + + for (auto& Endpoint : m_Endpoints) + { + UpstreamEndpointState State = Endpoint->GetState(); + if (State == UpstreamEndpointState::kError) + { + Endpoints.push_back(Endpoint.get()); + ZEN_WARN("HEALTH - endpoint '{} - {}' is in error state '{}'", + Endpoint->GetEndpointInfo().Name, + Endpoint->GetEndpointInfo().Url, + Endpoint->GetStatus().Reason); + } + if (State == UpstreamEndpointState::kUnauthorized) + { + Endpoints.push_back(Endpoint.get()); + } + } + } + + for (auto& Endpoint : Endpoints) + { + const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo(); + const UpstreamEndpointStatus Status = Endpoint->Initialize(); + + if (Status.State == UpstreamEndpointState::kOk) + { + ZEN_INFO("HEALTH - endpoint '{} - {}' Ok", Info.Name, Info.Url); + } + else + { + const std::string Reason = Status.Reason.empty() ? "" : fmt::format(", reason '{}'", Status.Reason); + ZEN_WARN("HEALTH - endpoint '{} - {}' {} {}", Info.Name, Info.Url, ToString(Status.State), Reason); + } + } + } + catch (std::exception& Err) + { + ZEN_ERROR("check endpoint(s) health FAILED, reason '{}'", Err.what()); + } + } + } + + void Shutdown() + { + if (m_RunState.Stop()) + { + m_UpstreamQueue.CompleteAdding(); + for (std::thread& Thread : m_UpstreamThreads) + { + Thread.join(); + } + + m_EndpointMonitorThread.join(); + m_UpstreamThreads.clear(); + m_Endpoints.clear(); + } + } + + spdlog::logger& Log() { return m_Log; } + + using UpstreamQueue = BlockingQueue<UpstreamCacheRecord>; + + struct RunState + { + std::mutex Mutex; + std::condition_variable ExitSignal; + std::atomic_bool IsRunning{false}; + + bool Stop() + { + bool Stopped = false; + { + std::lock_guard _(Mutex); + Stopped = IsRunning.exchange(false); + } + if (Stopped) + { + ExitSignal.notify_all(); + } + return Stopped; + } + }; + + spdlog::logger& m_Log; + UpstreamCacheOptions m_Options; + ZenCacheStore& m_CacheStore; + CidStore& m_CidStore; + UpstreamQueue m_UpstreamQueue; + std::shared_mutex m_EndpointsMutex; + std::vector<std::unique_ptr<UpstreamEndpoint>> m_Endpoints; + std::vector<std::thread> m_UpstreamThreads; + std::thread m_EndpointMonitorThread; + RunState m_RunState; +}; + +////////////////////////////////////////////////////////////////////////// + +std::unique_ptr<UpstreamEndpoint> +UpstreamEndpoint::CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options) +{ + return std::make_unique<detail::ZenUpstreamEndpoint>(Options); +} + +std::unique_ptr<UpstreamEndpoint> +UpstreamEndpoint::CreateJupiterEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr) +{ + return std::make_unique<detail::JupiterUpstreamEndpoint>(Options, AuthConfig, Mgr); +} + +std::unique_ptr<UpstreamCache> +UpstreamCache::Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore) +{ + return std::make_unique<UpstreamCacheImpl>(Options, CacheStore, CidStore); +} + +} // namespace zen diff --git a/src/zenserver/upstream/upstreamcache.h b/src/zenserver/upstream/upstreamcache.h new file mode 100644 index 000000000..695c06b32 --- /dev/null +++ b/src/zenserver/upstream/upstreamcache.h @@ -0,0 +1,252 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/compress.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/stats.h> +#include <zencore/zencore.h> +#include <zenutil/cache/cache.h> + +#include <atomic> +#include <chrono> +#include <functional> +#include <memory> +#include <vector> + +namespace zen { + +class CbObjectView; +class AuthMgr; +class CbObjectView; +class CbPackage; +class CbObjectWriter; +class CidStore; +class ZenCacheStore; +struct CloudCacheClientOptions; +class CloudCacheTokenProvider; +struct ZenStructuredCacheClientOptions; + +struct UpstreamCacheRecord +{ + ZenContentType Type = ZenContentType::kBinary; + std::string Namespace; + CacheKey Key; + std::vector<IoHash> ValueContentIds; +}; + +struct UpstreamCacheOptions +{ + std::chrono::seconds HealthCheckInterval{5}; + uint32_t ThreadCount = 4; + bool ReadUpstream = true; + bool WriteUpstream = true; +}; + +struct UpstreamError +{ + int32_t ErrorCode{}; + std::string Reason{}; + + explicit operator bool() const { return ErrorCode != 0; } +}; + +struct UpstreamEndpointInfo +{ + std::string Name; + std::string Url; +}; + +struct GetUpstreamCacheResult +{ + UpstreamError Error{}; + int64_t Bytes{}; + double ElapsedSeconds{}; + bool Success = false; +}; + +struct GetUpstreamCacheSingleResult +{ + GetUpstreamCacheResult Status; + IoBuffer Value; + const UpstreamEndpointInfo* Source = nullptr; +}; + +struct PutUpstreamCacheResult +{ + std::string Reason; + int64_t Bytes{}; + double ElapsedSeconds{}; + bool Success = false; +}; + +struct CacheRecordGetCompleteParams +{ + CacheKeyRequest& Request; + const CbObjectView& Record; + const CbPackage& Package; + double ElapsedSeconds{}; + const UpstreamEndpointInfo* Source = nullptr; +}; + +using OnCacheRecordGetComplete = std::function<void(CacheRecordGetCompleteParams&&)>; + +struct CacheValueGetCompleteParams +{ + CacheValueRequest& Request; + IoHash RawHash; + uint64_t RawSize; + IoBuffer Value; + double ElapsedSeconds{}; + const UpstreamEndpointInfo* Source = nullptr; +}; + +using OnCacheValueGetComplete = std::function<void(CacheValueGetCompleteParams&&)>; + +struct CacheChunkGetCompleteParams +{ + CacheChunkRequest& Request; + IoHash RawHash; + uint64_t RawSize; + IoBuffer Value; + double ElapsedSeconds{}; + const UpstreamEndpointInfo* Source = nullptr; +}; + +using OnCacheChunksGetComplete = std::function<void(CacheChunkGetCompleteParams&&)>; + +struct UpstreamEndpointStats +{ + metrics::OperationTiming CacheGetRequestTiming; + metrics::OperationTiming CachePutRequestTiming; + metrics::Counter CacheGetTotalBytes; + metrics::Counter CachePutTotalBytes; + metrics::Counter CacheGetCount; + metrics::Counter CacheHitCount; + metrics::Counter CacheErrorCount; +}; + +enum class UpstreamEndpointState : uint32_t +{ + kDisabled, + kUnauthorized, + kError, + kOk +}; + +inline std::string_view +ToString(UpstreamEndpointState State) +{ + using namespace std::literals; + + switch (State) + { + case UpstreamEndpointState::kDisabled: + return "Disabled"sv; + case UpstreamEndpointState::kUnauthorized: + return "Unauthorized"sv; + case UpstreamEndpointState::kError: + return "Error"sv; + case UpstreamEndpointState::kOk: + return "Ok"sv; + default: + return "Unknown"sv; + } +} + +struct UpstreamAuthConfig +{ + std::string_view OAuthUrl; + std::string_view OAuthClientId; + std::string_view OAuthClientSecret; + std::string_view OpenIdProvider; + std::string_view AccessToken; +}; + +struct UpstreamEndpointStatus +{ + std::string Reason; + UpstreamEndpointState State; +}; + +/** + * The upstream endpoint is responsible for handling upload/downloading of cache records. + */ +class UpstreamEndpoint +{ +public: + virtual ~UpstreamEndpoint() = default; + + virtual UpstreamEndpointStatus Initialize() = 0; + + virtual const UpstreamEndpointInfo& GetEndpointInfo() const = 0; + + virtual UpstreamEndpointState GetState() = 0; + virtual UpstreamEndpointStatus GetStatus() = 0; + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0; + virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, const CacheKey& CacheKey, const IoHash& PayloadId) = 0; + virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) = 0; + + virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord, + IoBuffer RecordValue, + std::span<IoBuffer const> Payloads) = 0; + + virtual UpstreamEndpointStats& Stats() = 0; + + static std::unique_ptr<UpstreamEndpoint> CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options); + + static std::unique_ptr<UpstreamEndpoint> CreateJupiterEndpoint(const CloudCacheClientOptions& Options, + const UpstreamAuthConfig& AuthConfig, + AuthMgr& Mgr); +}; + +/** + * Manages one or more upstream cache endpoints. + */ +class UpstreamCache +{ +public: + virtual ~UpstreamCache() = default; + + virtual void Initialize() = 0; + + virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) = 0; + virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) = 0; + + virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0; + virtual void GetCacheRecords(std::string_view Namespace, + std::span<CacheKeyRequest*> Requests, + OnCacheRecordGetComplete&& OnComplete) = 0; + + virtual void GetCacheValues(std::string_view Namespace, + std::span<CacheValueRequest*> CacheValueRequests, + OnCacheValueGetComplete&& OnComplete) = 0; + + virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, + const CacheKey& CacheKey, + const IoHash& ValueContentId) = 0; + virtual void GetCacheChunks(std::string_view Namespace, + std::span<CacheChunkRequest*> CacheChunkRequests, + OnCacheChunksGetComplete&& OnComplete) = 0; + + virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) = 0; + + virtual void GetStatus(CbObjectWriter& CbO) = 0; + + static std::unique_ptr<UpstreamCache> Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore); +}; + +} // namespace zen diff --git a/src/zenserver/upstream/upstreamservice.cpp b/src/zenserver/upstream/upstreamservice.cpp new file mode 100644 index 000000000..6db1357c5 --- /dev/null +++ b/src/zenserver/upstream/upstreamservice.cpp @@ -0,0 +1,56 @@ +// Copyright Epic Games, Inc. All Rights Reserved. +#include <upstream/upstreamservice.h> + +#include <auth/authmgr.h> +#include <upstream/upstreamcache.h> + +#include <zencore/compactbinarybuilder.h> +#include <zencore/string.h> + +namespace zen { + +using namespace std::literals; + +HttpUpstreamService::HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr) : m_Upstream(Upstream), m_AuthMgr(Mgr) +{ + m_Router.RegisterRoute( + "endpoints", + [this](HttpRouterRequest& Req) { + CbObjectWriter Writer; + Writer.BeginArray("Endpoints"sv); + m_Upstream.IterateEndpoints([&Writer](UpstreamEndpoint& Ep) { + UpstreamEndpointInfo Info = Ep.GetEndpointInfo(); + UpstreamEndpointStatus Status = Ep.GetStatus(); + + Writer.BeginObject(); + Writer << "Name"sv << Info.Name; + Writer << "Url"sv << Info.Url; + Writer << "State"sv << ToString(Status.State); + Writer << "Reason"sv << Status.Reason; + Writer.EndObject(); + + return true; + }); + Writer.EndArray(); + Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Writer.Save()); + }, + HttpVerb::kGet); +} + +HttpUpstreamService::~HttpUpstreamService() +{ +} + +const char* +HttpUpstreamService::BaseUri() const +{ + return "/upstream/"; +} + +void +HttpUpstreamService::HandleRequest(zen::HttpServerRequest& Request) +{ + m_Router.HandleRequest(Request); +} + +} // namespace zen diff --git a/src/zenserver/upstream/upstreamservice.h b/src/zenserver/upstream/upstreamservice.h new file mode 100644 index 000000000..f1da03c8c --- /dev/null +++ b/src/zenserver/upstream/upstreamservice.h @@ -0,0 +1,27 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpserver.h> + +namespace zen { + +class AuthMgr; +class UpstreamCache; + +class HttpUpstreamService final : public zen::HttpService +{ +public: + HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr); + virtual ~HttpUpstreamService(); + + virtual const char* BaseUri() const override; + virtual void HandleRequest(zen::HttpServerRequest& Request) override; + +private: + UpstreamCache& m_Upstream; + AuthMgr& m_AuthMgr; + HttpRequestRouter m_Router; +}; + +} // namespace zen diff --git a/src/zenserver/upstream/zen.cpp b/src/zenserver/upstream/zen.cpp new file mode 100644 index 000000000..9e1212834 --- /dev/null +++ b/src/zenserver/upstream/zen.cpp @@ -0,0 +1,326 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zen.h" + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/fmtutils.h> +#include <zencore/session.h> +#include <zencore/stream.h> +#include <zenhttp/httpcommon.h> +#include <zenhttp/httpshared.h> + +#include "cache/structuredcachestore.h" +#include "diag/formatters.h" +#include "diag/logging.h" + +ZEN_THIRD_PARTY_INCLUDES_START +#include <cpr/cpr.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <xxhash.h> +#include <gsl/gsl-lite.hpp> + +namespace zen { + +namespace detail { + struct ZenCacheSessionState + { + ZenCacheSessionState(ZenStructuredCacheClient& Client) : OwnerClient(Client) {} + ~ZenCacheSessionState() {} + + void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout) + { + Session.SetBody({}); + Session.SetHeader({}); + Session.SetConnectTimeout(ConnectTimeout); + Session.SetTimeout(Timeout); + } + + cpr::Session& GetSession() { return Session; } + + private: + ZenStructuredCacheClient& OwnerClient; + cpr::Session Session; + }; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////// + +ZenStructuredCacheClient::ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options) +: m_Log(logging::Get(std::string_view("zenclient"))) +, m_ServiceUrl(Options.Url) +, m_ConnectTimeout(Options.ConnectTimeout) +, m_Timeout(Options.Timeout) +{ +} + +ZenStructuredCacheClient::~ZenStructuredCacheClient() +{ +} + +detail::ZenCacheSessionState* +ZenStructuredCacheClient::AllocSessionState() +{ + detail::ZenCacheSessionState* State = nullptr; + + if (RwLock::ExclusiveLockScope _(m_SessionStateLock); !m_SessionStateCache.empty()) + { + State = m_SessionStateCache.front(); + m_SessionStateCache.pop_front(); + } + + if (State == nullptr) + { + State = new detail::ZenCacheSessionState(*this); + } + + State->Reset(m_ConnectTimeout, m_Timeout); + + return State; +} + +void +ZenStructuredCacheClient::FreeSessionState(detail::ZenCacheSessionState* State) +{ + RwLock::ExclusiveLockScope _(m_SessionStateLock); + m_SessionStateCache.push_front(State); +} + +////////////////////////////////////////////////////////////////////////// + +using namespace std::literals; + +ZenStructuredCacheSession::ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient) +: m_Log(OuterClient->Log()) +, m_Client(std::move(OuterClient)) +{ + m_SessionState = m_Client->AllocSessionState(); +} + +ZenStructuredCacheSession::~ZenStructuredCacheSession() +{ + m_Client->FreeSessionState(m_SessionState); +} + +ZenCacheResult +ZenStructuredCacheSession::CheckHealth() +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/health/check"; + + cpr::Session& Session = m_SessionState->GetSession(); + Session.SetOption(cpr::Url{Uri.c_str()}); + cpr::Response Response = Session.Get(); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + return {.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}; +} + +ZenCacheResult +ZenStructuredCacheSession::GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Accept", std::string{MapContentTypeToString(Type)}}}); + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::GetCacheChunk(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + const IoHash& ValueContentId) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Accept", "application/x-ue-comp"}}); + + cpr::Response Response = Session.Get(); + ZEN_DEBUG("GET {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = Buffer, + .Bytes = Response.downloaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Reason = Response.reason, + .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::PutCacheRecord(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + IoBuffer Value, + ZenContentType Type) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", + Type == ZenContentType::kCbPackage ? "application/x-ue-cbpkg" + : Type == ZenContentType::kCbObject ? "application/x-ue-cb" + : "application/octet-stream"}}); + Session.SetBody(cpr::Body{static_cast<const char*>(Value.Data()), Value.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200 || Response.status_code == 201; + return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::PutCacheValue(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + const IoHash& ValueContentId, + IoBuffer Payload) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/"; + if (Namespace != ZenCacheStore::DefaultNamespace) + { + Uri << Namespace << "/"; + } + Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-comp"}}); + Session.SetBody(cpr::Body{static_cast<const char*>(Payload.Data()), Payload.Size()}); + + cpr::Response Response = Session.Put(); + ZEN_DEBUG("PUT {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200 || Response.status_code == 201; + return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::InvokeRpc(const CbObjectView& Request) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/$rpc"; + + BinaryWriter Body; + Request.CopyTo(Body); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}}); + Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = std::move(Buffer), + .Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Reason = Response.reason, + .Success = Success}; +} + +ZenCacheResult +ZenStructuredCacheSession::InvokeRpc(const CbPackage& Request) +{ + ExtendableStringBuilder<256> Uri; + Uri << m_Client->ServiceUrl() << "/z$/$rpc"; + + SharedBuffer Message = FormatPackageMessageBuffer(Request).Flatten(); + + cpr::Session& Session = m_SessionState->GetSession(); + + Session.SetOption(cpr::Url{Uri.c_str()}); + Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}}); + Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Message.GetData()), Message.GetSize()}); + + cpr::Response Response = Session.Post(); + ZEN_DEBUG("POST {}", Response); + + if (Response.error) + { + return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)}; + } + + const bool Success = Response.status_code == 200; + const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer(); + + return {.Response = std::move(Buffer), + .Bytes = Response.uploaded_bytes, + .ElapsedSeconds = Response.elapsed, + .Reason = Response.reason, + .Success = Success}; +} + +} // namespace zen diff --git a/src/zenserver/upstream/zen.h b/src/zenserver/upstream/zen.h new file mode 100644 index 000000000..bfba8fa98 --- /dev/null +++ b/src/zenserver/upstream/zen.h @@ -0,0 +1,125 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/thread.h> +#include <zencore/uid.h> +#include <zencore/zencore.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +#include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <chrono> +#include <list> + +struct ZenCacheValue; + +namespace spdlog { +class logger; +} + +namespace zen { + +class CbObjectWriter; +class CbObjectView; +class CbPackage; +class ZenStructuredCacheClient; + +////////////////////////////////////////////////////////////////////////// + +namespace detail { + struct ZenCacheSessionState; +} + +struct ZenCacheResult +{ + IoBuffer Response; + int64_t Bytes = {}; + double ElapsedSeconds = {}; + int32_t ErrorCode = {}; + std::string Reason; + bool Success = false; +}; + +struct ZenStructuredCacheClientOptions +{ + std::string_view Name; + std::string_view Url; + std::span<std::string const> Urls; + std::chrono::milliseconds ConnectTimeout{}; + std::chrono::milliseconds Timeout{}; +}; + +/** Zen Structured Cache session + * + * This provides a context in which cache queries can be performed + * + * These are currently all synchronous. Will need to be made asynchronous + */ +class ZenStructuredCacheSession +{ +public: + ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient); + ~ZenStructuredCacheSession(); + + ZenCacheResult CheckHealth(); + ZenCacheResult GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type); + ZenCacheResult GetCacheChunk(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& ValueContentId); + ZenCacheResult PutCacheRecord(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + IoBuffer Value, + ZenContentType Type); + ZenCacheResult PutCacheValue(std::string_view Namespace, + std::string_view BucketId, + const IoHash& Key, + const IoHash& ValueContentId, + IoBuffer Payload); + ZenCacheResult InvokeRpc(const CbObjectView& Request); + ZenCacheResult InvokeRpc(const CbPackage& Package); + +private: + inline spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + Ref<ZenStructuredCacheClient> m_Client; + detail::ZenCacheSessionState* m_SessionState; +}; + +/** Zen Structured Cache client + * + * This represents an endpoint to query -- actual queries should be done via + * ZenStructuredCacheSession + */ +class ZenStructuredCacheClient : public RefCounted +{ +public: + ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options); + ~ZenStructuredCacheClient(); + + std::string_view ServiceUrl() const { return m_ServiceUrl; } + + inline spdlog::logger& Log() { return m_Log; } + +private: + spdlog::logger& m_Log; + std::string m_ServiceUrl; + std::chrono::milliseconds m_ConnectTimeout; + std::chrono::milliseconds m_Timeout; + + RwLock m_SessionStateLock; + std::list<detail::ZenCacheSessionState*> m_SessionStateCache; + + detail::ZenCacheSessionState* AllocSessionState(); + void FreeSessionState(detail::ZenCacheSessionState*); + + friend class ZenStructuredCacheSession; +}; + +} // namespace zen diff --git a/src/zenserver/windows/service.cpp b/src/zenserver/windows/service.cpp new file mode 100644 index 000000000..89bacab0b --- /dev/null +++ b/src/zenserver/windows/service.cpp @@ -0,0 +1,646 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "service.h" + +#include <zencore/zencore.h> + +#if ZEN_PLATFORM_WINDOWS + +# include <zencore/except.h> +# include <zencore/zencore.h> + +# include <stdio.h> +# include <tchar.h> +# include <zencore/windows.h> + +# define SVCNAME L"Zen Store" + +SERVICE_STATUS gSvcStatus; +SERVICE_STATUS_HANDLE gSvcStatusHandle; +HANDLE ghSvcStopEvent = NULL; + +void SvcInstall(void); + +void ReportSvcStatus(DWORD, DWORD, DWORD); +void SvcReportEvent(LPTSTR); + +WindowsService::WindowsService() +{ +} + +WindowsService::~WindowsService() +{ +} + +// +// Purpose: +// Installs a service in the SCM database +// +// Parameters: +// None +// +// Return value: +// None +// +VOID +WindowsService::Install() +{ + SC_HANDLE schSCManager; + SC_HANDLE schService; + TCHAR szPath[MAX_PATH]; + + if (!GetModuleFileName(NULL, szPath, MAX_PATH)) + { + printf("Cannot install service (%d)\n", GetLastError()); + return; + } + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Create the service + + schService = CreateService(schSCManager, // SCM database + SVCNAME, // name of service + SVCNAME, // service name to display + SERVICE_ALL_ACCESS, // desired access + SERVICE_WIN32_OWN_PROCESS, // service type + SERVICE_DEMAND_START, // start type + SERVICE_ERROR_NORMAL, // error control type + szPath, // path to service's binary + NULL, // no load ordering group + NULL, // no tag identifier + NULL, // no dependencies + NULL, // LocalSystem account + NULL); // no password + + if (schService == NULL) + { + printf("CreateService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + else + printf("Service installed successfully\n"); + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} + +void +WindowsService::Delete() +{ + SC_HANDLE schSCManager; + SC_HANDLE schService; + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Get a handle to the service. + + schService = OpenService(schSCManager, // SCM database + SVCNAME, // name of service + DELETE); // need delete access + + if (schService == NULL) + { + printf("OpenService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + + // Delete the service. + + if (!DeleteService(schService)) + { + printf("DeleteService failed (%d)\n", GetLastError()); + } + else + printf("Service deleted successfully\n"); + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} + +WindowsService* gSvc; + +void WINAPI +CallMain(DWORD, LPSTR*) +{ + gSvc->SvcMain(); +} + +int +WindowsService::ServiceMain() +{ + gSvc = this; + + SERVICE_TABLE_ENTRY DispatchTable[] = {{(LPWSTR)SVCNAME, (LPSERVICE_MAIN_FUNCTION)&CallMain}, {NULL, NULL}}; + + // This call returns when the service has stopped. + // The process should simply terminate when the call returns. + + if (!StartServiceCtrlDispatcher(DispatchTable)) + { + const DWORD dwError = zen::GetLastError(); + + if (dwError == ERROR_FAILED_SERVICE_CONTROLLER_CONNECT) + { + // Not actually running as a service + gSvc = nullptr; + + zen::SetIsInteractiveSession(true); + + return Run(); + } + else + { + zen::ThrowSystemError(dwError, "StartServiceCtrlDispatcher failed"); + } + } + + zen::SetIsInteractiveSession(false); + + return 0; +} + +int +WindowsService::SvcMain() +{ + // Register the handler function for the service + + gSvcStatusHandle = RegisterServiceCtrlHandler(SVCNAME, SvcCtrlHandler); + + if (!gSvcStatusHandle) + { + SvcReportEvent((LPTSTR)TEXT("RegisterServiceCtrlHandler")); + + return 1; + } + + // These SERVICE_STATUS members remain as set here + + gSvcStatus.dwServiceType = SERVICE_WIN32_OWN_PROCESS; + gSvcStatus.dwServiceSpecificExitCode = 0; + + // Report initial status to the SCM + + ReportSvcStatus(SERVICE_START_PENDING, NO_ERROR, 3000); + + // Create an event. The control handler function, SvcCtrlHandler, + // signals this event when it receives the stop control code. + + ghSvcStopEvent = CreateEvent(NULL, // default security attributes + TRUE, // manual reset event + FALSE, // not signaled + NULL); // no name + + if (ghSvcStopEvent == NULL) + { + ReportSvcStatus(SERVICE_STOPPED, GetLastError(), 0); + + return 1; + } + + // Report running status when initialization is complete. + + ReportSvcStatus(SERVICE_RUNNING, NO_ERROR, 0); + + int ReturnCode = Run(); + + ReportSvcStatus(SERVICE_STOPPED, NO_ERROR, 0); + + return ReturnCode; +} + +// +// Purpose: +// Retrieves and displays the current service configuration. +// +// Parameters: +// None +// +// Return value: +// None +// +void +DoQuerySvc() +{ + SC_HANDLE schSCManager{}; + SC_HANDLE schService{}; + LPQUERY_SERVICE_CONFIG lpsc{}; + LPSERVICE_DESCRIPTION lpsd{}; + DWORD dwBytesNeeded{}, cbBufSize{}, dwError{}; + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Get a handle to the service. + + schService = OpenService(schSCManager, // SCM database + SVCNAME, // name of service + SERVICE_QUERY_CONFIG); // need query config access + + if (schService == NULL) + { + printf("OpenService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + + // Get the configuration information. + + if (!QueryServiceConfig(schService, NULL, 0, &dwBytesNeeded)) + { + dwError = GetLastError(); + if (ERROR_INSUFFICIENT_BUFFER == dwError) + { + cbBufSize = dwBytesNeeded; + lpsc = (LPQUERY_SERVICE_CONFIG)LocalAlloc(LMEM_FIXED, cbBufSize); + } + else + { + printf("QueryServiceConfig failed (%d)", dwError); + goto cleanup; + } + } + + if (!QueryServiceConfig(schService, lpsc, cbBufSize, &dwBytesNeeded)) + { + printf("QueryServiceConfig failed (%d)", GetLastError()); + goto cleanup; + } + + if (!QueryServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, NULL, 0, &dwBytesNeeded)) + { + dwError = GetLastError(); + if (ERROR_INSUFFICIENT_BUFFER == dwError) + { + cbBufSize = dwBytesNeeded; + lpsd = (LPSERVICE_DESCRIPTION)LocalAlloc(LMEM_FIXED, cbBufSize); + } + else + { + printf("QueryServiceConfig2 failed (%d)", dwError); + goto cleanup; + } + } + + if (!QueryServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, (LPBYTE)lpsd, cbBufSize, &dwBytesNeeded)) + { + printf("QueryServiceConfig2 failed (%d)", GetLastError()); + goto cleanup; + } + + // Print the configuration information. + + _tprintf(TEXT("%s configuration: \n"), SVCNAME); + _tprintf(TEXT(" Type: 0x%x\n"), lpsc->dwServiceType); + _tprintf(TEXT(" Start Type: 0x%x\n"), lpsc->dwStartType); + _tprintf(TEXT(" Error Control: 0x%x\n"), lpsc->dwErrorControl); + _tprintf(TEXT(" Binary path: %s\n"), lpsc->lpBinaryPathName); + _tprintf(TEXT(" Account: %s\n"), lpsc->lpServiceStartName); + + if (lpsd->lpDescription != NULL && lstrcmp(lpsd->lpDescription, TEXT("")) != 0) + _tprintf(TEXT(" Description: %s\n"), lpsd->lpDescription); + if (lpsc->lpLoadOrderGroup != NULL && lstrcmp(lpsc->lpLoadOrderGroup, TEXT("")) != 0) + _tprintf(TEXT(" Load order group: %s\n"), lpsc->lpLoadOrderGroup); + if (lpsc->dwTagId != 0) + _tprintf(TEXT(" Tag ID: %d\n"), lpsc->dwTagId); + if (lpsc->lpDependencies != NULL && lstrcmp(lpsc->lpDependencies, TEXT("")) != 0) + _tprintf(TEXT(" Dependencies: %s\n"), lpsc->lpDependencies); + + LocalFree(lpsc); + LocalFree(lpsd); + +cleanup: + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} + +// +// Purpose: +// Disables the service. +// +// Parameters: +// None +// +// Return value: +// None +// +void +DoDisableSvc() +{ + SC_HANDLE schSCManager; + SC_HANDLE schService; + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Get a handle to the service. + + schService = OpenService(schSCManager, // SCM database + SVCNAME, // name of service + SERVICE_CHANGE_CONFIG); // need change config access + + if (schService == NULL) + { + printf("OpenService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + + // Change the service start type. + + if (!ChangeServiceConfig(schService, // handle of service + SERVICE_NO_CHANGE, // service type: no change + SERVICE_DISABLED, // service start type + SERVICE_NO_CHANGE, // error control: no change + NULL, // binary path: no change + NULL, // load order group: no change + NULL, // tag ID: no change + NULL, // dependencies: no change + NULL, // account name: no change + NULL, // password: no change + NULL)) // display name: no change + { + printf("ChangeServiceConfig failed (%d)\n", GetLastError()); + } + else + printf("Service disabled successfully.\n"); + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} + +// +// Purpose: +// Enables the service. +// +// Parameters: +// None +// +// Return value: +// None +// +VOID __stdcall DoEnableSvc() +{ + SC_HANDLE schSCManager; + SC_HANDLE schService; + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Get a handle to the service. + + schService = OpenService(schSCManager, // SCM database + SVCNAME, // name of service + SERVICE_CHANGE_CONFIG); // need change config access + + if (schService == NULL) + { + printf("OpenService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + + // Change the service start type. + + if (!ChangeServiceConfig(schService, // handle of service + SERVICE_NO_CHANGE, // service type: no change + SERVICE_DEMAND_START, // service start type + SERVICE_NO_CHANGE, // error control: no change + NULL, // binary path: no change + NULL, // load order group: no change + NULL, // tag ID: no change + NULL, // dependencies: no change + NULL, // account name: no change + NULL, // password: no change + NULL)) // display name: no change + { + printf("ChangeServiceConfig failed (%d)\n", GetLastError()); + } + else + printf("Service enabled successfully.\n"); + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} +// +// Purpose: +// Updates the service description to "This is a test description". +// +// Parameters: +// None +// +// Return value: +// None +// +void +DoUpdateSvcDesc() +{ + SC_HANDLE schSCManager; + SC_HANDLE schService; + SERVICE_DESCRIPTION sd; + TCHAR szDesc[] = TEXT("This is a test description"); + + // Get a handle to the SCM database. + + schSCManager = OpenSCManager(NULL, // local computer + NULL, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (NULL == schSCManager) + { + printf("OpenSCManager failed (%d)\n", GetLastError()); + return; + } + + // Get a handle to the service. + + schService = OpenService(schSCManager, // SCM database + SVCNAME, // name of service + SERVICE_CHANGE_CONFIG); // need change config access + + if (schService == NULL) + { + printf("OpenService failed (%d)\n", GetLastError()); + CloseServiceHandle(schSCManager); + return; + } + + // Change the service description. + + sd.lpDescription = szDesc; + + if (!ChangeServiceConfig2(schService, // handle to service + SERVICE_CONFIG_DESCRIPTION, // change: description + &sd)) // new description + { + printf("ChangeServiceConfig2 failed\n"); + } + else + printf("Service description updated successfully.\n"); + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); +} + +// +// Purpose: +// Sets the current service status and reports it to the SCM. +// +// Parameters: +// dwCurrentState - The current state (see SERVICE_STATUS) +// dwWin32ExitCode - The system error code +// dwWaitHint - Estimated time for pending operation, +// in milliseconds +// +// Return value: +// None +// +VOID +ReportSvcStatus(DWORD dwCurrentState, DWORD dwWin32ExitCode, DWORD dwWaitHint) +{ + static DWORD dwCheckPoint = 1; + + // Fill in the SERVICE_STATUS structure. + + gSvcStatus.dwCurrentState = dwCurrentState; + gSvcStatus.dwWin32ExitCode = dwWin32ExitCode; + gSvcStatus.dwWaitHint = dwWaitHint; + + if (dwCurrentState == SERVICE_START_PENDING) + gSvcStatus.dwControlsAccepted = 0; + else + gSvcStatus.dwControlsAccepted = SERVICE_ACCEPT_STOP; + + if ((dwCurrentState == SERVICE_RUNNING) || (dwCurrentState == SERVICE_STOPPED)) + gSvcStatus.dwCheckPoint = 0; + else + gSvcStatus.dwCheckPoint = dwCheckPoint++; + + // Report the status of the service to the SCM. + SetServiceStatus(gSvcStatusHandle, &gSvcStatus); +} + +void +WindowsService::SvcCtrlHandler(DWORD dwCtrl) +{ + // Handle the requested control code. + // + // Called by SCM whenever a control code is sent to the service + // using the ControlService function. + + switch (dwCtrl) + { + case SERVICE_CONTROL_STOP: + ReportSvcStatus(SERVICE_STOP_PENDING, NO_ERROR, 0); + + // Signal the service to stop. + + SetEvent(ghSvcStopEvent); + zen::RequestApplicationExit(0); + + ReportSvcStatus(gSvcStatus.dwCurrentState, NO_ERROR, 0); + return; + + case SERVICE_CONTROL_INTERROGATE: + break; + + default: + break; + } +} + +// +// Purpose: +// Logs messages to the event log +// +// Parameters: +// szFunction - name of function that failed +// +// Return value: +// None +// +// Remarks: +// The service must have an entry in the Application event log. +// +VOID +SvcReportEvent(LPTSTR szFunction) +{ + ZEN_UNUSED(szFunction); + + // HANDLE hEventSource; + // LPCTSTR lpszStrings[2]; + // TCHAR Buffer[80]; + + // hEventSource = RegisterEventSource(NULL, SVCNAME); + + // if (NULL != hEventSource) + //{ + // StringCchPrintf(Buffer, 80, TEXT("%s failed with %d"), szFunction, GetLastError()); + + // lpszStrings[0] = SVCNAME; + // lpszStrings[1] = Buffer; + + // ReportEvent(hEventSource, // event log handle + // EVENTLOG_ERROR_TYPE, // event type + // 0, // event category + // SVC_ERROR, // event identifier + // NULL, // no security identifier + // 2, // size of lpszStrings array + // 0, // no binary data + // lpszStrings, // array of strings + // NULL); // no binary data + + // DeregisterEventSource(hEventSource); + //} +} + +#endif // ZEN_PLATFORM_WINDOWS diff --git a/src/zenserver/windows/service.h b/src/zenserver/windows/service.h new file mode 100644 index 000000000..7c9610983 --- /dev/null +++ b/src/zenserver/windows/service.h @@ -0,0 +1,20 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +class WindowsService +{ +public: + WindowsService(); + ~WindowsService(); + + virtual int Run() = 0; + + int ServiceMain(); + + static void Install(); + static void Delete(); + + int SvcMain(); + static void __stdcall SvcCtrlHandler(unsigned long); +}; diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua new file mode 100644 index 000000000..23bfb9535 --- /dev/null +++ b/src/zenserver/xmake.lua @@ -0,0 +1,60 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zenserver") + set_kind("binary") + add_deps("zencore", "zenhttp", "zenstore", "zenutil") + add_headerfiles("**.h") + add_files("**.cpp") + add_files("zenserver.cpp", {unity_ignored = true }) + add_includedirs(".") + set_symbols("debug") + + if is_mode("release") then + set_optimize("fastest") + end + + if is_plat("windows") then + add_ldflags("/subsystem:console,5.02") + add_ldflags("/MANIFEST:EMBED") + add_ldflags("/LTCG") + add_files("zenserver.rc") + add_cxxflags("/bigobj") + else + remove_files("windows/**") + end + + if is_plat("macosx") then + add_ldflags("-framework CoreFoundation") + add_ldflags("-framework CoreGraphics") + add_ldflags("-framework CoreText") + add_ldflags("-framework Foundation") + add_ldflags("-framework Security") + add_ldflags("-framework SystemConfiguration") + add_syslinks("bsm") + end + + add_options("compute") + add_options("exec") + + add_packages( + "vcpkg::asio", + "vcpkg::cxxopts", + "vcpkg::http-parser", + "vcpkg::json11", + "vcpkg::lua", + "vcpkg::mimalloc", + "vcpkg::rocksdb", + "vcpkg::sentry-native", + "vcpkg::sol2" + ) + + -- Only applicable to later versions of sentry-native + --[[ + if is_plat("linux") then + -- As sentry_native uses symbols from breakpad_client, the latter must + -- be specified after the former with GCC-like toolchains. xmake however + -- is unaware of this and simply globs files from vcpkg's output. The + -- line below forces breakpad_client to be to the right of sentry_native + add_syslinks("breakpad_client") + end + ]]-- diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp new file mode 100644 index 000000000..635fd04e0 --- /dev/null +++ b/src/zenserver/zenserver.cpp @@ -0,0 +1,1261 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/config.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/iobuffer.h> +#include <zencore/logging.h> +#include <zencore/refcount.h> +#include <zencore/scopeguard.h> +#include <zencore/session.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/timer.h> +#include <zencore/trace.h> +#include <zenhttp/httpserver.h> +#include <zenhttp/websocket.h> +#include <zenstore/cidstore.h> +#include <zenstore/scrubcontext.h> +#include <zenutil/basicfile.h> +#include <zenutil/zenserverprocess.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#endif + +#if ZEN_USE_MIMALLOC +ZEN_THIRD_PARTY_INCLUDES_START +# include <mimalloc-new-delete.h> +# include <mimalloc.h> +ZEN_THIRD_PARTY_INCLUDES_END +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <asio.hpp> +#include <lua.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <exception> +#include <list> +#include <optional> +#include <regex> +#include <set> +#include <unordered_map> + +////////////////////////////////////////////////////////////////////////// +// We don't have any doctest code in this file but this is needed to bring +// in some shared code into the executable + +#if ZEN_WITH_TESTS +# define ZEN_TEST_WITH_RUNNER 1 +# include <zencore/testing.h> +#endif + +////////////////////////////////////////////////////////////////////////// + +#include "config.h" +#include "diag/logging.h" + +#if ZEN_PLATFORM_WINDOWS +# include "windows/service.h" +#endif + +////////////////////////////////////////////////////////////////////////// +// Sentry +// + +#if !defined(ZEN_USE_SENTRY) +# if ZEN_PLATFORM_MAC && ZEN_ARCH_ARM64 +// vcpkg's sentry-native port does not support Arm on Mac. +# define ZEN_USE_SENTRY 0 +# else +# define ZEN_USE_SENTRY 1 +# endif +#endif + +#if ZEN_USE_SENTRY +# define SENTRY_BUILD_STATIC 1 +ZEN_THIRD_PARTY_INCLUDES_START +# include <sentry.h> +# include <spdlog/sinks/base_sink.h> +ZEN_THIRD_PARTY_INCLUDES_END + +// Sentry currently does not automatically add all required Windows +// libraries to the linker when consumed via vcpkg + +# if ZEN_PLATFORM_WINDOWS +# pragma comment(lib, "sentry.lib") +# pragma comment(lib, "dbghelp.lib") +# pragma comment(lib, "winhttp.lib") +# pragma comment(lib, "version.lib") +# endif +#endif + +////////////////////////////////////////////////////////////////////////// +// Services +// + +#include "admin/admin.h" +#include "auth/authmgr.h" +#include "auth/authservice.h" +#include "cache/structuredcache.h" +#include "cache/structuredcachestore.h" +#include "cidstore.h" +#include "compute/function.h" +#include "diag/diagsvcs.h" +#include "frontend/frontend.h" +#include "monitoring/httpstats.h" +#include "monitoring/httpstatus.h" +#include "objectstore/objectstore.h" +#include "projectstore/projectstore.h" +#include "testing/httptest.h" +#include "upstream/upstream.h" +#include "zenstore/gc.h" + +#define ZEN_APP_NAME "Zen store" + +namespace zen { + +using namespace std::literals; + +namespace utils { +#if ZEN_USE_SENTRY + class sentry_sink final : public spdlog::sinks::base_sink<spdlog::details::null_mutex> + { + public: + sentry_sink() {} + + protected: + static constexpr sentry_level_t MapToSentryLevel[spdlog::level::level_enum::n_levels] = {SENTRY_LEVEL_DEBUG, + SENTRY_LEVEL_DEBUG, + SENTRY_LEVEL_INFO, + SENTRY_LEVEL_WARNING, + SENTRY_LEVEL_ERROR, + SENTRY_LEVEL_FATAL, + SENTRY_LEVEL_DEBUG}; + + void sink_it_(const spdlog::details::log_msg& msg) override + { + std::string Message = fmt::format("{}\n{}({}) [{}]", msg.payload, msg.source.filename, msg.source.line, msg.source.funcname); + sentry_value_t event = sentry_value_new_message_event( + /* level */ MapToSentryLevel[msg.level], + /* logger */ nullptr, + /* message */ Message.c_str()); + sentry_event_value_add_stacktrace(event, NULL, 0); + sentry_capture_event(event); + } + void flush_() override {} + }; +#endif + + asio::error_code ResolveHostname(asio::io_context& Ctx, + std::string_view Host, + std::string_view DefaultPort, + std::vector<std::string>& OutEndpoints) + { + std::string_view Port = DefaultPort; + + if (const size_t Idx = Host.find(":"); Idx != std::string_view::npos) + { + Port = Host.substr(Idx + 1); + Host = Host.substr(0, Idx); + } + + asio::ip::tcp::resolver Resolver(Ctx); + + asio::error_code ErrorCode; + asio::ip::tcp::resolver::results_type Endpoints = Resolver.resolve(Host, Port, ErrorCode); + + if (!ErrorCode) + { + for (const asio::ip::tcp::endpoint Ep : Endpoints) + { + OutEndpoints.push_back(fmt::format("http://{}:{}", Ep.address().to_string(), Ep.port())); + } + } + + return ErrorCode; + } +} // namespace utils + +class ZenServer : public IHttpStatusProvider +{ +public: + int Initialize(const ZenServerOptions& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry) + { + m_UseSentry = ServerOptions.NoSentry == false; + m_ServerEntry = ServerEntry; + m_DebugOptionForcedCrash = ServerOptions.ShouldCrash; + const int ParentPid = ServerOptions.OwnerPid; + + if (ParentPid) + { + zen::ProcessHandle OwnerProcess; + OwnerProcess.Initialize(ParentPid); + + if (!OwnerProcess.IsValid()) + { + ZEN_WARN("Unable to initialize process handle for specified parent pid #{}", ParentPid); + + // If the pid is not reachable should we just shut down immediately? the intended owner process + // could have been killed or somehow crashed already + } + else + { + ZEN_INFO("Using parent pid #{} to control process lifetime", ParentPid); + } + + m_ProcessMonitor.AddPid(ParentPid); + } + + // Initialize/check mutex based on base port + + std::string MutexName = fmt::format("zen_{}", ServerOptions.BasePort); + + if (zen::NamedMutex::Exists(MutexName) || ((m_ServerMutex.Create(MutexName) == false))) + { + throw std::runtime_error(fmt::format("Failed to create mutex '{}' - is another instance already running?", MutexName).c_str()); + } + + InitializeState(ServerOptions); + + m_HealthService.SetHealthInfo({.DataRoot = m_DataRoot, + .AbsLogPath = ServerOptions.AbsLogFile, + .HttpServerClass = std::string(ServerOptions.HttpServerClass), + .BuildVersion = std::string(ZEN_CFG_VERSION_BUILD_STRING_FULL)}); + + // Ok so now we're configured, let's kick things off + + m_Http = zen::CreateHttpServer(ServerOptions.HttpServerClass); + int EffectiveBasePort = m_Http->Initialize(ServerOptions.BasePort); + + if (ServerOptions.WebSocketPort != 0) + { + const uint32 ThreadCount = + ServerOptions.WebSocketThreads > 0 ? uint32_t(ServerOptions.WebSocketThreads) : std::thread::hardware_concurrency(); + + m_WebSocket = zen::WebSocketServer::Create( + {.Port = gsl::narrow<uint16_t>(ServerOptions.WebSocketPort), .ThreadCount = Max(ThreadCount, uint32_t(16))}); + } + + // Setup authentication manager + { + std::string EncryptionKey = ServerOptions.EncryptionKey; + + if (EncryptionKey.empty()) + { + EncryptionKey = "abcdefghijklmnopqrstuvxyz0123456"; + + ZEN_WARN("using default encryption key"); + } + + std::string EncryptionIV = ServerOptions.EncryptionIV; + + if (EncryptionIV.empty()) + { + EncryptionIV = "0123456789abcdef"; + + ZEN_WARN("using default encryption initialization vector"); + } + + m_AuthMgr = AuthMgr::Create({.RootDirectory = m_DataRoot / "auth", + .EncryptionKey = AesKey256Bit::FromString(EncryptionKey), + .EncryptionIV = AesIV128Bit::FromString(EncryptionIV)}); + + for (const ZenOpenIdProviderConfig& OpenIdProvider : ServerOptions.AuthConfig.OpenIdProviders) + { + m_AuthMgr->AddOpenIdProvider({.Name = OpenIdProvider.Name, .Url = OpenIdProvider.Url, .ClientId = OpenIdProvider.ClientId}); + } + } + + m_AuthService = std::make_unique<zen::HttpAuthService>(*m_AuthMgr); + m_Http->RegisterService(*m_AuthService); + + m_Http->RegisterService(m_HealthService); + m_Http->RegisterService(m_StatsService); + m_Http->RegisterService(m_StatusService); + m_StatusService.RegisterHandler("status", *this); + + // Initialize storage and services + + ZEN_INFO("initializing storage"); + + zen::CidStoreConfiguration Config; + Config.RootDirectory = m_DataRoot / "cas"; + + m_CidStore = std::make_unique<zen::CidStore>(m_GcManager); + m_CidStore->Initialize(Config); + m_CidService.reset(new zen::HttpCidService{*m_CidStore}); + + ZEN_INFO("instantiating project service"); + + m_ProjectStore = new zen::ProjectStore(*m_CidStore, m_DataRoot / "projects", m_GcManager); + m_HttpProjectService.reset(new zen::HttpProjectService{*m_CidStore, m_ProjectStore, m_StatsService, *m_AuthMgr}); + +#if ZEN_WITH_COMPUTE_SERVICES + if (ServerOptions.ComputeServiceEnabled) + { + InitializeCompute(ServerOptions); + } + else + { + ZEN_INFO("NOT instantiating compute services"); + } +#endif // ZEN_WITH_COMPUTE_SERVICES + + if (ServerOptions.StructuredCacheEnabled) + { + InitializeStructuredCache(ServerOptions); + } + else + { + ZEN_INFO("NOT instantiating structured cache service"); + } + + m_Http->RegisterService(m_TestService); // NOTE: this is intentionally not limited to test mode as it's useful for diagnostics + m_Http->RegisterService(m_TestingService); + m_Http->RegisterService(m_AdminService); + + if (m_WebSocket) + { + m_WebSocket->RegisterService(m_TestingService); + } + + if (m_HttpProjectService) + { + m_Http->RegisterService(*m_HttpProjectService); + } + + m_Http->RegisterService(*m_CidService); + +#if ZEN_WITH_COMPUTE_SERVICES + if (ServerOptions.ComputeServiceEnabled) + { + if (m_HttpFunctionService != nullptr) + { + m_Http->RegisterService(*m_HttpFunctionService); + } + } +#endif // ZEN_WITH_COMPUTE_SERVICES + + m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot); + + if (m_FrontendService) + { + m_Http->RegisterService(*m_FrontendService); + } + + if (ServerOptions.ObjectStoreEnabled) + { + ObjectStoreConfig ObjCfg; + ObjCfg.RootDirectory = m_DataRoot / "obj"; + ObjCfg.ServerPort = static_cast<uint16_t>(EffectiveBasePort); + + for (const auto& Bucket : ServerOptions.ObjectStoreConfig.Buckets) + { + ObjectStoreConfig::BucketConfig NewBucket{.Name = Bucket.Name}; + NewBucket.Directory = Bucket.Directory.empty() ? (ObjCfg.RootDirectory / Bucket.Name) : Bucket.Directory; + ObjCfg.Buckets.push_back(std::move(NewBucket)); + } + + m_ObjStoreService = std::make_unique<HttpObjectStoreService>(std::move(ObjCfg)); + m_Http->RegisterService(*m_ObjStoreService); + } + + ZEN_INFO("initializing GC, enabled '{}', interval {}s", ServerOptions.GcConfig.Enabled, ServerOptions.GcConfig.IntervalSeconds); + zen::GcSchedulerConfig GcConfig{.RootDirectory = m_DataRoot / "gc", + .MonitorInterval = std::chrono::seconds(ServerOptions.GcConfig.MonitorIntervalSeconds), + .Interval = std::chrono::seconds(ServerOptions.GcConfig.IntervalSeconds), + .MaxCacheDuration = std::chrono::seconds(ServerOptions.GcConfig.Cache.MaxDurationSeconds), + .CollectSmallObjects = ServerOptions.GcConfig.CollectSmallObjects, + .Enabled = ServerOptions.GcConfig.Enabled, + .DiskReserveSize = ServerOptions.GcConfig.DiskReserveSize, + .DiskSizeSoftLimit = ServerOptions.GcConfig.Cache.DiskSizeSoftLimit}; + m_GcScheduler.Initialize(GcConfig); + + return EffectiveBasePort; + } + + void InitializeState(const ZenServerOptions& ServerOptions); + void InitializeStructuredCache(const ZenServerOptions& ServerOptions); + void InitializeCompute(const ZenServerOptions& ServerOptions); + + void Run() + { + // This is disabled for now, awaiting better scheduling + // + // Scrub(); + + if (m_ProcessMonitor.IsActive()) + { + EnqueueTimer(); + } + + if (!m_TestMode) + { + ZEN_INFO("__________ _________ __ "); + ZEN_INFO("\\____ /____ ____ / _____// |_ ___________ ____ "); + ZEN_INFO(" / // __ \\ / \\ \\_____ \\\\ __\\/ _ \\_ __ \\_/ __ \\ "); + ZEN_INFO(" / /\\ ___/| | \\ / \\| | ( <_> ) | \\/\\ ___/ "); + ZEN_INFO("/_______ \\___ >___| / /_______ /|__| \\____/|__| \\___ >"); + ZEN_INFO(" \\/ \\/ \\/ \\/ \\/ "); + } + + ZEN_INFO(ZEN_APP_NAME " now running (pid: {})", zen::GetCurrentProcessId()); + +#if ZEN_USE_SENTRY + ZEN_INFO("sentry crash handler {}", m_UseSentry ? "ENABLED" : "DISABLED"); + if (m_UseSentry) + { + sentry_clear_modulecache(); + } +#endif + + if (m_DebugOptionForcedCrash) + { + ZEN_DEBUG_BREAK(); + } + + const bool IsInteractiveMode = zen::IsInteractiveSession() && !m_TestMode; + + SetNewState(kRunning); + + OnReady(); + + if (m_WebSocket) + { + m_WebSocket->Run(); + } + + m_Http->Run(IsInteractiveMode); + + SetNewState(kShuttingDown); + + ZEN_INFO(ZEN_APP_NAME " exiting"); + + m_IoContext.stop(); + if (m_IoRunner.joinable()) + { + m_IoRunner.join(); + } + + Flush(); + } + + void RequestExit(int ExitCode) + { + RequestApplicationExit(ExitCode); + m_Http->RequestExit(); + } + + void Cleanup() + { + ZEN_INFO(ZEN_APP_NAME " cleaning up"); + m_GcScheduler.Shutdown(); + } + + void SetDedicatedMode(bool State) { m_IsDedicatedMode = State; } + void SetTestMode(bool State) { m_TestMode = State; } + void SetDataRoot(std::filesystem::path Root) { m_DataRoot = Root; } + void SetContentRoot(std::filesystem::path Root) { m_ContentRoot = Root; } + + std::function<void()> m_IsReadyFunc; + void SetIsReadyFunc(std::function<void()>&& IsReadyFunc) { m_IsReadyFunc = std::move(IsReadyFunc); } + void OnReady(); + + void EnsureIoRunner() + { + if (!m_IoRunner.joinable()) + { + m_IoRunner = std::thread{[this] { m_IoContext.run(); }}; + } + } + + void EnqueueTimer() + { + m_PidCheckTimer.expires_after(std::chrono::seconds(1)); + m_PidCheckTimer.async_wait([this](const asio::error_code&) { CheckOwnerPid(); }); + + EnsureIoRunner(); + } + + void CheckOwnerPid() + { + // Pick up any new "owner" processes + + std::set<uint32_t> AddedPids; + + for (auto& PidEntry : m_ServerEntry->SponsorPids) + { + if (uint32_t ThisPid = PidEntry.load(std::memory_order_relaxed)) + { + if (PidEntry.compare_exchange_strong(ThisPid, 0)) + { + if (AddedPids.insert(ThisPid).second) + { + m_ProcessMonitor.AddPid(ThisPid); + + ZEN_INFO("added process with pid #{} as a sponsor process", ThisPid); + } + } + } + } + + if (m_ProcessMonitor.IsRunning()) + { + EnqueueTimer(); + } + else + { + ZEN_INFO(ZEN_APP_NAME " exiting since sponsor processes are all gone"); + + RequestExit(0); + } + } + + void Scrub() + { + Stopwatch Timer; + ZEN_INFO("Storage validation STARTING"); + + ScrubContext Ctx; + m_CidStore->Scrub(Ctx); + m_ProjectStore->Scrub(Ctx); + m_StructuredCacheService->Scrub(Ctx); + + const uint64_t ElapsedTimeMs = Timer.GetElapsedTimeMs(); + + ZEN_INFO("Storage validation DONE in {}, ({} in {} chunks - {})", + NiceTimeSpanMs(ElapsedTimeMs), + NiceBytes(Ctx.ScrubbedBytes()), + Ctx.ScrubbedChunks(), + NiceByteRate(Ctx.ScrubbedBytes(), ElapsedTimeMs)); + } + + void Flush() + { + if (m_CidStore) + m_CidStore->Flush(); + + if (m_StructuredCacheService) + m_StructuredCacheService->Flush(); + + if (m_ProjectStore) + m_ProjectStore->Flush(); + } + + virtual void HandleStatusRequest(HttpServerRequest& Request) override + { + CbObjectWriter Cbo; + Cbo << "ok" << true; + Cbo << "state" << ToString(m_CurrentState); + Request.WriteResponse(HttpResponseCode::OK, Cbo.Save()); + } + +private: + ZenServerState::ZenServerEntry* m_ServerEntry = nullptr; + bool m_IsDedicatedMode = false; + bool m_TestMode = false; + CbObject m_RootManifest; + std::filesystem::path m_DataRoot; + std::filesystem::path m_ContentRoot; + std::thread m_IoRunner; + asio::io_context m_IoContext; + asio::steady_timer m_PidCheckTimer{m_IoContext}; + zen::ProcessMonitor m_ProcessMonitor; + zen::NamedMutex m_ServerMutex; + + enum ServerState + { + kInitializing, + kRunning, + kShuttingDown + } m_CurrentState = kInitializing; + + inline void SetNewState(ServerState NewState) { m_CurrentState = NewState; } + + std::string_view ToString(ServerState Value) + { + switch (Value) + { + case kInitializing: + return "initializing"sv; + case kRunning: + return "running"sv; + case kShuttingDown: + return "shutdown"sv; + default: + return "unknown"sv; + } + } + + zen::Ref<zen::HttpServer> m_Http; + std::unique_ptr<zen::WebSocketServer> m_WebSocket; + std::unique_ptr<zen::AuthMgr> m_AuthMgr; + std::unique_ptr<zen::HttpAuthService> m_AuthService; + zen::HttpStatusService m_StatusService; + zen::HttpStatsService m_StatsService; + zen::GcManager m_GcManager; + zen::GcScheduler m_GcScheduler{m_GcManager}; + std::unique_ptr<zen::CidStore> m_CidStore; + std::unique_ptr<zen::ZenCacheStore> m_CacheStore; + zen::HttpTestService m_TestService; + zen::HttpTestingService m_TestingService; + std::unique_ptr<zen::HttpCidService> m_CidService; + zen::RefPtr<zen::ProjectStore> m_ProjectStore; + std::unique_ptr<zen::HttpProjectService> m_HttpProjectService; + std::unique_ptr<zen::UpstreamCache> m_UpstreamCache; + std::unique_ptr<zen::HttpUpstreamService> m_UpstreamService; + std::unique_ptr<zen::HttpStructuredCacheService> m_StructuredCacheService; + zen::HttpAdminService m_AdminService{m_GcScheduler}; + zen::HttpHealthService m_HealthService; +#if ZEN_WITH_COMPUTE_SERVICES + std::unique_ptr<zen::HttpFunctionService> m_HttpFunctionService; +#endif // ZEN_WITH_COMPUTE_SERVICES + std::unique_ptr<zen::HttpFrontendService> m_FrontendService; + std::unique_ptr<zen::HttpObjectStoreService> m_ObjStoreService; + + bool m_DebugOptionForcedCrash = false; + bool m_UseSentry = false; +}; + +void +ZenServer::OnReady() +{ + m_ServerEntry->SignalReady(); + + if (m_IsReadyFunc) + { + m_IsReadyFunc(); + } +} + +void +ZenServer::InitializeState(const ZenServerOptions& ServerOptions) +{ + // Check root manifest to deal with schema versioning + + bool WipeState = false; + std::string WipeReason = "Unspecified"; + + bool UpdateManifest = false; + std::filesystem::path ManifestPath = m_DataRoot / "root_manifest"; + FileContents ManifestData = zen::ReadFile(ManifestPath); + + if (ManifestData.ErrorCode) + { + if (ServerOptions.IsFirstRun) + { + ZEN_INFO("Initializing state at '{}'", m_DataRoot); + + UpdateManifest = true; + } + else + { + WipeState = true; + WipeReason = fmt::format("No manifest present at '{}'", ManifestPath); + } + } + else + { + IoBuffer Manifest = ManifestData.Flatten(); + + if (CbValidateError ValidationResult = ValidateCompactBinary(Manifest, CbValidateMode::All); + ValidationResult != CbValidateError::None) + { + ZEN_WARN("Manifest validation failed: {}, state will be wiped", uint32_t(ValidationResult)); + + WipeState = true; + WipeReason = fmt::format("Validation of manifest at '{}' failed: {}", ManifestPath, uint32_t(ValidationResult)); + } + else + { + m_RootManifest = LoadCompactBinaryObject(Manifest); + + const int32_t ManifestVersion = m_RootManifest["schema_version"].AsInt32(0); + + if (ManifestVersion != ZEN_CFG_SCHEMA_VERSION) + { + WipeState = true; + WipeReason = fmt::format("Manifest schema version: {}, differs from required: {}", ManifestVersion, ZEN_CFG_SCHEMA_VERSION); + } + } + } + + // Release any open handles so we can overwrite the manifest + ManifestData = {}; + + // Handle any state wipe + + if (WipeState) + { + ZEN_WARN("Wiping state at '{}' - reason: '{}'", m_DataRoot, WipeReason); + + std::error_code Ec; + for (const std::filesystem::directory_entry& DirEntry : std::filesystem::directory_iterator{m_DataRoot, Ec}) + { + if (DirEntry.is_directory() && (DirEntry.path().filename() != "logs")) + { + ZEN_INFO("Deleting '{}'", DirEntry.path()); + + std::filesystem::remove_all(DirEntry.path(), Ec); + + if (Ec) + { + ZEN_WARN("Delete of '{}' returned error: '{}'", DirEntry.path(), Ec.message()); + } + } + } + + ZEN_INFO("Wiped all directories in data root"); + + UpdateManifest = true; + } + + if (UpdateManifest) + { + // Write new manifest + + const DateTime Now = DateTime::Now(); + + CbObjectWriter Cbo; + Cbo << "schema_version" << ZEN_CFG_SCHEMA_VERSION << "created" << Now << "updated" << Now << "state_id" << Oid::NewOid(); + + m_RootManifest = Cbo.Save(); + + WriteFile(ManifestPath, m_RootManifest.GetBuffer().AsIoBuffer()); + } +} + +void +ZenServer::InitializeStructuredCache(const ZenServerOptions& ServerOptions) +{ + using namespace std::literals; + + ZEN_INFO("instantiating structured cache service"); + m_CacheStore = std::make_unique<ZenCacheStore>( + m_GcManager, + ZenCacheStore::Configuration{.BasePath = m_DataRoot / "cache", .AllowAutomaticCreationOfNamespaces = true}); + + const ZenUpstreamCacheConfig& UpstreamConfig = ServerOptions.UpstreamCacheConfig; + + zen::UpstreamCacheOptions UpstreamOptions; + UpstreamOptions.ReadUpstream = (uint8_t(ServerOptions.UpstreamCacheConfig.CachePolicy) & uint8_t(UpstreamCachePolicy::Read)) != 0; + UpstreamOptions.WriteUpstream = (uint8_t(ServerOptions.UpstreamCacheConfig.CachePolicy) & uint8_t(UpstreamCachePolicy::Write)) != 0; + + if (UpstreamConfig.UpstreamThreadCount < 32) + { + UpstreamOptions.ThreadCount = static_cast<uint32_t>(UpstreamConfig.UpstreamThreadCount); + } + + m_UpstreamCache = zen::UpstreamCache::Create(UpstreamOptions, *m_CacheStore, *m_CidStore); + m_UpstreamService = std::make_unique<HttpUpstreamService>(*m_UpstreamCache, *m_AuthMgr); + m_UpstreamCache->Initialize(); + + if (ServerOptions.UpstreamCacheConfig.CachePolicy != UpstreamCachePolicy::Disabled) + { + // Zen upstream + { + std::vector<std::string> ZenUrls = UpstreamConfig.ZenConfig.Urls; + if (!UpstreamConfig.ZenConfig.Dns.empty()) + { + for (const std::string& Dns : UpstreamConfig.ZenConfig.Dns) + { + if (!Dns.empty()) + { + const asio::error_code Err = zen::utils::ResolveHostname(m_IoContext, Dns, "1337"sv, ZenUrls); + if (Err) + { + ZEN_ERROR("resolve FAILED, reason '{}'", Err.message()); + } + } + } + } + + std::erase_if(ZenUrls, [](const auto& Url) { return Url.empty(); }); + + if (!ZenUrls.empty()) + { + const auto ZenEndpointName = UpstreamConfig.ZenConfig.Name.empty() ? "Zen"sv : UpstreamConfig.ZenConfig.Name; + + std::unique_ptr<zen::UpstreamEndpoint> ZenEndpoint = zen::UpstreamEndpoint::CreateZenEndpoint( + {.Name = ZenEndpointName, + .Urls = ZenUrls, + .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds), + .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)}); + + m_UpstreamCache->RegisterEndpoint(std::move(ZenEndpoint)); + } + } + + // Jupiter upstream + if (UpstreamConfig.JupiterConfig.Url.empty() == false) + { + std::string_view EndpointName = UpstreamConfig.JupiterConfig.Name.empty() ? "Jupiter"sv : UpstreamConfig.JupiterConfig.Name; + + auto Options = + zen::CloudCacheClientOptions{.Name = EndpointName, + .ServiceUrl = UpstreamConfig.JupiterConfig.Url, + .DdcNamespace = UpstreamConfig.JupiterConfig.DdcNamespace, + .BlobStoreNamespace = UpstreamConfig.JupiterConfig.Namespace, + .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds), + .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)}; + + auto AuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.JupiterConfig.OAuthUrl, + .OAuthClientId = UpstreamConfig.JupiterConfig.OAuthClientId, + .OAuthClientSecret = UpstreamConfig.JupiterConfig.OAuthClientSecret, + .OpenIdProvider = UpstreamConfig.JupiterConfig.OpenIdProvider, + .AccessToken = UpstreamConfig.JupiterConfig.AccessToken}; + + std::unique_ptr<zen::UpstreamEndpoint> JupiterEndpoint = + zen::UpstreamEndpoint::CreateJupiterEndpoint(Options, AuthConfig, *m_AuthMgr); + + m_UpstreamCache->RegisterEndpoint(std::move(JupiterEndpoint)); + } + } + + m_StructuredCacheService = + std::make_unique<HttpStructuredCacheService>(*m_CacheStore, *m_CidStore, m_StatsService, m_StatusService, *m_UpstreamCache); + + m_Http->RegisterService(*m_StructuredCacheService); + m_Http->RegisterService(*m_UpstreamService); +} + +#if ZEN_WITH_COMPUTE_SERVICES +void +ZenServer::InitializeCompute(const ZenServerOptions& ServerOptions) +{ + ServerOptions; + const ZenUpstreamCacheConfig& UpstreamConfig = ServerOptions.UpstreamCacheConfig; + + // Horde compute upstream + if (UpstreamConfig.HordeConfig.Url.empty() == false && UpstreamConfig.HordeConfig.StorageUrl.empty() == false) + { + ZEN_INFO("instantiating compute service"); + + std::string_view EndpointName = UpstreamConfig.HordeConfig.Name.empty() ? "Horde"sv : UpstreamConfig.HordeConfig.Name; + + auto ComputeOptions = + zen::CloudCacheClientOptions{.Name = EndpointName, + .ServiceUrl = UpstreamConfig.HordeConfig.Url, + .ComputeCluster = UpstreamConfig.HordeConfig.Cluster, + .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds), + .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)}; + + auto ComputeAuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.HordeConfig.OAuthUrl, + .OAuthClientId = UpstreamConfig.HordeConfig.OAuthClientId, + .OAuthClientSecret = UpstreamConfig.HordeConfig.OAuthClientSecret, + .OpenIdProvider = UpstreamConfig.HordeConfig.OpenIdProvider, + .AccessToken = UpstreamConfig.HordeConfig.AccessToken}; + + auto StorageOptions = + zen::CloudCacheClientOptions{.Name = EndpointName, + .ServiceUrl = UpstreamConfig.HordeConfig.StorageUrl, + .BlobStoreNamespace = UpstreamConfig.HordeConfig.Namespace, + .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds), + .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)}; + + auto StorageAuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.HordeConfig.StorageOAuthUrl, + .OAuthClientId = UpstreamConfig.HordeConfig.StorageOAuthClientId, + .OAuthClientSecret = UpstreamConfig.HordeConfig.StorageOAuthClientSecret, + .OpenIdProvider = UpstreamConfig.HordeConfig.StorageOpenIdProvider, + .AccessToken = UpstreamConfig.HordeConfig.StorageAccessToken}; + + m_HttpFunctionService = std::make_unique<zen::HttpFunctionService>(*m_CidStore, + ComputeOptions, + StorageOptions, + ComputeAuthConfig, + StorageAuthConfig, + *m_AuthMgr); + } + else + { + ZEN_INFO("NOT instantiating compute service (missing Horde or Storage config)"); + } +} +#endif // ZEN_WITH_COMPUTE_SERVICES + +//////////////////////////////////////////////////////////////////////////////// + +class ZenEntryPoint +{ +public: + ZenEntryPoint(ZenServerOptions& ServerOptions); + ZenEntryPoint(const ZenEntryPoint&) = delete; + ZenEntryPoint& operator=(const ZenEntryPoint&) = delete; + int Run(); + +private: + ZenServerOptions& m_ServerOptions; + zen::LockFile m_LockFile; +}; + +ZenEntryPoint::ZenEntryPoint(ZenServerOptions& ServerOptions) : m_ServerOptions(ServerOptions) +{ +} + +#if ZEN_USE_SENTRY +static void +SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[maybe_unused]] void* Userdata) +{ + char LogMessageBuffer[160]; + std::string LogMessage; + const char* MessagePtr = LogMessageBuffer; + + int n = vsnprintf(LogMessageBuffer, sizeof LogMessageBuffer, Message, Args); + + if (n >= int(sizeof LogMessageBuffer)) + { + LogMessage.resize(n + 1); + + n = vsnprintf(LogMessage.data(), LogMessage.size(), Message, Args); + + MessagePtr = LogMessage.c_str(); + } + + switch (Level) + { + case SENTRY_LEVEL_DEBUG: + ConsoleLog().debug("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_INFO: + ConsoleLog().info("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_WARNING: + ConsoleLog().warn("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_ERROR: + ConsoleLog().error("sentry: {}", MessagePtr); + break; + + case SENTRY_LEVEL_FATAL: + ConsoleLog().critical("sentry: {}", MessagePtr); + break; + } +} +#endif + +int +ZenEntryPoint::Run() +{ +#if ZEN_USE_SENTRY + std::string SentryDatabasePath = PathToUtf8(m_ServerOptions.DataDir / ".sentry-native"); + int SentryErrorCode = 0; + if (m_ServerOptions.NoSentry == false) + { + sentry_options_t* SentryOptions = sentry_options_new(); + sentry_options_set_dsn(SentryOptions, "https://[email protected]/5919284"); + if (SentryDatabasePath.starts_with("\\\\?\\")) + { + SentryDatabasePath = SentryDatabasePath.substr(4); + } + sentry_options_set_database_path(SentryOptions, SentryDatabasePath.c_str()); + sentry_options_set_logger(SentryOptions, SentryLogFunction, this); + std::string SentryAttachmentPath = m_ServerOptions.AbsLogFile.string(); + if (SentryAttachmentPath.starts_with("\\\\?\\")) + { + SentryAttachmentPath = SentryAttachmentPath.substr(4); + } + sentry_options_add_attachment(SentryOptions, SentryAttachmentPath.c_str()); + sentry_options_set_release(SentryOptions, ZEN_CFG_VERSION); + // sentry_options_set_debug(SentryOptions, 1); + + SentryErrorCode = sentry_init(SentryOptions); + + auto SentrySink = spdlog::create<utils::sentry_sink>("sentry"); + zen::logging::SetErrorLog(std::move(SentrySink)); + } + + auto _ = zen::MakeGuard([] { + zen::logging::SetErrorLog(std::shared_ptr<spdlog::logger>()); + sentry_close(); + }); +#endif + + auto& ServerOptions = m_ServerOptions; + + try + { + // Mutual exclusion and synchronization + ZenServerState ServerState; + ServerState.Initialize(); + ServerState.Sweep(); + + ZenServerState::ZenServerEntry* Entry = ServerState.Lookup(ServerOptions.BasePort); + + if (Entry) + { + if (ServerOptions.OwnerPid) + { + ConsoleLog().info( + "Looks like there is already a process listening to this port {} (pid: {}), attaching owner pid {} to running instance", + ServerOptions.BasePort, + Entry->Pid, + ServerOptions.OwnerPid); + + Entry->AddSponsorProcess(ServerOptions.OwnerPid); + + std::exit(0); + } + else + { + ConsoleLog().warn("Exiting since there is already a process listening to port {} (pid: {})", + ServerOptions.BasePort, + Entry->Pid); + std::exit(1); + } + } + + std::error_code Ec; + + std::filesystem::path LockFilePath = ServerOptions.DataDir / ".lock"; + + bool IsReady = false; + + auto MakeLockData = [&] { + CbObjectWriter Cbo; + Cbo << "pid" << zen::GetCurrentProcessId() << "data" << PathToUtf8(ServerOptions.DataDir) << "port" << ServerOptions.BasePort + << "session_id" << GetSessionId() << "ready" << IsReady; + return Cbo.Save(); + }; + + m_LockFile.Create(LockFilePath, MakeLockData(), Ec); + + if (Ec) + { + ConsoleLog().warn("ERROR: Unable to grab lock at '{}' (error: '{}')", LockFilePath, Ec.message()); + + std::exit(99); + } + + InitializeLogging(ServerOptions); + +#if ZEN_USE_SENTRY + if (m_ServerOptions.NoSentry == false) + { + if (SentryErrorCode == 0) + { + ZEN_INFO("sentry initialized"); + } + else + { + ZEN_WARN("sentry_init returned failure! (error code: {})", SentryErrorCode); + } + } +#endif + + MaximizeOpenFileCount(); + + ZEN_INFO(ZEN_APP_NAME " - using lock file at '{}'", LockFilePath); + + ZEN_INFO(ZEN_APP_NAME " - starting on port {}, version '{}'", ServerOptions.BasePort, ZEN_CFG_VERSION_BUILD_STRING_FULL); + + Entry = ServerState.Register(ServerOptions.BasePort); + + if (ServerOptions.OwnerPid) + { + Entry->AddSponsorProcess(ServerOptions.OwnerPid); + } + + ZenServer Server; + Server.SetDataRoot(ServerOptions.DataDir); + Server.SetContentRoot(ServerOptions.ContentDir); + Server.SetTestMode(ServerOptions.IsTest); + Server.SetDedicatedMode(ServerOptions.IsDedicated); + + int EffectiveBasePort = Server.Initialize(ServerOptions, Entry); + + Entry->EffectiveListenPort = uint16_t(EffectiveBasePort); + if (EffectiveBasePort != ServerOptions.BasePort) + { + ZEN_INFO(ZEN_APP_NAME " - relocated to base port {}", EffectiveBasePort); + ServerOptions.BasePort = EffectiveBasePort; + } + + std::unique_ptr<std::thread> ShutdownThread; + std::unique_ptr<zen::NamedEvent> ShutdownEvent; + + zen::ExtendableStringBuilder<64> ShutdownEventName; + ShutdownEventName << "Zen_" << ServerOptions.BasePort << "_Shutdown"; + ShutdownEvent.reset(new zen::NamedEvent{ShutdownEventName}); + + // Monitor shutdown signals + + ShutdownThread.reset(new std::thread{[&] { + ZEN_INFO("shutdown monitor thread waiting for shutdown signal '{}'", ShutdownEventName); + if (ShutdownEvent->Wait()) + { + ZEN_INFO("shutdown signal received"); + Server.RequestExit(0); + } + else + { + ZEN_INFO("shutdown signal wait() failed"); + } + }}); + + // If we have a parent process, establish the mechanisms we need + // to be able to communicate readiness with the parent + + Server.SetIsReadyFunc([&] { + IsReady = true; + + m_LockFile.Update(MakeLockData(), Ec); + + if (!ServerOptions.ChildId.empty()) + { + zen::NamedEvent ParentEvent{ServerOptions.ChildId}; + ParentEvent.Set(); + } + }); + + Server.Run(); + Server.Cleanup(); + + ShutdownEvent->Set(); + ShutdownThread->join(); + } + catch (std::exception& e) + { + SPDLOG_CRITICAL("Caught exception in main: {}", e.what()); + } + + ShutdownLogging(); + + return 0; +} + +} // namespace zen + +//////////////////////////////////////////////////////////////////////////////// + +#if ZEN_PLATFORM_WINDOWS + +class ZenWindowsService : public WindowsService +{ +public: + ZenWindowsService(ZenServerOptions& ServerOptions) : m_EntryPoint(ServerOptions) {} + + ZenWindowsService(const ZenWindowsService&) = delete; + ZenWindowsService& operator=(const ZenWindowsService&) = delete; + + virtual int Run() override; + +private: + zen::ZenEntryPoint m_EntryPoint; +}; + +int +ZenWindowsService::Run() +{ + return m_EntryPoint.Run(); +} + +#endif // ZEN_PLATFORM_WINDOWS + +//////////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS +int +test_main(int argc, char** argv) +{ + zen::zencore_forcelinktests(); + zen::zenhttp_forcelinktests(); + zen::zenstore_forcelinktests(); + zen::z$_forcelink(); + zen::z$service_forcelink(); + + zen::logging::InitializeLogging(); + spdlog::set_level(spdlog::level::debug); + + zen::MaximizeOpenFileCount(); + + return ZEN_RUN_TESTS(argc, argv); +} +#endif + +int +main(int argc, char* argv[]) +{ + using namespace zen; + +#if ZEN_USE_MIMALLOC + mi_version(); +#endif + +#if ZEN_WITH_TESTS + if (argc >= 2) + { + if (argv[1] == "test"sv) + { + return test_main(argc, argv); + } + } +#endif + + try + { + ZenServerOptions ServerOptions; + ParseCliOptions(argc, argv, ServerOptions); + + if (!std::filesystem::exists(ServerOptions.DataDir)) + { + ServerOptions.IsFirstRun = true; + std::filesystem::create_directories(ServerOptions.DataDir); + } + +#if ZEN_WITH_TRACE + if (ServerOptions.TraceHost.size()) + { + TraceInit(ServerOptions.TraceHost.c_str(), TraceType::Network); + } + else if (ServerOptions.TraceFile.size()) + { + TraceInit(ServerOptions.TraceFile.c_str(), TraceType::File); + } + else + { + TraceInit(nullptr, TraceType::None); + } +#endif // ZEN_WITH_TRACE + +#if ZEN_PLATFORM_WINDOWS + if (ServerOptions.InstallService) + { + WindowsService::Install(); + + std::exit(0); + } + + if (ServerOptions.UninstallService) + { + WindowsService::Delete(); + + std::exit(0); + } + + ZenWindowsService App(ServerOptions); + return App.ServiceMain(); +#else + if (ServerOptions.InstallService || ServerOptions.UninstallService) + { + throw std::runtime_error("Service mode is not supported on this platform"); + } + + ZenEntryPoint App(ServerOptions); + return App.Run(); +#endif // ZEN_PLATFORM_WINDOWS + } + catch (std::exception& Ex) + { + fprintf(stderr, "ERROR: Caught exception in main: '%s'", Ex.what()); + + return 1; + } +} diff --git a/src/zenserver/zenserver.rc b/src/zenserver/zenserver.rc new file mode 100644 index 000000000..6d31e2c6e --- /dev/null +++ b/src/zenserver/zenserver.rc @@ -0,0 +1,105 @@ +// Microsoft Visual C++ generated resource script. +// +#include "resource.h" + +#include "zencore/config.h" + +#define APSTUDIO_READONLY_SYMBOLS +///////////////////////////////////////////////////////////////////////////// +// +// Generated from the TEXTINCLUDE 2 resource. +// +#include "winres.h" + +///////////////////////////////////////////////////////////////////////////// +#undef APSTUDIO_READONLY_SYMBOLS + +///////////////////////////////////////////////////////////////////////////// +// English (United States) resources + +#if !defined(AFX_RESOURCE_DLL) || defined(AFX_TARG_ENU) +LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US +#pragma code_page(1252) + +///////////////////////////////////////////////////////////////////////////// +// +// Icon +// + +// Icon with lowest ID value placed first to ensure application icon +// remains consistent on all systems. +IDI_ICON1 ICON "..\\UnrealEngine.ico" + +#endif // English (United States) resources +///////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////// +// English (United Kingdom) resources + +#if !defined(AFX_RESOURCE_DLL) || defined(AFX_TARG_ENG) +LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_UK +#pragma code_page(1252) + +#ifdef APSTUDIO_INVOKED +///////////////////////////////////////////////////////////////////////////// +// +// TEXTINCLUDE +// + +1 TEXTINCLUDE +BEGIN + "resource.h\0" +END + +2 TEXTINCLUDE +BEGIN + "#include ""winres.h""\r\n" + "\0" +END + +3 TEXTINCLUDE +BEGIN + "\r\n" + "\0" +END + +#endif // APSTUDIO_INVOKED + +#endif // English (United Kingdom) resources +///////////////////////////////////////////////////////////////////////////// + + + +#ifndef APSTUDIO_INVOKED +///////////////////////////////////////////////////////////////////////////// +// +// Generated from the TEXTINCLUDE 3 resource. +// + + +///////////////////////////////////////////////////////////////////////////// +#endif // not APSTUDIO_INVOKED + +VS_VERSION_INFO VERSIONINFO +FILEVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0 +PRODUCTVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0 +{ + BLOCK "StringFileInfo" + { + BLOCK "040904b0" + { + VALUE "CompanyName", "Epic Games Inc\0" + VALUE "FileDescription", "Local Storage Service for Unreal Engine\0" + VALUE "FileVersion", ZEN_CFG_VERSION "\0" + VALUE "LegalCopyright", "Copyright Epic Games Inc. All Rights Reserved\0" + VALUE "OriginalFilename", "zenserver.exe\0" + VALUE "ProductName", "Zen Storage Server\0" + VALUE "ProductVersion", ZEN_CFG_VERSION_BUILD_STRING_FULL "\0" + } + } + BLOCK "VarFileInfo" + { + VALUE "Translation", 0x409, 1200 + } +} diff --git a/src/zenstore-test/xmake.lua b/src/zenstore-test/xmake.lua new file mode 100644 index 000000000..5dbcafa3c --- /dev/null +++ b/src/zenstore-test/xmake.lua @@ -0,0 +1,8 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zenstore-test") + set_kind("binary") + add_headerfiles("**.h") + add_files("*.cpp") + add_deps("zenstore", "zencore") + add_packages("vcpkg::doctest") diff --git a/src/zenstore-test/zenstore-test.cpp b/src/zenstore-test/zenstore-test.cpp new file mode 100644 index 000000000..00c1136b6 --- /dev/null +++ b/src/zenstore-test/zenstore-test.cpp @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zencore/filesystem.h> +#include <zencore/logging.h> +#include <zencore/zencore.h> +#include <zenstore/zenstore.h> + +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC +# include <sys/time.h> +# include <sys/resource.h> +# include <zencore/except.h> +#endif + +#if ZEN_WITH_TESTS +# define ZEN_TEST_WITH_RUNNER 1 +# include <zencore/testing.h> +#endif + +int +main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) +{ +#if ZEN_WITH_TESTS + zen::zenstore_forcelinktests(); + + zen::logging::InitializeLogging(); + zen::MaximizeOpenFileCount(); + + return ZEN_RUN_TESTS(argc, argv); +#else + return 0; +#endif +} diff --git a/src/zenstore/blockstore.cpp b/src/zenstore/blockstore.cpp new file mode 100644 index 000000000..5dfa10c91 --- /dev/null +++ b/src/zenstore/blockstore.cpp @@ -0,0 +1,1312 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenstore/blockstore.h> + +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/timer.h> + +#include <algorithm> + +#if ZEN_WITH_TESTS +# include <zencore/compactbinarybuilder.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zencore/workthreadpool.h> +# include <random> +#endif + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +BlockStoreFile::BlockStoreFile(const std::filesystem::path& BlockPath) : m_Path(BlockPath) +{ +} + +BlockStoreFile::~BlockStoreFile() +{ + m_IoBuffer = IoBuffer(); + m_File.Detach(); +} + +const std::filesystem::path& +BlockStoreFile::GetPath() const +{ + return m_Path; +} + +void +BlockStoreFile::Open() +{ + m_File.Open(m_Path, BasicFile::Mode::kDelete); + void* FileHandle = m_File.Handle(); + m_IoBuffer = IoBuffer(IoBuffer::File, FileHandle, 0, m_File.FileSize()); +} + +void +BlockStoreFile::Create(uint64_t InitialSize) +{ + auto ParentPath = m_Path.parent_path(); + if (!std::filesystem::is_directory(ParentPath)) + { + CreateDirectories(ParentPath); + } + + m_File.Open(m_Path, BasicFile::Mode::kTruncateDelete); + void* FileHandle = m_File.Handle(); + + // We map our m_IoBuffer beyond the file size as we will grow it over time and want + // to be able to create sub-buffers of all the written range later + m_IoBuffer = IoBuffer(IoBuffer::File, FileHandle, 0, InitialSize); +} + +uint64_t +BlockStoreFile::FileSize() +{ + return m_File.FileSize(); +} + +void +BlockStoreFile::MarkAsDeleteOnClose() +{ + m_IoBuffer.MarkAsDeleteOnClose(); +} + +IoBuffer +BlockStoreFile::GetChunk(uint64_t Offset, uint64_t Size) +{ + return IoBuffer(m_IoBuffer, Offset, Size); +} + +void +BlockStoreFile::Read(void* Data, uint64_t Size, uint64_t FileOffset) +{ + m_File.Read(Data, Size, FileOffset); +} + +void +BlockStoreFile::Write(const void* Data, uint64_t Size, uint64_t FileOffset) +{ + m_File.Write(Data, Size, FileOffset); +} + +void +BlockStoreFile::Flush() +{ + m_File.Flush(); +} + +BasicFile& +BlockStoreFile::GetBasicFile() +{ + return m_File; +} + +void +BlockStoreFile::StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun) +{ + m_File.StreamByteRange(FileOffset, Size, std::move(ChunkFun)); +} + +constexpr uint64_t ScrubSmallChunkWindowSize = 4 * 1024 * 1024; + +void +BlockStore::Initialize(const std::filesystem::path& BlocksBasePath, + uint64_t MaxBlockSize, + uint64_t MaxBlockCount, + const std::vector<BlockStoreLocation>& KnownLocations) +{ + ZEN_ASSERT(MaxBlockSize > 0); + ZEN_ASSERT(MaxBlockCount > 0); + ZEN_ASSERT(IsPow2(MaxBlockCount)); + + m_TotalSize = 0; + m_BlocksBasePath = BlocksBasePath; + m_MaxBlockSize = MaxBlockSize; + + m_ChunkBlocks.clear(); + + std::unordered_set<uint32_t> KnownBlocks; + for (const auto& Entry : KnownLocations) + { + KnownBlocks.insert(Entry.BlockIndex); + } + + if (std::filesystem::is_directory(m_BlocksBasePath)) + { + std::vector<std::filesystem::path> FoldersToScan; + FoldersToScan.push_back(m_BlocksBasePath); + size_t FolderOffset = 0; + while (FolderOffset < FoldersToScan.size()) + { + for (const std::filesystem::directory_entry& Entry : std::filesystem::directory_iterator(FoldersToScan[FolderOffset])) + { + if (Entry.is_directory()) + { + FoldersToScan.push_back(Entry.path()); + continue; + } + if (Entry.is_regular_file()) + { + const std::filesystem::path Path = Entry.path(); + if (Path.extension() != GetBlockFileExtension()) + { + continue; + } + std::string FileName = PathToUtf8(Path.stem()); + uint32_t BlockIndex; + bool OK = ParseHexNumber(FileName, BlockIndex); + if (!OK) + { + continue; + } + if (!KnownBlocks.contains(BlockIndex)) + { + // Log removing unreferenced block + // Clear out unused blocks + ZEN_DEBUG("removing unused block at '{}'", Path); + std::error_code Ec; + std::filesystem::remove(Path, Ec); + if (Ec) + { + ZEN_WARN("Failed to delete file '{}' reason: '{}'", Path, Ec.message()); + } + continue; + } + Ref<BlockStoreFile> BlockFile{new BlockStoreFile(Path)}; + BlockFile->Open(); + m_TotalSize.fetch_add(BlockFile->FileSize(), std::memory_order::relaxed); + m_ChunkBlocks[BlockIndex] = BlockFile; + } + } + ++FolderOffset; + } + } + else + { + CreateDirectories(m_BlocksBasePath); + } +} + +void +BlockStore::Close() +{ + RwLock::ExclusiveLockScope InsertLock(m_InsertLock); + m_WriteBlock = nullptr; + m_CurrentInsertOffset = 0; + m_WriteBlockIndex = 0; + + m_ChunkBlocks.clear(); + m_BlocksBasePath.clear(); +} + +void +BlockStore::WriteChunk(const void* Data, uint64_t Size, uint64_t Alignment, const WriteChunkCallback& Callback) +{ + ZEN_ASSERT(Data != nullptr); + ZEN_ASSERT(Size > 0u); + ZEN_ASSERT(Size <= m_MaxBlockSize); + ZEN_ASSERT(Alignment > 0u); + + RwLock::ExclusiveLockScope InsertLock(m_InsertLock); + + uint32_t WriteBlockIndex = m_WriteBlockIndex.load(std::memory_order_acquire); + bool IsWriting = !!m_WriteBlock; + if (!IsWriting || (m_CurrentInsertOffset + Size) > m_MaxBlockSize) + { + if (m_WriteBlock) + { + m_WriteBlock = nullptr; + } + + if (m_ChunkBlocks.size() == m_MaxBlockCount) + { + throw std::runtime_error(fmt::format("unable to allocate a new block in '{}'", m_BlocksBasePath)); + } + + WriteBlockIndex += IsWriting ? 1 : 0; + while (m_ChunkBlocks.contains(WriteBlockIndex)) + { + WriteBlockIndex = (WriteBlockIndex + 1) & (m_MaxBlockCount - 1); + } + + std::filesystem::path BlockPath = GetBlockPath(m_BlocksBasePath, WriteBlockIndex); + + Ref<BlockStoreFile> NewBlockFile(new BlockStoreFile(BlockPath)); + NewBlockFile->Create(m_MaxBlockSize); + + m_ChunkBlocks[WriteBlockIndex] = NewBlockFile; + m_WriteBlock = NewBlockFile; + m_WriteBlockIndex.store(WriteBlockIndex, std::memory_order_release); + m_CurrentInsertOffset = 0; + } + uint64_t InsertOffset = m_CurrentInsertOffset; + m_CurrentInsertOffset = RoundUp(InsertOffset + Size, Alignment); + uint64_t AlignedWriteSize = m_CurrentInsertOffset - InsertOffset; + Ref<BlockStoreFile> WriteBlock = m_WriteBlock; + m_ActiveWriteBlocks.push_back(WriteBlockIndex); + InsertLock.ReleaseNow(); + + WriteBlock->Write(Data, Size, InsertOffset); + m_TotalSize.fetch_add(AlignedWriteSize, std::memory_order::relaxed); + + Callback({.BlockIndex = WriteBlockIndex, .Offset = InsertOffset, .Size = Size}); + + { + RwLock::ExclusiveLockScope _(m_InsertLock); + m_ActiveWriteBlocks.erase(std::find(m_ActiveWriteBlocks.begin(), m_ActiveWriteBlocks.end(), WriteBlockIndex)); + } +} + +BlockStore::ReclaimSnapshotState +BlockStore::GetReclaimSnapshotState() +{ + ReclaimSnapshotState State; + RwLock::SharedLockScope _(m_InsertLock); + for (uint32_t BlockIndex : m_ActiveWriteBlocks) + { + State.m_ActiveWriteBlocks.insert(BlockIndex); + } + if (m_WriteBlock) + { + State.m_ActiveWriteBlocks.insert(m_WriteBlockIndex); + } + State.BlockCount = m_ChunkBlocks.size(); + return State; +} + +IoBuffer +BlockStore::TryGetChunk(const BlockStoreLocation& Location) const +{ + RwLock::SharedLockScope InsertLock(m_InsertLock); + if (auto BlockIt = m_ChunkBlocks.find(Location.BlockIndex); BlockIt != m_ChunkBlocks.end()) + { + if (const Ref<BlockStoreFile>& Block = BlockIt->second; Block) + { + return Block->GetChunk(Location.Offset, Location.Size); + } + } + return IoBuffer(); +} + +void +BlockStore::Flush() +{ + RwLock::ExclusiveLockScope _(m_InsertLock); + if (m_CurrentInsertOffset > 0) + { + uint32_t WriteBlockIndex = m_WriteBlockIndex.load(std::memory_order_acquire); + WriteBlockIndex = (WriteBlockIndex + 1) & (m_MaxBlockCount - 1); + m_WriteBlock = nullptr; + m_WriteBlockIndex.store(WriteBlockIndex, std::memory_order_release); + m_CurrentInsertOffset = 0; + } +} + +void +BlockStore::ReclaimSpace(const ReclaimSnapshotState& Snapshot, + const std::vector<BlockStoreLocation>& ChunkLocations, + const ChunkIndexArray& KeepChunkIndexes, + uint64_t PayloadAlignment, + bool DryRun, + const ReclaimCallback& ChangeCallback, + const ClaimDiskReserveCallback& DiskReserveCallback) +{ + if (ChunkLocations.empty()) + { + return; + } + uint64_t WriteBlockTimeUs = 0; + uint64_t WriteBlockLongestTimeUs = 0; + uint64_t ReadBlockTimeUs = 0; + uint64_t ReadBlockLongestTimeUs = 0; + uint64_t TotalChunkCount = ChunkLocations.size(); + uint64_t DeletedSize = 0; + uint64_t OldTotalSize = 0; + uint64_t NewTotalSize = 0; + + uint64_t MovedCount = 0; + uint64_t DeletedCount = 0; + + Stopwatch TotalTimer; + const auto _ = MakeGuard([&] { + ZEN_DEBUG( + "reclaim space for '{}' DONE after {}, write lock: {} ({}), read lock: {} ({}), collected {} bytes, deleted {} and moved " + "{} " + "of {} " + "chunks ({}).", + m_BlocksBasePath, + NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()), + NiceLatencyNs(WriteBlockTimeUs), + NiceLatencyNs(WriteBlockLongestTimeUs), + NiceLatencyNs(ReadBlockTimeUs), + NiceLatencyNs(ReadBlockLongestTimeUs), + NiceBytes(DeletedSize), + DeletedCount, + MovedCount, + TotalChunkCount, + NiceBytes(OldTotalSize)); + }); + + size_t BlockCount = Snapshot.BlockCount; + if (BlockCount == 0) + { + ZEN_DEBUG("garbage collect for '{}' SKIPPED, no blocks to process", m_BlocksBasePath); + return; + } + + std::unordered_set<size_t> KeepChunkMap; + KeepChunkMap.reserve(KeepChunkIndexes.size()); + for (size_t KeepChunkIndex : KeepChunkIndexes) + { + KeepChunkMap.insert(KeepChunkIndex); + } + + std::unordered_map<uint32_t, size_t> BlockIndexToChunkMapIndex; + std::vector<ChunkIndexArray> BlockKeepChunks; + std::vector<ChunkIndexArray> BlockDeleteChunks; + + BlockIndexToChunkMapIndex.reserve(BlockCount); + BlockKeepChunks.reserve(BlockCount); + BlockDeleteChunks.reserve(BlockCount); + size_t GuesstimateCountPerBlock = TotalChunkCount / BlockCount / 2; + + size_t DeleteCount = 0; + for (size_t Index = 0; Index < TotalChunkCount; ++Index) + { + const BlockStoreLocation& Location = ChunkLocations[Index]; + OldTotalSize += Location.Size; + if (Snapshot.m_ActiveWriteBlocks.contains(Location.BlockIndex)) + { + continue; + } + + auto BlockIndexPtr = BlockIndexToChunkMapIndex.find(Location.BlockIndex); + size_t ChunkMapIndex = 0; + if (BlockIndexPtr == BlockIndexToChunkMapIndex.end()) + { + ChunkMapIndex = BlockKeepChunks.size(); + BlockIndexToChunkMapIndex[Location.BlockIndex] = ChunkMapIndex; + BlockKeepChunks.resize(ChunkMapIndex + 1); + BlockKeepChunks.back().reserve(GuesstimateCountPerBlock); + BlockDeleteChunks.resize(ChunkMapIndex + 1); + BlockDeleteChunks.back().reserve(GuesstimateCountPerBlock); + } + else + { + ChunkMapIndex = BlockIndexPtr->second; + } + + if (KeepChunkMap.contains(Index)) + { + ChunkIndexArray& IndexMap = BlockKeepChunks[ChunkMapIndex]; + IndexMap.push_back(Index); + NewTotalSize += Location.Size; + continue; + } + ChunkIndexArray& IndexMap = BlockDeleteChunks[ChunkMapIndex]; + IndexMap.push_back(Index); + DeleteCount++; + } + + std::unordered_set<uint32_t> BlocksToReWrite; + BlocksToReWrite.reserve(BlockIndexToChunkMapIndex.size()); + for (const auto& Entry : BlockIndexToChunkMapIndex) + { + uint32_t BlockIndex = Entry.first; + size_t ChunkMapIndex = Entry.second; + const ChunkIndexArray& ChunkMap = BlockDeleteChunks[ChunkMapIndex]; + if (ChunkMap.empty()) + { + continue; + } + BlocksToReWrite.insert(BlockIndex); + } + + if (DryRun) + { + ZEN_DEBUG("garbage collect for '{}' DISABLED, found {} {} chunks of total {} {}", + m_BlocksBasePath, + DeleteCount, + NiceBytes(OldTotalSize - NewTotalSize), + TotalChunkCount, + OldTotalSize); + return; + } + + Ref<BlockStoreFile> NewBlockFile; + try + { + uint64_t WriteOffset = 0; + uint32_t NewBlockIndex = 0; + for (uint32_t BlockIndex : BlocksToReWrite) + { + const size_t ChunkMapIndex = BlockIndexToChunkMapIndex[BlockIndex]; + + Ref<BlockStoreFile> OldBlockFile; + { + RwLock::SharedLockScope _i(m_InsertLock); + Stopwatch Timer; + const auto __ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + OldBlockFile = m_ChunkBlocks[BlockIndex]; + } + + if (!OldBlockFile) + { + // If the block file pointed to does not exist, move them all to deleted list + BlockDeleteChunks[ChunkMapIndex].insert(BlockDeleteChunks[ChunkMapIndex].end(), + BlockKeepChunks[ChunkMapIndex].begin(), + BlockKeepChunks[ChunkMapIndex].end()); + BlockKeepChunks[ChunkMapIndex].clear(); + } + + const ChunkIndexArray& KeepMap = BlockKeepChunks[ChunkMapIndex]; + if (KeepMap.empty()) + { + const ChunkIndexArray& DeleteMap = BlockDeleteChunks[ChunkMapIndex]; + for (size_t DeleteIndex : DeleteMap) + { + DeletedSize += ChunkLocations[DeleteIndex].Size; + } + ChangeCallback({}, DeleteMap); + DeletedCount += DeleteMap.size(); + { + RwLock::ExclusiveLockScope _i(m_InsertLock); + Stopwatch Timer; + const auto __ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + ReadBlockTimeUs += ElapsedUs; + ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs); + }); + if (OldBlockFile) + { + m_ChunkBlocks[BlockIndex] = nullptr; + ZEN_DEBUG("marking cas block store file '{}' for delete, block #{}", OldBlockFile->GetPath(), BlockIndex); + m_TotalSize.fetch_sub(OldBlockFile->FileSize(), std::memory_order::relaxed); + OldBlockFile->MarkAsDeleteOnClose(); + } + } + continue; + } + + ZEN_ASSERT(OldBlockFile); + + MovedChunksArray MovedChunks; + std::vector<uint8_t> Chunk; + for (const size_t& ChunkIndex : KeepMap) + { + const BlockStoreLocation ChunkLocation = ChunkLocations[ChunkIndex]; + Chunk.resize(ChunkLocation.Size); + OldBlockFile->Read(Chunk.data(), Chunk.size(), ChunkLocation.Offset); + + if (!NewBlockFile || (WriteOffset + Chunk.size() > m_MaxBlockSize)) + { + uint32_t NextBlockIndex = m_WriteBlockIndex.load(std::memory_order_relaxed); + + if (NewBlockFile) + { + NewBlockFile->Flush(); + NewBlockFile = nullptr; + } + { + ChangeCallback(MovedChunks, {}); + MovedCount += KeepMap.size(); + MovedChunks.clear(); + RwLock::ExclusiveLockScope __(m_InsertLock); + Stopwatch Timer; + const auto ___ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + ReadBlockTimeUs += ElapsedUs; + ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs); + }); + if (m_ChunkBlocks.size() == m_MaxBlockCount) + { + ZEN_ERROR("unable to allocate a new block in '{}', count limit {} exeeded", + m_BlocksBasePath, + static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1); + return; + } + while (m_ChunkBlocks.contains(NextBlockIndex)) + { + NextBlockIndex = (NextBlockIndex + 1) & (m_MaxBlockCount - 1); + } + std::filesystem::path NewBlockPath = GetBlockPath(m_BlocksBasePath, NextBlockIndex); + NewBlockFile = new BlockStoreFile(NewBlockPath); + m_ChunkBlocks[NextBlockIndex] = NewBlockFile; + } + + std::error_code Error; + DiskSpace Space = DiskSpaceInfo(m_BlocksBasePath, Error); + if (Error) + { + ZEN_ERROR("get disk space in '{}' FAILED, reason: '{}'", m_BlocksBasePath, Error.message()); + return; + } + if (Space.Free < m_MaxBlockSize) + { + uint64_t ReclaimedSpace = DiskReserveCallback(); + if (Space.Free + ReclaimedSpace < m_MaxBlockSize) + { + ZEN_WARN("garbage collect for '{}' FAILED, required disk space {}, free {}", + m_BlocksBasePath, + m_MaxBlockSize, + NiceBytes(Space.Free + ReclaimedSpace)); + RwLock::ExclusiveLockScope _l(m_InsertLock); + Stopwatch Timer; + const auto __ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + ReadBlockTimeUs += ElapsedUs; + ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs); + }); + m_ChunkBlocks.erase(NextBlockIndex); + return; + } + + ZEN_INFO("using gc reserve for '{}', reclaimed {}, disk free {}", + m_BlocksBasePath, + ReclaimedSpace, + NiceBytes(Space.Free + ReclaimedSpace)); + } + NewBlockFile->Create(m_MaxBlockSize); + NewBlockIndex = NextBlockIndex; + WriteOffset = 0; + } + + NewBlockFile->Write(Chunk.data(), Chunk.size(), WriteOffset); + MovedChunks.push_back({ChunkIndex, {.BlockIndex = NewBlockIndex, .Offset = WriteOffset, .Size = Chunk.size()}}); + uint64_t OldOffset = WriteOffset; + WriteOffset = RoundUp(WriteOffset + Chunk.size(), PayloadAlignment); + m_TotalSize.fetch_add(WriteOffset - OldOffset, std::memory_order::relaxed); + } + Chunk.clear(); + if (NewBlockFile) + { + NewBlockFile->Flush(); + NewBlockFile = nullptr; + } + + const ChunkIndexArray& DeleteMap = BlockDeleteChunks[ChunkMapIndex]; + for (size_t DeleteIndex : DeleteMap) + { + DeletedSize += ChunkLocations[DeleteIndex].Size; + } + + ChangeCallback(MovedChunks, DeleteMap); + MovedCount += KeepMap.size(); + DeletedCount += DeleteMap.size(); + MovedChunks.clear(); + { + RwLock::ExclusiveLockScope __(m_InsertLock); + Stopwatch Timer; + const auto ___ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + ReadBlockTimeUs += ElapsedUs; + ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs); + }); + m_ChunkBlocks[BlockIndex] = nullptr; + ZEN_DEBUG("marking cas block store file '{}' for delete, block #{}", OldBlockFile->GetPath(), BlockIndex); + m_TotalSize.fetch_sub(OldBlockFile->FileSize(), std::memory_order::relaxed); + OldBlockFile->MarkAsDeleteOnClose(); + } + } + } + catch (std::exception& ex) + { + ZEN_ERROR("reclaiming space for '{}' failed with: '{}'", m_BlocksBasePath, ex.what()); + if (NewBlockFile) + { + ZEN_DEBUG("dropping incomplete cas block store file '{}'", NewBlockFile->GetPath()); + m_TotalSize.fetch_sub(NewBlockFile->FileSize(), std::memory_order::relaxed); + NewBlockFile->MarkAsDeleteOnClose(); + } + } +} + +void +BlockStore::IterateChunks(const std::vector<BlockStoreLocation>& ChunkLocations, + const IterateChunksSmallSizeCallback& SmallSizeCallback, + const IterateChunksLargeSizeCallback& LargeSizeCallback) +{ + std::vector<size_t> LocationIndexes; + LocationIndexes.reserve(ChunkLocations.size()); + for (size_t ChunkIndex = 0; ChunkIndex < ChunkLocations.size(); ++ChunkIndex) + { + LocationIndexes.push_back(ChunkIndex); + } + std::sort(LocationIndexes.begin(), LocationIndexes.end(), [&](size_t IndexA, size_t IndexB) -> bool { + const BlockStoreLocation& LocationA = ChunkLocations[IndexA]; + const BlockStoreLocation& LocationB = ChunkLocations[IndexB]; + if (LocationA.BlockIndex < LocationB.BlockIndex) + { + return true; + } + else if (LocationA.BlockIndex > LocationB.BlockIndex) + { + return false; + } + return LocationA.Offset < LocationB.Offset; + }); + + IoBuffer ReadBuffer{ScrubSmallChunkWindowSize}; + void* BufferBase = ReadBuffer.MutableData(); + + RwLock::SharedLockScope _(m_InsertLock); + + auto GetNextRange = [&](size_t StartIndexOffset) { + size_t ChunkCount = 0; + size_t StartIndex = LocationIndexes[StartIndexOffset]; + const BlockStoreLocation& StartLocation = ChunkLocations[StartIndex]; + uint64_t StartOffset = StartLocation.Offset; + while (StartIndexOffset + ChunkCount < LocationIndexes.size()) + { + size_t NextIndex = LocationIndexes[StartIndexOffset + ChunkCount]; + const BlockStoreLocation& Location = ChunkLocations[NextIndex]; + if (Location.BlockIndex != StartLocation.BlockIndex) + { + break; + } + if ((Location.Offset + Location.Size) - StartOffset > ScrubSmallChunkWindowSize) + { + break; + } + ++ChunkCount; + } + return ChunkCount; + }; + + size_t LocationIndexOffset = 0; + while (LocationIndexOffset < LocationIndexes.size()) + { + size_t ChunkIndex = LocationIndexes[LocationIndexOffset]; + const BlockStoreLocation& FirstLocation = ChunkLocations[ChunkIndex]; + + const Ref<BlockStoreFile>& BlockFile = m_ChunkBlocks[FirstLocation.BlockIndex]; + if (!BlockFile) + { + while (ChunkLocations[ChunkIndex].BlockIndex == FirstLocation.BlockIndex) + { + SmallSizeCallback(ChunkIndex, nullptr, 0); + LocationIndexOffset++; + if (LocationIndexOffset == LocationIndexes.size()) + { + break; + } + ChunkIndex = LocationIndexes[LocationIndexOffset]; + } + continue; + } + size_t BlockSize = BlockFile->FileSize(); + size_t RangeCount = GetNextRange(LocationIndexOffset); + if (RangeCount > 0) + { + size_t LastChunkIndex = LocationIndexes[LocationIndexOffset + RangeCount - 1]; + const BlockStoreLocation& LastLocation = ChunkLocations[LastChunkIndex]; + uint64_t Size = LastLocation.Offset + LastLocation.Size - FirstLocation.Offset; + BlockFile->Read(BufferBase, Size, FirstLocation.Offset); + for (size_t RangeIndex = 0; RangeIndex < RangeCount; ++RangeIndex) + { + size_t NextChunkIndex = LocationIndexes[LocationIndexOffset + RangeIndex]; + const BlockStoreLocation& ChunkLocation = ChunkLocations[NextChunkIndex]; + if (ChunkLocation.Size == 0 || (ChunkLocation.Offset + ChunkLocation.Size > BlockSize)) + { + SmallSizeCallback(NextChunkIndex, nullptr, 0); + continue; + } + void* BufferPtr = &((char*)BufferBase)[ChunkLocation.Offset - FirstLocation.Offset]; + SmallSizeCallback(NextChunkIndex, BufferPtr, ChunkLocation.Size); + } + LocationIndexOffset += RangeCount; + continue; + } + if (FirstLocation.Size == 0 || (FirstLocation.Offset + FirstLocation.Size > BlockSize)) + { + SmallSizeCallback(ChunkIndex, nullptr, 0); + LocationIndexOffset++; + continue; + } + LargeSizeCallback(ChunkIndex, *BlockFile.Get(), FirstLocation.Offset, FirstLocation.Size); + LocationIndexOffset++; + } +} + +const char* +BlockStore::GetBlockFileExtension() +{ + return ".ucas"; +} + +std::filesystem::path +BlockStore::GetBlockPath(const std::filesystem::path& BlocksBasePath, const uint32_t BlockIndex) +{ + ExtendablePathBuilder<256> Path; + + char BlockHexString[9]; + ToHexNumber(BlockIndex, BlockHexString); + + Path.Append(BlocksBasePath); + Path.AppendSeparator(); + Path.AppendAsciiRange(BlockHexString, BlockHexString + 4); + Path.AppendSeparator(); + Path.Append(BlockHexString); + Path.Append(GetBlockFileExtension()); + return Path.ToPath(); +} + +#if ZEN_WITH_TESTS + +TEST_CASE("blockstore.blockstoredisklocation") +{ + BlockStoreLocation Zero = BlockStoreLocation{.BlockIndex = 0, .Offset = 0, .Size = 0}; + CHECK(Zero == BlockStoreDiskLocation(Zero, 4).Get(4)); + + BlockStoreLocation MaxBlockIndex = BlockStoreLocation{.BlockIndex = BlockStoreDiskLocation::MaxBlockIndex, .Offset = 0, .Size = 0}; + CHECK(MaxBlockIndex == BlockStoreDiskLocation(MaxBlockIndex, 4).Get(4)); + + BlockStoreLocation MaxOffset = BlockStoreLocation{.BlockIndex = 0, .Offset = BlockStoreDiskLocation::MaxOffset * 4, .Size = 0}; + CHECK(MaxOffset == BlockStoreDiskLocation(MaxOffset, 4).Get(4)); + + BlockStoreLocation MaxSize = BlockStoreLocation{.BlockIndex = 0, .Offset = 0, .Size = std::numeric_limits<uint32_t>::max()}; + CHECK(MaxSize == BlockStoreDiskLocation(MaxSize, 4).Get(4)); + + BlockStoreLocation MaxBlockIndexAndOffset = + BlockStoreLocation{.BlockIndex = BlockStoreDiskLocation::MaxBlockIndex, .Offset = BlockStoreDiskLocation::MaxOffset * 4, .Size = 0}; + CHECK(MaxBlockIndexAndOffset == BlockStoreDiskLocation(MaxBlockIndexAndOffset, 4).Get(4)); + + BlockStoreLocation MaxAll = BlockStoreLocation{.BlockIndex = BlockStoreDiskLocation::MaxBlockIndex, + .Offset = BlockStoreDiskLocation::MaxOffset * 4, + .Size = std::numeric_limits<uint32_t>::max()}; + CHECK(MaxAll == BlockStoreDiskLocation(MaxAll, 4).Get(4)); + + BlockStoreLocation MaxAll4096 = BlockStoreLocation{.BlockIndex = BlockStoreDiskLocation::MaxBlockIndex, + .Offset = BlockStoreDiskLocation::MaxOffset * 4096, + .Size = std::numeric_limits<uint32_t>::max()}; + CHECK(MaxAll4096 == BlockStoreDiskLocation(MaxAll4096, 4096).Get(4096)); + + BlockStoreLocation Middle = BlockStoreLocation{.BlockIndex = (BlockStoreDiskLocation::MaxBlockIndex) / 2, + .Offset = ((BlockStoreDiskLocation::MaxOffset) / 2) * 4, + .Size = std::numeric_limits<uint32_t>::max() / 2}; + CHECK(Middle == BlockStoreDiskLocation(Middle, 4).Get(4)); +} + +TEST_CASE("blockstore.blockfile") +{ + ScopedTemporaryDirectory TempDir; + auto RootDirectory = TempDir.Path() / "blocks"; + CreateDirectories(RootDirectory); + + { + BlockStoreFile File1(RootDirectory / "1"); + File1.Create(16384); + CHECK(File1.FileSize() == 0); + File1.Write("data", 5, 0); + IoBuffer DataChunk = File1.GetChunk(0, 5); + File1.Write("boop", 5, 5); + IoBuffer BoopChunk = File1.GetChunk(5, 5); + const char* Data = static_cast<const char*>(DataChunk.GetData()); + CHECK(std::string(Data) == "data"); + const char* Boop = static_cast<const char*>(BoopChunk.GetData()); + CHECK(std::string(Boop) == "boop"); + File1.Flush(); + CHECK(File1.FileSize() == 10); + } + { + BlockStoreFile File1(RootDirectory / "1"); + File1.Open(); + + char DataRaw[5]; + File1.Read(DataRaw, 5, 0); + CHECK(std::string(DataRaw) == "data"); + IoBuffer DataChunk = File1.GetChunk(0, 5); + + char BoopRaw[5]; + File1.Read(BoopRaw, 5, 5); + CHECK(std::string(BoopRaw) == "boop"); + + IoBuffer BoopChunk = File1.GetChunk(5, 5); + const char* Data = static_cast<const char*>(DataChunk.GetData()); + CHECK(std::string(Data) == "data"); + const char* Boop = static_cast<const char*>(BoopChunk.GetData()); + CHECK(std::string(Boop) == "boop"); + } + + { + IoBuffer DataChunk; + IoBuffer BoopChunk; + + { + BlockStoreFile File1(RootDirectory / "1"); + File1.Open(); + DataChunk = File1.GetChunk(0, 5); + BoopChunk = File1.GetChunk(5, 5); + } + + CHECK(std::filesystem::exists(RootDirectory / "1")); + + const char* Data = static_cast<const char*>(DataChunk.GetData()); + CHECK(std::string(Data) == "data"); + const char* Boop = static_cast<const char*>(BoopChunk.GetData()); + CHECK(std::string(Boop) == "boop"); + } + CHECK(std::filesystem::exists(RootDirectory / "1")); + + { + IoBuffer DataChunk; + IoBuffer BoopChunk; + + { + BlockStoreFile File1(RootDirectory / "1"); + File1.Open(); + File1.MarkAsDeleteOnClose(); + DataChunk = File1.GetChunk(0, 5); + BoopChunk = File1.GetChunk(5, 5); + } + + const char* Data = static_cast<const char*>(DataChunk.GetData()); + CHECK(std::string(Data) == "data"); + const char* Boop = static_cast<const char*>(BoopChunk.GetData()); + CHECK(std::string(Boop) == "boop"); + } + CHECK(!std::filesystem::exists(RootDirectory / "1")); +} + +namespace blockstore::impl { + BlockStoreLocation WriteStringAsChunk(BlockStore& Store, std::string_view String, size_t PayloadAlignment) + { + BlockStoreLocation Location; + Store.WriteChunk(String.data(), String.length(), PayloadAlignment, [&](const BlockStoreLocation& L) { Location = L; }); + CHECK(Location.Size == String.length()); + return Location; + }; + + std::string ReadChunkAsString(BlockStore& Store, const BlockStoreLocation& Location) + { + IoBuffer ChunkData = Store.TryGetChunk(Location); + if (!ChunkData) + { + return ""; + } + std::string AsString((const char*)ChunkData.Data(), ChunkData.Size()); + return AsString; + }; + + std::vector<std::filesystem::path> GetDirectoryContent(std::filesystem::path RootDir, bool Files, bool Directories) + { + DirectoryContent DirectoryContent; + GetDirectoryContent(RootDir, + DirectoryContent::RecursiveFlag | (Files ? DirectoryContent::IncludeFilesFlag : 0) | + (Directories ? DirectoryContent::IncludeDirsFlag : 0), + DirectoryContent); + std::vector<std::filesystem::path> Result; + Result.insert(Result.end(), DirectoryContent.Directories.begin(), DirectoryContent.Directories.end()); + Result.insert(Result.end(), DirectoryContent.Files.begin(), DirectoryContent.Files.end()); + return Result; + }; + + static IoBuffer CreateChunk(uint64_t Size) + { + static std::random_device rd; + static std::mt19937 g(rd()); + + std::vector<uint8_t> Values; + Values.resize(Size); + for (size_t Idx = 0; Idx < Size; ++Idx) + { + Values[Idx] = static_cast<uint8_t>(Idx); + } + std::shuffle(Values.begin(), Values.end(), g); + + return IoBufferBuilder::MakeCloneFromMemory(Values.data(), Values.size()); + } +} // namespace blockstore::impl + +TEST_CASE("blockstore.chunks") +{ + using namespace blockstore::impl; + + ScopedTemporaryDirectory TempDir; + auto RootDirectory = TempDir.Path(); + + BlockStore Store; + Store.Initialize(RootDirectory, 128, 1024, {}); + IoBuffer BadChunk = Store.TryGetChunk({.BlockIndex = 0, .Offset = 0, .Size = 512}); + CHECK(!BadChunk); + + std::string FirstChunkData = "This is the data of the first chunk that we will write"; + BlockStoreLocation FirstChunkLocation = WriteStringAsChunk(Store, FirstChunkData, 4); + std::string SecondChunkData = "This is the data for the second chunk that we will write"; + BlockStoreLocation SecondChunkLocation = WriteStringAsChunk(Store, SecondChunkData, 4); + + CHECK(ReadChunkAsString(Store, FirstChunkLocation) == FirstChunkData); + CHECK(ReadChunkAsString(Store, SecondChunkLocation) == SecondChunkData); + + std::string ThirdChunkData = + "This is a much longer string that will not fit in the first block so it should be placed in the second block"; + BlockStoreLocation ThirdChunkLocation = WriteStringAsChunk(Store, ThirdChunkData, 4); + CHECK(ThirdChunkLocation.BlockIndex != FirstChunkLocation.BlockIndex); + + CHECK(ReadChunkAsString(Store, FirstChunkLocation) == FirstChunkData); + CHECK(ReadChunkAsString(Store, SecondChunkLocation) == SecondChunkData); + CHECK(ReadChunkAsString(Store, ThirdChunkLocation) == ThirdChunkData); +} + +TEST_CASE("blockstore.clean.stray.blocks") +{ + using namespace blockstore::impl; + + ScopedTemporaryDirectory TempDir; + auto RootDirectory = TempDir.Path(); + + BlockStore Store; + Store.Initialize(RootDirectory / "store", 128, 1024, {}); + + std::string FirstChunkData = "This is the data of the first chunk that we will write"; + BlockStoreLocation FirstChunkLocation = WriteStringAsChunk(Store, FirstChunkData, 4); + std::string SecondChunkData = "This is the data for the second chunk that we will write"; + BlockStoreLocation SecondChunkLocation = WriteStringAsChunk(Store, SecondChunkData, 4); + std::string ThirdChunkData = + "This is a much longer string that will not fit in the first block so it should be placed in the second block"; + WriteStringAsChunk(Store, ThirdChunkData, 4); + + Store.Close(); + + // Not referencing the second block means that we should be deleted + Store.Initialize(RootDirectory / "store", 128, 1024, {FirstChunkLocation, SecondChunkLocation}); + + CHECK(GetDirectoryContent(RootDirectory / "store", true, false).size() == 1); +} + +TEST_CASE("blockstore.flush.forces.new.block") +{ + using namespace blockstore::impl; + + ScopedTemporaryDirectory TempDir; + auto RootDirectory = TempDir.Path(); + + BlockStore Store; + Store.Initialize(RootDirectory / "store", 128, 1024, {}); + + std::string FirstChunkData = "This is the data of the first chunk that we will write"; + WriteStringAsChunk(Store, FirstChunkData, 4); + Store.Flush(); + std::string SecondChunkData = "This is the data for the second chunk that we will write"; + WriteStringAsChunk(Store, SecondChunkData, 4); + Store.Flush(); + std::string ThirdChunkData = + "This is a much longer string that will not fit in the first block so it should be placed in the second block"; + WriteStringAsChunk(Store, ThirdChunkData, 4); + + CHECK(GetDirectoryContent(RootDirectory / "store", true, false).size() == 3); +} + +TEST_CASE("blockstore.iterate.chunks") +{ + using namespace blockstore::impl; + + ScopedTemporaryDirectory TempDir; + auto RootDirectory = TempDir.Path(); + + BlockStore Store; + Store.Initialize(RootDirectory / "store", ScrubSmallChunkWindowSize * 2, 1024, {}); + IoBuffer BadChunk = Store.TryGetChunk({.BlockIndex = 0, .Offset = 0, .Size = 512}); + CHECK(!BadChunk); + + std::string FirstChunkData = "This is the data of the first chunk that we will write"; + BlockStoreLocation FirstChunkLocation = WriteStringAsChunk(Store, FirstChunkData, 4); + + std::string SecondChunkData = "This is the data for the second chunk that we will write"; + BlockStoreLocation SecondChunkLocation = WriteStringAsChunk(Store, SecondChunkData, 4); + Store.Flush(); + + std::string VeryLargeChunk(ScrubSmallChunkWindowSize * 2, 'L'); + BlockStoreLocation VeryLargeChunkLocation = WriteStringAsChunk(Store, VeryLargeChunk, 4); + + BlockStoreLocation BadLocationZeroSize = {.BlockIndex = 0, .Offset = 0, .Size = 0}; + BlockStoreLocation BadLocationOutOfRange = {.BlockIndex = 0, + .Offset = ScrubSmallChunkWindowSize, + .Size = ScrubSmallChunkWindowSize * 2}; + BlockStoreLocation BadBlockIndex = {.BlockIndex = 0xfffff, .Offset = 1024, .Size = 1024}; + + Store.IterateChunks( + {FirstChunkLocation, SecondChunkLocation, VeryLargeChunkLocation, BadLocationZeroSize, BadLocationOutOfRange, BadBlockIndex}, + [&](size_t ChunkIndex, const void* Data, uint64_t Size) { + switch (ChunkIndex) + { + case 0: + CHECK(Data); + CHECK(Size == FirstChunkData.size()); + CHECK(std::string((const char*)Data, Size) == FirstChunkData); + break; + case 1: + CHECK(Data); + CHECK(Size == SecondChunkData.size()); + CHECK(std::string((const char*)Data, Size) == SecondChunkData); + break; + case 2: + CHECK(false); + break; + case 3: + CHECK(!Data); + break; + case 4: + CHECK(!Data); + break; + case 5: + CHECK(!Data); + break; + default: + CHECK(false); + break; + } + }, + [&](size_t ChunkIndex, BlockStoreFile& File, uint64_t Offset, uint64_t Size) { + switch (ChunkIndex) + { + case 0: + case 1: + CHECK(false); + break; + case 2: + { + CHECK(Size == VeryLargeChunk.size()); + char* Buffer = new char[Size]; + size_t HashOffset = 0; + File.StreamByteRange(Offset, Size, [&](const void* Data, uint64_t Size) { + memcpy(&Buffer[HashOffset], Data, Size); + HashOffset += Size; + }); + CHECK(memcmp(Buffer, VeryLargeChunk.data(), Size) == 0); + delete[] Buffer; + } + break; + case 3: + CHECK(false); + break; + case 4: + CHECK(false); + break; + case 5: + CHECK(false); + break; + default: + CHECK(false); + break; + } + }); +} + +TEST_CASE("blockstore.reclaim.space") +{ + using namespace blockstore::impl; + + ScopedTemporaryDirectory TempDir; + auto RootDirectory = TempDir.Path(); + + BlockStore Store; + Store.Initialize(RootDirectory / "store", 512, 1024, {}); + + constexpr size_t ChunkCount = 200; + constexpr size_t Alignment = 8; + std::vector<BlockStoreLocation> ChunkLocations; + std::vector<IoHash> ChunkHashes; + ChunkLocations.reserve(ChunkCount); + ChunkHashes.reserve(ChunkCount); + for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex) + { + IoBuffer Chunk = CreateChunk(57 + ChunkIndex); + + Store.WriteChunk(Chunk.Data(), Chunk.Size(), Alignment, [&](const BlockStoreLocation& L) { ChunkLocations.push_back(L); }); + ChunkHashes.push_back(IoHash::HashBuffer(Chunk.Data(), Chunk.Size())); + } + + std::vector<size_t> ChunksToKeep; + ChunksToKeep.reserve(ChunkLocations.size()); + for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex) + { + ChunksToKeep.push_back(ChunkIndex); + } + + Store.Flush(); + BlockStore::ReclaimSnapshotState State1 = Store.GetReclaimSnapshotState(); + Store.ReclaimSpace(State1, ChunkLocations, ChunksToKeep, Alignment, true); + + // If we keep all the chunks we should not get any callbacks on moved/deleted stuff + Store.ReclaimSpace( + State1, + ChunkLocations, + ChunksToKeep, + Alignment, + false, + [](const BlockStore::MovedChunksArray&, const BlockStore::ChunkIndexArray&) { CHECK(false); }, + []() { + CHECK(false); + return 0; + }); + + size_t DeleteChunkCount = 38; + ChunksToKeep.clear(); + for (size_t ChunkIndex = DeleteChunkCount; ChunkIndex < ChunkCount; ++ChunkIndex) + { + ChunksToKeep.push_back(ChunkIndex); + } + + std::vector<BlockStoreLocation> NewChunkLocations = ChunkLocations; + size_t MovedChunkCount = 0; + size_t DeletedChunkCount = 0; + Store.ReclaimSpace( + State1, + ChunkLocations, + ChunksToKeep, + Alignment, + false, + [&](const BlockStore::MovedChunksArray& MovedChunks, const BlockStore::ChunkIndexArray& DeletedChunks) { + for (const auto& MovedChunk : MovedChunks) + { + CHECK(MovedChunk.first >= DeleteChunkCount); + NewChunkLocations[MovedChunk.first] = MovedChunk.second; + } + MovedChunkCount += MovedChunks.size(); + for (size_t DeletedIndex : DeletedChunks) + { + CHECK(DeletedIndex < DeleteChunkCount); + } + DeletedChunkCount += DeletedChunks.size(); + }, + []() { + CHECK(false); + return 0; + }); + CHECK(MovedChunkCount <= DeleteChunkCount); + CHECK(DeletedChunkCount == DeleteChunkCount); + ChunkLocations = std::vector<BlockStoreLocation>(NewChunkLocations.begin() + DeleteChunkCount, NewChunkLocations.end()); + + for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex) + { + IoBuffer ChunkBlock = Store.TryGetChunk(NewChunkLocations[ChunkIndex]); + if (ChunkIndex >= DeleteChunkCount) + { + IoBuffer VerifyChunk = Store.TryGetChunk(NewChunkLocations[ChunkIndex]); + CHECK(VerifyChunk); + IoHash VerifyHash = IoHash::HashBuffer(VerifyChunk.Data(), VerifyChunk.Size()); + CHECK(VerifyHash == ChunkHashes[ChunkIndex]); + } + } + + NewChunkLocations = ChunkLocations; + MovedChunkCount = 0; + DeletedChunkCount = 0; + Store.ReclaimSpace( + State1, + ChunkLocations, + {}, + Alignment, + false, + [&](const BlockStore::MovedChunksArray& MovedChunks, const BlockStore::ChunkIndexArray& DeletedChunks) { + CHECK(MovedChunks.empty()); + DeletedChunkCount += DeletedChunks.size(); + }, + []() { + CHECK(false); + return 0; + }); + CHECK(DeletedChunkCount == ChunkCount - DeleteChunkCount); +} + +TEST_CASE("blockstore.thread.read.write") +{ + using namespace blockstore::impl; + + ScopedTemporaryDirectory TempDir; + auto RootDirectory = TempDir.Path(); + + BlockStore Store; + Store.Initialize(RootDirectory / "store", 1088, 1024, {}); + + constexpr size_t ChunkCount = 1000; + constexpr size_t Alignment = 8; + std::vector<IoBuffer> Chunks; + std::vector<IoHash> ChunkHashes; + Chunks.reserve(ChunkCount); + ChunkHashes.reserve(ChunkCount); + for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex) + { + IoBuffer Chunk = CreateChunk(57 + ChunkIndex / 2); + Chunks.push_back(Chunk); + ChunkHashes.push_back(IoHash::HashBuffer(Chunk.Data(), Chunk.Size())); + } + + std::vector<BlockStoreLocation> ChunkLocations; + ChunkLocations.resize(ChunkCount); + + WorkerThreadPool WorkerPool(8); + std::atomic<size_t> WorkCompleted = 0; + for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex) + { + WorkerPool.ScheduleWork([&Store, ChunkIndex, &Chunks, &ChunkLocations, &WorkCompleted]() { + IoBuffer& Chunk = Chunks[ChunkIndex]; + Store.WriteChunk(Chunk.Data(), Chunk.Size(), Alignment, [&](const BlockStoreLocation& L) { ChunkLocations[ChunkIndex] = L; }); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < Chunks.size()) + { + Sleep(1); + } + + WorkCompleted = 0; + for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex) + { + WorkerPool.ScheduleWork([&Store, ChunkIndex, &ChunkLocations, &ChunkHashes, &WorkCompleted]() { + IoBuffer VerifyChunk = Store.TryGetChunk(ChunkLocations[ChunkIndex]); + CHECK(VerifyChunk); + IoHash VerifyHash = IoHash::HashBuffer(VerifyChunk.Data(), VerifyChunk.Size()); + CHECK(VerifyHash == ChunkHashes[ChunkIndex]); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < Chunks.size()) + { + Sleep(1); + } + + std::vector<BlockStoreLocation> SecondChunkLocations; + SecondChunkLocations.resize(ChunkCount); + WorkCompleted = 0; + for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex) + { + WorkerPool.ScheduleWork([&Store, ChunkIndex, &Chunks, &SecondChunkLocations, &WorkCompleted]() { + IoBuffer& Chunk = Chunks[ChunkIndex]; + Store.WriteChunk(Chunk.Data(), Chunk.Size(), Alignment, [&](const BlockStoreLocation& L) { + SecondChunkLocations[ChunkIndex] = L; + }); + WorkCompleted.fetch_add(1); + }); + WorkerPool.ScheduleWork([&Store, ChunkIndex, &ChunkLocations, &ChunkHashes, &WorkCompleted]() { + IoBuffer VerifyChunk = Store.TryGetChunk(ChunkLocations[ChunkIndex]); + CHECK(VerifyChunk); + IoHash VerifyHash = IoHash::HashBuffer(VerifyChunk.Data(), VerifyChunk.Size()); + CHECK(VerifyHash == ChunkHashes[ChunkIndex]); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < Chunks.size() * 2) + { + Sleep(1); + } +} + +#endif + +void +blockstore_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenstore/cas.cpp b/src/zenstore/cas.cpp new file mode 100644 index 000000000..fdec78c60 --- /dev/null +++ b/src/zenstore/cas.cpp @@ -0,0 +1,355 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "cas.h" + +#include "compactcas.h" +#include "filecas.h" + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/except.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> +#include <zencore/thread.h> +#include <zencore/trace.h> +#include <zencore/uid.h> +#include <zenstore/cidstore.h> +#include <zenstore/gc.h> +#include <zenstore/scrubcontext.h> + +#include <gsl/gsl-lite.hpp> + +#include <filesystem> +#include <functional> +#include <unordered_map> + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +/** + * CAS store implementation + * + * Uses a basic strategy of splitting payloads by size, to improve ability to reclaim space + * quickly for unused large chunks and to maintain locality for small chunks which are + * frequently accessed together. + * + */ +class CasImpl : public CasStore +{ +public: + CasImpl(GcManager& Gc); + virtual ~CasImpl(); + + virtual void Initialize(const CidStoreConfiguration& InConfig) override; + virtual CasStore::InsertResult InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, InsertMode Mode) override; + virtual IoBuffer FindChunk(const IoHash& ChunkHash) override; + virtual bool ContainsChunk(const IoHash& ChunkHash) override; + virtual void FilterChunks(HashKeySet& InOutChunks) override; + virtual void Flush() override; + virtual void Scrub(ScrubContext& Ctx) override; + virtual void GarbageCollect(GcContext& GcCtx) override; + virtual CidStoreSize TotalSize() const override; + +private: + CasContainerStrategy m_TinyStrategy; + CasContainerStrategy m_SmallStrategy; + FileCasStrategy m_LargeStrategy; + CbObject m_ManifestObject; + + enum class StorageScheme + { + Legacy = 0, + WithCbManifest = 1 + }; + + StorageScheme m_StorageScheme = StorageScheme::Legacy; + + bool OpenOrCreateManifest(); + void UpdateManifest(); +}; + +CasImpl::CasImpl(GcManager& Gc) : m_TinyStrategy(Gc), m_SmallStrategy(Gc), m_LargeStrategy(Gc) +{ +} + +CasImpl::~CasImpl() +{ +} + +void +CasImpl::Initialize(const CidStoreConfiguration& InConfig) +{ + m_Config = InConfig; + + ZEN_INFO("initializing CAS pool at '{}'", m_Config.RootDirectory); + + // Ensure root directory exists - create if it doesn't exist already + + std::filesystem::create_directories(m_Config.RootDirectory); + + // Open or create manifest + + const bool IsNewStore = OpenOrCreateManifest(); + + // Initialize payload storage + + m_LargeStrategy.Initialize(m_Config.RootDirectory, IsNewStore); + m_TinyStrategy.Initialize(m_Config.RootDirectory, "tobs", 1u << 28, 16, IsNewStore); // 256 Mb per block + m_SmallStrategy.Initialize(m_Config.RootDirectory, "sobs", 1u << 30, 4096, IsNewStore); // 1 Gb per block +} + +bool +CasImpl::OpenOrCreateManifest() +{ + bool IsNewStore = false; + + std::filesystem::path ManifestPath = m_Config.RootDirectory; + ManifestPath /= ".ucas_root"; + + std::error_code Ec; + BasicFile ManifestFile; + ManifestFile.Open(ManifestPath.c_str(), BasicFile::Mode::kRead, Ec); + + bool ManifestIsOk = false; + + if (Ec) + { + if (Ec == std::errc::no_such_file_or_directory) + { + IsNewStore = true; + } + } + else + { + IoBuffer ManifestBuffer = ManifestFile.ReadAll(); + ManifestFile.Close(); + + if (ManifestBuffer.Size() > 0 && ManifestBuffer.Data<uint8_t>()[0] == '#') + { + // Old-style manifest, does not contain any useful information, so we may as well update it + } + else + { + CbObject Manifest{SharedBuffer(ManifestBuffer)}; + CbValidateError ValidationResult = ValidateCompactBinary(ManifestBuffer, CbValidateMode::All); + + if (ValidationResult == CbValidateError::None) + { + if (Manifest["id"]) + { + ManifestIsOk = true; + } + } + else + { + ZEN_WARN("Store manifest validation failed: {:#x}, will generate new manifest to recover", uint32_t(ValidationResult)); + } + + if (ManifestIsOk) + { + m_ManifestObject = std::move(Manifest); + } + } + } + + if (!ManifestIsOk) + { + UpdateManifest(); + } + + return IsNewStore; +} + +void +CasImpl::UpdateManifest() +{ + if (!m_ManifestObject) + { + CbObjectWriter Cbo; + Cbo << "id" << zen::Oid::NewOid() << "created" << DateTime::Now(); + m_ManifestObject = Cbo.Save(); + } + + // Write manifest to file + + std::filesystem::path ManifestPath = m_Config.RootDirectory; + ManifestPath /= ".ucas_root"; + + // This will throw on failure + + ZEN_TRACE("Writing new manifest to '{}'", ManifestPath); + + BasicFile Marker; + Marker.Open(ManifestPath.c_str(), BasicFile::Mode::kTruncate); + Marker.Write(m_ManifestObject.GetBuffer(), 0); +} + +CasStore::InsertResult +CasImpl::InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, InsertMode Mode) +{ + ZEN_TRACE_CPU("CAS::InsertChunk"); + + const uint64_t ChunkSize = Chunk.Size(); + + if (ChunkSize < m_Config.TinyValueThreshold) + { + ZEN_ASSERT(ChunkSize); + + return m_TinyStrategy.InsertChunk(Chunk, ChunkHash); + } + else if (ChunkSize < m_Config.HugeValueThreshold) + { + return m_SmallStrategy.InsertChunk(Chunk, ChunkHash); + } + + return m_LargeStrategy.InsertChunk(Chunk, ChunkHash, Mode); +} + +IoBuffer +CasImpl::FindChunk(const IoHash& ChunkHash) +{ + ZEN_TRACE_CPU("CAS::FindChunk"); + + if (IoBuffer Found = m_SmallStrategy.FindChunk(ChunkHash)) + { + return Found; + } + + if (IoBuffer Found = m_TinyStrategy.FindChunk(ChunkHash)) + { + return Found; + } + + if (IoBuffer Found = m_LargeStrategy.FindChunk(ChunkHash)) + { + return Found; + } + + // Not found + return IoBuffer{}; +} + +bool +CasImpl::ContainsChunk(const IoHash& ChunkHash) +{ + return m_SmallStrategy.HaveChunk(ChunkHash) || m_TinyStrategy.HaveChunk(ChunkHash) || m_LargeStrategy.HaveChunk(ChunkHash); +} + +void +CasImpl::FilterChunks(HashKeySet& InOutChunks) +{ + m_SmallStrategy.FilterChunks(InOutChunks); + m_TinyStrategy.FilterChunks(InOutChunks); + m_LargeStrategy.FilterChunks(InOutChunks); +} + +void +CasImpl::Flush() +{ + m_SmallStrategy.Flush(); + m_TinyStrategy.Flush(); + m_LargeStrategy.Flush(); +} + +void +CasImpl::Scrub(ScrubContext& Ctx) +{ + if (m_LastScrubTime == Ctx.ScrubTimestamp()) + { + return; + } + + m_LastScrubTime = Ctx.ScrubTimestamp(); + + m_SmallStrategy.Scrub(Ctx); + m_TinyStrategy.Scrub(Ctx); + m_LargeStrategy.Scrub(Ctx); +} + +void +CasImpl::GarbageCollect(GcContext& GcCtx) +{ + m_SmallStrategy.CollectGarbage(GcCtx); + m_TinyStrategy.CollectGarbage(GcCtx); + m_LargeStrategy.CollectGarbage(GcCtx); +} + +CidStoreSize +CasImpl::TotalSize() const +{ + const uint64_t Tiny = m_TinyStrategy.StorageSize().DiskSize; + const uint64_t Small = m_SmallStrategy.StorageSize().DiskSize; + const uint64_t Large = m_LargeStrategy.StorageSize().DiskSize; + + return {.TinySize = Tiny, .SmallSize = Small, .LargeSize = Large, .TotalSize = Tiny + Small + Large}; +} + +////////////////////////////////////////////////////////////////////////// + +std::unique_ptr<CasStore> +CreateCasStore(GcManager& Gc) +{ + return std::make_unique<CasImpl>(Gc); +} + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +TEST_CASE("CasStore") +{ + ScopedTemporaryDirectory TempDir; + + CidStoreConfiguration config; + config.RootDirectory = TempDir.Path(); + + GcManager Gc; + + std::unique_ptr<CasStore> Store = CreateCasStore(Gc); + Store->Initialize(config); + + ScrubContext Ctx; + Store->Scrub(Ctx); + + IoBuffer Value1{16}; + memcpy(Value1.MutableData(), "1234567890123456", 16); + IoHash Hash1 = IoHash::HashBuffer(Value1.Data(), Value1.Size()); + CasStore::InsertResult Result1 = Store->InsertChunk(Value1, Hash1); + CHECK(Result1.New); + + IoBuffer Value2{16}; + memcpy(Value2.MutableData(), "ABCDEFGHIJKLMNOP", 16); + IoHash Hash2 = IoHash::HashBuffer(Value2.Data(), Value2.Size()); + CasStore::InsertResult Result2 = Store->InsertChunk(Value2, Hash2); + CHECK(Result2.New); + + HashKeySet ChunkSet; + ChunkSet.AddHashToSet(Hash1); + ChunkSet.AddHashToSet(Hash2); + + Store->FilterChunks(ChunkSet); + CHECK(ChunkSet.IsEmpty()); + + IoBuffer Lookup1 = Store->FindChunk(Hash1); + CHECK(Lookup1); + IoBuffer Lookup2 = Store->FindChunk(Hash2); + CHECK(Lookup2); +} + +void +CAS_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenstore/cas.h b/src/zenstore/cas.h new file mode 100644 index 000000000..9c48d4707 --- /dev/null +++ b/src/zenstore/cas.h @@ -0,0 +1,67 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/blake3.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/refcount.h> +#include <zencore/timer.h> +#include <zenstore/cidstore.h> +#include <zenstore/hashkeyset.h> + +#include <atomic> +#include <filesystem> +#include <functional> +#include <memory> +#include <string> +#include <unordered_set> + +namespace zen { + +class GcContext; +class GcManager; +class ScrubContext; + +/** Content Addressable Storage interface + + */ + +class CasStore +{ +public: + virtual ~CasStore() = default; + + const CidStoreConfiguration& Config() { return m_Config; } + + struct InsertResult + { + bool New = false; + }; + + enum class InsertMode + { + kCopyOnly, + kMayBeMovedInPlace + }; + + virtual void Initialize(const CidStoreConfiguration& Config) = 0; + virtual InsertResult InsertChunk(IoBuffer Data, const IoHash& ChunkHash, InsertMode Mode = InsertMode::kMayBeMovedInPlace) = 0; + virtual IoBuffer FindChunk(const IoHash& ChunkHash) = 0; + virtual bool ContainsChunk(const IoHash& ChunkHash) = 0; + virtual void FilterChunks(HashKeySet& InOutChunks) = 0; + virtual void Flush() = 0; + virtual void Scrub(ScrubContext& Ctx) = 0; + virtual void GarbageCollect(GcContext& GcCtx) = 0; + virtual CidStoreSize TotalSize() const = 0; + +protected: + CidStoreConfiguration m_Config; + uint64_t m_LastScrubTime = 0; +}; + +ZENCORE_API std::unique_ptr<CasStore> CreateCasStore(GcManager& Gc); + +void CAS_forcelink(); + +} // namespace zen diff --git a/src/zenstore/caslog.cpp b/src/zenstore/caslog.cpp new file mode 100644 index 000000000..2a978ae12 --- /dev/null +++ b/src/zenstore/caslog.cpp @@ -0,0 +1,236 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenstore/caslog.h> + +#include "compactcas.h" + +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/string.h> +#include <zencore/thread.h> +#include <zencore/uid.h> + +#include <xxhash.h> + +#include <gsl/gsl-lite.hpp> + +#include <filesystem> +#include <functional> + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +uint32_t +CasLogFile::FileHeader::ComputeChecksum() +{ + return XXH32(&this->Magic, sizeof(FileHeader) - 4, 0xC0C0'BABA); +} + +CasLogFile::CasLogFile() +{ +} + +CasLogFile::~CasLogFile() +{ +} + +bool +CasLogFile::IsValid(std::filesystem::path FileName, size_t RecordSize) +{ + if (!std::filesystem::is_regular_file(FileName)) + { + return false; + } + BasicFile File; + + std::error_code Ec; + File.Open(FileName, BasicFile::Mode::kRead, Ec); + if (Ec) + { + return false; + } + + FileHeader Header; + if (File.FileSize() < sizeof(Header)) + { + return false; + } + + // Validate header and log contents and prepare for appending/replay + File.Read(&Header, sizeof Header, 0); + + if ((0 != memcmp(Header.Magic, FileHeader::MagicSequence, sizeof Header.Magic)) || (Header.Checksum != Header.ComputeChecksum())) + { + return false; + } + if (Header.RecordSize != RecordSize) + { + return false; + } + return true; +} + +void +CasLogFile::Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode) +{ + m_RecordSize = RecordSize; + + std::error_code Ec; + BasicFile::Mode FileMode = BasicFile::Mode::kRead; + switch (Mode) + { + case Mode::kWrite: + FileMode = BasicFile::Mode::kWrite; + break; + case Mode::kTruncate: + FileMode = BasicFile::Mode::kTruncate; + break; + } + + m_File.Open(FileName, FileMode, Ec); + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to open log file '{}'", FileName)); + } + + uint64_t AppendOffset = 0; + + if ((Mode == Mode::kTruncate) || (m_File.FileSize() < sizeof(FileHeader))) + { + if (Mode == Mode::kRead) + { + throw std::runtime_error(fmt::format("Mangled log header (file to small) in '{}'", FileName)); + } + // Initialize log by writing header + FileHeader Header = {.RecordSize = gsl::narrow<uint32_t>(RecordSize), .LogId = Oid::NewOid(), .ValidatedTail = 0}; + memcpy(Header.Magic, FileHeader::MagicSequence, sizeof Header.Magic); + Header.Finalize(); + + m_File.Write(&Header, sizeof Header, 0); + + AppendOffset = sizeof(FileHeader); + + m_Header = Header; + } + else + { + FileHeader Header; + m_File.Read(&Header, sizeof Header, 0); + + if ((0 != memcmp(Header.Magic, FileHeader::MagicSequence, sizeof Header.Magic)) || (Header.Checksum != Header.ComputeChecksum())) + { + throw std::runtime_error(fmt::format("Mangled log header (invalid header magic) in '{}'", FileName)); + } + if (Header.RecordSize != RecordSize) + { + throw std::runtime_error(fmt::format("Mangled log header (mismatch in record size, expected {}, found {}) in '{}'", + RecordSize, + Header.RecordSize, + FileName)); + } + + AppendOffset = m_File.FileSize(); + + // Adjust the offset to ensure we end up on a good boundary, in case there is some garbage appended + + AppendOffset -= sizeof Header; + AppendOffset -= AppendOffset % RecordSize; + AppendOffset += sizeof Header; + + m_Header = Header; + } + + m_AppendOffset = AppendOffset; +} + +void +CasLogFile::Close() +{ + // TODO: update header and maybe add trailer + Flush(); + + m_File.Close(); +} + +uint64_t +CasLogFile::GetLogSize() +{ + return m_File.FileSize(); +} + +uint64_t +CasLogFile::GetLogCount() +{ + uint64_t LogFileSize = m_AppendOffset.load(std::memory_order_acquire); + if (LogFileSize < sizeof(FileHeader)) + { + return 0; + } + const uint64_t LogBaseOffset = sizeof(FileHeader); + const size_t LogEntryCount = (LogFileSize - LogBaseOffset) / m_RecordSize; + return LogEntryCount; +} + +void +CasLogFile::Replay(std::function<void(const void*)>&& Handler, uint64_t SkipEntryCount) +{ + uint64_t LogFileSize = m_File.FileSize(); + + // Ensure we end up on a clean boundary + uint64_t LogBaseOffset = sizeof(FileHeader); + size_t LogEntryCount = (LogFileSize - LogBaseOffset) / m_RecordSize; + + if (LogEntryCount <= SkipEntryCount) + { + return; + } + + LogBaseOffset += SkipEntryCount * m_RecordSize; + LogEntryCount -= SkipEntryCount; + + // This should really be streaming the data rather than just + // reading it into memory, though we don't tend to get very + // large logs so it may not matter + + const uint64_t LogDataSize = LogEntryCount * m_RecordSize; + + std::vector<uint8_t> ReadBuffer; + ReadBuffer.resize(LogDataSize); + + m_File.Read(ReadBuffer.data(), LogDataSize, LogBaseOffset); + + for (int i = 0; i < int(LogEntryCount); ++i) + { + Handler(ReadBuffer.data() + (i * m_RecordSize)); + } + + m_AppendOffset = LogBaseOffset + (m_RecordSize * LogEntryCount); +} + +void +CasLogFile::Append(const void* DataPointer, uint64_t DataSize) +{ + ZEN_ASSERT((DataSize % m_RecordSize) == 0); + + uint64_t AppendOffset = m_AppendOffset.fetch_add(DataSize); + + std::error_code Ec; + m_File.Write(DataPointer, gsl::narrow<uint32_t>(DataSize), AppendOffset, Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to write to log file '{}'", PathFromHandle(m_File.Handle()))); + } +} + +void +CasLogFile::Flush() +{ + m_File.Flush(); +} + +} // namespace zen diff --git a/src/zenstore/cidstore.cpp b/src/zenstore/cidstore.cpp new file mode 100644 index 000000000..5a5116faf --- /dev/null +++ b/src/zenstore/cidstore.cpp @@ -0,0 +1,125 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenstore/cidstore.h" + +#include <zencore/compress.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/string.h> +#include <zenstore/scrubcontext.h> + +#include "cas.h" + +#include <filesystem> + +namespace zen { + +struct CidStore::Impl +{ + Impl(CasStore& InCasStore) : m_CasStore(InCasStore) {} + + CasStore& m_CasStore; + + void Initialize(const CidStoreConfiguration& Config) { m_CasStore.Initialize(Config); } + + CidStore::InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, CidStore::InsertMode Mode) + { +#ifndef NDEBUG + IoHash VerifyRawHash; + uint64_t _; + ZEN_ASSERT(CompressedBuffer::ValidateCompressedHeader(ChunkData, VerifyRawHash, _) && RawHash == VerifyRawHash); +#endif // NDEBUG + IoBuffer Payload(ChunkData); + Payload.SetContentType(ZenContentType::kCompressedBinary); + + CasStore::InsertResult Result = m_CasStore.InsertChunk(Payload, RawHash, static_cast<CasStore::InsertMode>(Mode)); + + return {.New = Result.New}; + } + + IoBuffer FindChunkByCid(const IoHash& DecompressedId) { return m_CasStore.FindChunk(DecompressedId); } + + bool ContainsChunk(const IoHash& DecompressedId) { return m_CasStore.ContainsChunk(DecompressedId); } + + void FilterChunks(HashKeySet& InOutChunks) + { + InOutChunks.RemoveHashesIf([&](const IoHash& Hash) { return ContainsChunk(Hash); }); + } + + void Flush() { m_CasStore.Flush(); } + + void Scrub(ScrubContext& Ctx) + { + if (Ctx.ScrubTimestamp() == m_LastScrubTime) + { + return; + } + + m_LastScrubTime = Ctx.ScrubTimestamp(); + + m_CasStore.Scrub(Ctx); + } + + uint64_t m_LastScrubTime = 0; +}; + +////////////////////////////////////////////////////////////////////////// + +CidStore::CidStore(GcManager& Gc) : m_CasStore(CreateCasStore(Gc)), m_Impl(std::make_unique<Impl>(*m_CasStore)) +{ +} + +CidStore::~CidStore() +{ +} + +void +CidStore::Initialize(const CidStoreConfiguration& Config) +{ + m_Impl->Initialize(Config); +} + +CidStore::InsertResult +CidStore::AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, InsertMode Mode) +{ + return m_Impl->AddChunk(ChunkData, RawHash, Mode); +} + +IoBuffer +CidStore::FindChunkByCid(const IoHash& DecompressedId) +{ + return m_Impl->FindChunkByCid(DecompressedId); +} + +bool +CidStore::ContainsChunk(const IoHash& DecompressedId) +{ + return m_Impl->ContainsChunk(DecompressedId); +} + +void +CidStore::FilterChunks(HashKeySet& InOutChunks) +{ + return m_Impl->FilterChunks(InOutChunks); +} + +void +CidStore::Flush() +{ + m_Impl->Flush(); +} + +void +CidStore::Scrub(ScrubContext& Ctx) +{ + m_Impl->Scrub(Ctx); +} + +CidStoreSize +CidStore::TotalSize() const +{ + return m_Impl->m_CasStore.TotalSize(); +} + +} // namespace zen diff --git a/src/zenstore/compactcas.cpp b/src/zenstore/compactcas.cpp new file mode 100644 index 000000000..7b2c21b0f --- /dev/null +++ b/src/zenstore/compactcas.cpp @@ -0,0 +1,1511 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "compactcas.h" + +#include "cas.h" + +#include <zencore/compress.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zenstore/scrubcontext.h> + +#include <gsl/gsl-lite.hpp> + +#include <xxhash.h> + +#if ZEN_WITH_TESTS +# include <zencore/compactbinarybuilder.h> +# include <zencore/testing.h> +# include <zencore/testutils.h> +# include <zencore/workthreadpool.h> +# include <zenstore/cidstore.h> +# include <algorithm> +# include <random> +#endif + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +struct CasDiskIndexHeader +{ + static constexpr uint32_t ExpectedMagic = 0x75696478; // 'uidx'; + static constexpr uint32_t CurrentVersion = 1; + + uint32_t Magic = ExpectedMagic; + uint32_t Version = CurrentVersion; + uint64_t EntryCount = 0; + uint64_t LogPosition = 0; + uint32_t PayloadAlignment = 0; + uint32_t Checksum = 0; + + static uint32_t ComputeChecksum(const CasDiskIndexHeader& Header) + { + return XXH32(&Header.Magic, sizeof(CasDiskIndexHeader) - sizeof(uint32_t), 0xC0C0'BABA); + } +}; + +static_assert(sizeof(CasDiskIndexHeader) == 32); + +namespace { + const char* IndexExtension = ".uidx"; + const char* LogExtension = ".ulog"; + + std::filesystem::path GetBasePath(const std::filesystem::path& RootPath, const std::string& ContainerBaseName) + { + return RootPath / ContainerBaseName; + } + + std::filesystem::path GetIndexPath(const std::filesystem::path& RootPath, const std::string& ContainerBaseName) + { + return GetBasePath(RootPath, ContainerBaseName) / (ContainerBaseName + IndexExtension); + } + + std::filesystem::path GetTempIndexPath(const std::filesystem::path& RootPath, const std::string& ContainerBaseName) + { + return GetBasePath(RootPath, ContainerBaseName) / (ContainerBaseName + ".tmp" + LogExtension); + } + + std::filesystem::path GetLogPath(const std::filesystem::path& RootPath, const std::string& ContainerBaseName) + { + return GetBasePath(RootPath, ContainerBaseName) / (ContainerBaseName + LogExtension); + } + + std::filesystem::path GetBlocksBasePath(const std::filesystem::path& RootPath, const std::string& ContainerBaseName) + { + return GetBasePath(RootPath, ContainerBaseName) / "blocks"; + } + + bool ValidateEntry(const CasDiskIndexEntry& Entry, std::string& OutReason) + { + if (Entry.Key == IoHash::Zero) + { + OutReason = fmt::format("Invalid hash key {}", Entry.Key.ToHexString()); + return false; + } + if ((Entry.Flags & ~CasDiskIndexEntry::kTombstone) != 0) + { + OutReason = fmt::format("Invalid flags {} for entry {}", Entry.Flags, Entry.Key.ToHexString()); + return false; + } + if (Entry.Flags & CasDiskIndexEntry::kTombstone) + { + return true; + } + if (Entry.ContentType != ZenContentType::kUnknownContentType) + { + OutReason = + fmt::format("Invalid content type {} for entry {}", static_cast<uint8_t>(Entry.ContentType), Entry.Key.ToHexString()); + return false; + } + uint64_t Size = Entry.Location.GetSize(); + if (Size == 0) + { + OutReason = fmt::format("Invalid size {} for entry {}", Size, Entry.Key.ToHexString()); + return false; + } + return true; + } + +} // namespace + +////////////////////////////////////////////////////////////////////////// + +CasContainerStrategy::CasContainerStrategy(GcManager& Gc) : GcStorage(Gc), m_Log(logging::Get("containercas")) +{ +} + +CasContainerStrategy::~CasContainerStrategy() +{ +} + +void +CasContainerStrategy::Initialize(const std::filesystem::path& RootDirectory, + const std::string_view ContainerBaseName, + uint32_t MaxBlockSize, + uint64_t Alignment, + bool IsNewStore) +{ + ZEN_ASSERT(IsPow2(Alignment)); + ZEN_ASSERT(!m_IsInitialized); + ZEN_ASSERT(MaxBlockSize > 0); + + m_RootDirectory = RootDirectory; + m_ContainerBaseName = ContainerBaseName; + m_PayloadAlignment = Alignment; + m_MaxBlockSize = MaxBlockSize; + m_BlocksBasePath = GetBlocksBasePath(m_RootDirectory, m_ContainerBaseName); + + OpenContainer(IsNewStore); + + m_IsInitialized = true; +} + +CasStore::InsertResult +CasContainerStrategy::InsertChunk(const void* ChunkData, size_t ChunkSize, const IoHash& ChunkHash) +{ + { + RwLock::SharedLockScope _(m_LocationMapLock); + if (m_LocationMap.contains(ChunkHash)) + { + return CasStore::InsertResult{.New = false}; + } + } + + // We can end up in a situation that InsertChunk writes the same chunk data in + // different locations. + // We release the insert lock once we have the correct WriteBlock ready and we know + // where to write the data. If a new InsertChunk request for the same chunk hash/data + // comes in before we update m_LocationMap below we will have a race. + // The outcome of that is that we will write the chunk data in more than one location + // but the chunk hash will only point to one of the chunks. + // We will in that case waste space until the next GC operation. + // + // This should be a rare occasion and the current flow reduces the time we block for + // reads, insert and GC. + + m_BlockStore.WriteChunk(ChunkData, ChunkSize, m_PayloadAlignment, [&](const BlockStoreLocation& Location) { + BlockStoreDiskLocation DiskLocation(Location, m_PayloadAlignment); + const CasDiskIndexEntry IndexEntry{.Key = ChunkHash, .Location = DiskLocation}; + m_CasLog.Append(IndexEntry); + { + RwLock::ExclusiveLockScope _(m_LocationMapLock); + m_LocationMap.emplace(ChunkHash, DiskLocation); + } + }); + + return CasStore::InsertResult{.New = true}; +} + +CasStore::InsertResult +CasContainerStrategy::InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash) +{ +#if !ZEN_WITH_TESTS + ZEN_ASSERT(Chunk.GetContentType() == ZenContentType::kCompressedBinary); +#endif + return InsertChunk(Chunk.Data(), Chunk.Size(), ChunkHash); +} + +IoBuffer +CasContainerStrategy::FindChunk(const IoHash& ChunkHash) +{ + RwLock::SharedLockScope _(m_LocationMapLock); + auto KeyIt = m_LocationMap.find(ChunkHash); + if (KeyIt == m_LocationMap.end()) + { + return IoBuffer(); + } + const BlockStoreLocation& Location = KeyIt->second.Get(m_PayloadAlignment); + + IoBuffer Chunk = m_BlockStore.TryGetChunk(Location); + return Chunk; +} + +bool +CasContainerStrategy::HaveChunk(const IoHash& ChunkHash) +{ + RwLock::SharedLockScope _(m_LocationMapLock); + return m_LocationMap.contains(ChunkHash); +} + +void +CasContainerStrategy::FilterChunks(HashKeySet& InOutChunks) +{ + // This implementation is good enough for relatively small + // chunk sets (in terms of chunk identifiers), but would + // benefit from a better implementation which removes + // items incrementally for large sets, especially when + // we're likely to already have a large proportion of the + // chunks in the set + + InOutChunks.RemoveHashesIf([&](const IoHash& Hash) { return HaveChunk(Hash); }); +} + +void +CasContainerStrategy::Flush() +{ + m_BlockStore.Flush(); + m_CasLog.Flush(); + MakeIndexSnapshot(); +} + +void +CasContainerStrategy::Scrub(ScrubContext& Ctx) +{ + std::vector<IoHash> BadKeys; + uint64_t ChunkCount{0}, ChunkBytes{0}; + std::vector<BlockStoreLocation> ChunkLocations; + std::vector<IoHash> ChunkIndexToChunkHash; + + RwLock::SharedLockScope _(m_LocationMapLock); + + uint64_t TotalChunkCount = m_LocationMap.size(); + ChunkLocations.reserve(TotalChunkCount); + ChunkIndexToChunkHash.reserve(TotalChunkCount); + { + for (const auto& Entry : m_LocationMap) + { + const IoHash& ChunkHash = Entry.first; + const BlockStoreDiskLocation& DiskLocation = Entry.second; + BlockStoreLocation Location = DiskLocation.Get(m_PayloadAlignment); + + ChunkLocations.push_back(Location); + ChunkIndexToChunkHash.push_back(ChunkHash); + } + } + + const auto ValidateSmallChunk = [&](size_t ChunkIndex, const void* Data, uint64_t Size) { + ++ChunkCount; + ChunkBytes += Size; + + const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex]; + if (!Data) + { + // ChunkLocation out of range of stored blocks + BadKeys.push_back(Hash); + return; + } + + IoBuffer Buffer(IoBuffer::Wrap, Data, Size); + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize)) + { + if (RawHash != Hash) + { + // Hash mismatch + BadKeys.push_back(Hash); + return; + } + return; + } +#if ZEN_WITH_TESTS + IoHash ComputedHash = IoHash::HashBuffer(Data, Size); + if (ComputedHash == Hash) + { + return; + } +#endif + BadKeys.push_back(Hash); + }; + + const auto ValidateLargeChunk = [&](size_t ChunkIndex, BlockStoreFile& File, uint64_t Offset, uint64_t Size) { + ++ChunkCount; + ChunkBytes += Size; + + const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex]; + IoBuffer Buffer(IoBuffer::BorrowedFile, File.GetBasicFile().Handle(), Offset, Size); + + IoHash RawHash; + uint64_t RawSize; + // TODO: Add API to verify compressed buffer without having to memorymap the whole file + if (CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize)) + { + if (RawHash != Hash) + { + // Hash mismatch + BadKeys.push_back(Hash); + return; + } + return; + } +#if ZEN_WITH_TESTS + IoHashStream Hasher; + File.StreamByteRange(Offset, Size, [&](const void* Data, size_t Size) { Hasher.Append(Data, Size); }); + IoHash ComputedHash = Hasher.GetHash(); + if (ComputedHash == Hash) + { + return; + } +#endif + BadKeys.push_back(Hash); + }; + + m_BlockStore.IterateChunks(ChunkLocations, ValidateSmallChunk, ValidateLargeChunk); + + _.ReleaseNow(); + + Ctx.ReportScrubbed(ChunkCount, ChunkBytes); + + if (!BadKeys.empty()) + { + ZEN_WARN("Scrubbing found {} bad chunks in '{}'", BadKeys.size(), m_RootDirectory / m_ContainerBaseName); + + if (Ctx.RunRecovery()) + { + // Deal with bad chunks by removing them from our lookup map + + std::vector<CasDiskIndexEntry> LogEntries; + LogEntries.reserve(BadKeys.size()); + { + RwLock::ExclusiveLockScope __(m_LocationMapLock); + for (const IoHash& ChunkHash : BadKeys) + { + const auto KeyIt = m_LocationMap.find(ChunkHash); + if (KeyIt == m_LocationMap.end()) + { + // Might have been GC'd + continue; + } + LogEntries.push_back({.Key = KeyIt->first, .Location = KeyIt->second, .Flags = CasDiskIndexEntry::kTombstone}); + m_LocationMap.erase(KeyIt); + } + } + m_CasLog.Append(LogEntries); + } + } + + // Let whomever it concerns know about the bad chunks. This could + // be used to invalidate higher level data structures more efficiently + // than a full validation pass might be able to do + Ctx.ReportBadCidChunks(BadKeys); + + ZEN_INFO("compact cas scrubbed: {} chunks ({})", ChunkCount, NiceBytes(ChunkBytes)); +} + +void +CasContainerStrategy::CollectGarbage(GcContext& GcCtx) +{ + // It collects all the blocks that we want to delete chunks from. For each such + // block we keep a list of chunks to retain and a list of chunks to delete. + // + // If there is a block that we are currently writing to, that block is omitted + // from the garbage collection. + // + // Next it will iterate over all blocks that we want to remove chunks from. + // If the block is empty after removal of chunks we mark the block as pending + // delete - we want to delete it as soon as there are no IoBuffers using the + // block file. + // Once complete we update the m_LocationMap by removing the chunks. + // + // If the block is non-empty we write out the chunks we want to keep to a new + // block file (creating new block files as needed). + // + // We update the index as we complete each new block file. This makes it possible + // to break the GC if we want to limit time for execution. + // + // GC can very parallell to regular operation - it will block while taking + // a snapshot of the current m_LocationMap state and while moving blocks it will + // do a blocking operation and update the m_LocationMap after each new block is + // written and figuring out the path to the next new block. + + ZEN_DEBUG("collecting garbage from '{}'", m_RootDirectory / m_ContainerBaseName); + + uint64_t WriteBlockTimeUs = 0; + uint64_t WriteBlockLongestTimeUs = 0; + uint64_t ReadBlockTimeUs = 0; + uint64_t ReadBlockLongestTimeUs = 0; + + LocationMap_t LocationMap; + BlockStore::ReclaimSnapshotState BlockStoreState; + { + RwLock::SharedLockScope ___(m_LocationMapLock); + Stopwatch Timer; + const auto ____ = MakeGuard([&Timer, &WriteBlockTimeUs, &WriteBlockLongestTimeUs] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + WriteBlockTimeUs += ElapsedUs; + WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs); + }); + LocationMap = m_LocationMap; + BlockStoreState = m_BlockStore.GetReclaimSnapshotState(); + } + + uint64_t TotalChunkCount = LocationMap.size(); + + std::vector<IoHash> TotalChunkHashes; + TotalChunkHashes.reserve(TotalChunkCount); + for (const auto& Entry : LocationMap) + { + TotalChunkHashes.push_back(Entry.first); + } + + std::vector<BlockStoreLocation> ChunkLocations; + BlockStore::ChunkIndexArray KeepChunkIndexes; + std::vector<IoHash> ChunkIndexToChunkHash; + ChunkLocations.reserve(TotalChunkCount); + ChunkIndexToChunkHash.reserve(TotalChunkCount); + + GcCtx.FilterCids(TotalChunkHashes, [&](const IoHash& ChunkHash, bool Keep) { + auto KeyIt = LocationMap.find(ChunkHash); + const BlockStoreDiskLocation& DiskLocation = KeyIt->second; + BlockStoreLocation Location = DiskLocation.Get(m_PayloadAlignment); + size_t ChunkIndex = ChunkLocations.size(); + + ChunkLocations.push_back(Location); + ChunkIndexToChunkHash[ChunkIndex] = ChunkHash; + if (Keep) + { + KeepChunkIndexes.push_back(ChunkIndex); + } + }); + + const bool PerformDelete = GcCtx.IsDeletionMode() && GcCtx.CollectSmallObjects(); + if (!PerformDelete) + { + m_BlockStore.ReclaimSpace(BlockStoreState, ChunkLocations, KeepChunkIndexes, m_PayloadAlignment, true); + return; + } + + std::vector<IoHash> DeletedChunks; + m_BlockStore.ReclaimSpace( + BlockStoreState, + ChunkLocations, + KeepChunkIndexes, + m_PayloadAlignment, + false, + [&](const BlockStore::MovedChunksArray& MovedChunks, const BlockStore::ChunkIndexArray& RemovedChunks) { + std::vector<CasDiskIndexEntry> LogEntries; + LogEntries.reserve(MovedChunks.size() + RemovedChunks.size()); + for (const auto& Entry : MovedChunks) + { + size_t ChunkIndex = Entry.first; + const BlockStoreLocation& NewLocation = Entry.second; + const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex]; + LogEntries.push_back({.Key = ChunkHash, .Location = {NewLocation, m_PayloadAlignment}}); + } + for (const size_t ChunkIndex : RemovedChunks) + { + const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex]; + const BlockStoreDiskLocation& OldDiskLocation = LocationMap[ChunkHash]; + LogEntries.push_back({.Key = ChunkHash, .Location = OldDiskLocation, .Flags = CasDiskIndexEntry::kTombstone}); + DeletedChunks.push_back(ChunkHash); + } + + m_CasLog.Append(LogEntries); + m_CasLog.Flush(); + { + RwLock::ExclusiveLockScope __(m_LocationMapLock); + Stopwatch Timer; + const auto ____ = MakeGuard([&] { + uint64_t ElapsedUs = Timer.GetElapsedTimeUs(); + ReadBlockTimeUs += ElapsedUs; + ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs); + }); + for (const CasDiskIndexEntry& Entry : LogEntries) + { + if (Entry.Flags & CasDiskIndexEntry::kTombstone) + { + m_LocationMap.erase(Entry.Key); + continue; + } + m_LocationMap[Entry.Key] = Entry.Location; + } + } + }, + [&GcCtx]() { return GcCtx.CollectSmallObjects(); }); + + GcCtx.AddDeletedCids(DeletedChunks); +} + +void +CasContainerStrategy::MakeIndexSnapshot() +{ + uint64_t LogCount = m_CasLog.GetLogCount(); + if (m_LogFlushPosition == LogCount) + { + return; + } + + ZEN_DEBUG("write store snapshot for '{}'", m_RootDirectory / m_ContainerBaseName); + uint64_t EntryCount = 0; + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("wrote store snapshot for '{}' containing {} entries in {}", + m_RootDirectory / m_ContainerBaseName, + EntryCount, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + namespace fs = std::filesystem; + + fs::path IndexPath = GetIndexPath(m_RootDirectory, m_ContainerBaseName); + fs::path TempIndexPath = GetTempIndexPath(m_RootDirectory, m_ContainerBaseName); + + // Move index away, we keep it if something goes wrong + if (fs::is_regular_file(TempIndexPath)) + { + fs::remove(TempIndexPath); + } + if (fs::is_regular_file(IndexPath)) + { + fs::rename(IndexPath, TempIndexPath); + } + + try + { + // Write the current state of the location map to a new index state + std::vector<CasDiskIndexEntry> Entries; + + { + RwLock::SharedLockScope ___(m_LocationMapLock); + Entries.resize(m_LocationMap.size()); + + uint64_t EntryIndex = 0; + for (auto& Entry : m_LocationMap) + { + CasDiskIndexEntry& IndexEntry = Entries[EntryIndex++]; + IndexEntry.Key = Entry.first; + IndexEntry.Location = Entry.second; + } + } + + BasicFile ObjectIndexFile; + ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kTruncate); + CasDiskIndexHeader Header = {.EntryCount = Entries.size(), + .LogPosition = LogCount, + .PayloadAlignment = gsl::narrow<uint32_t>(m_PayloadAlignment)}; + + Header.Checksum = CasDiskIndexHeader::ComputeChecksum(Header); + + ObjectIndexFile.Write(&Header, sizeof(CasDiskIndexEntry), 0); + ObjectIndexFile.Write(Entries.data(), Entries.size() * sizeof(CasDiskIndexEntry), sizeof(CasDiskIndexEntry)); + ObjectIndexFile.Flush(); + ObjectIndexFile.Close(); + EntryCount = Entries.size(); + m_LogFlushPosition = LogCount; + } + catch (std::exception& Err) + { + ZEN_ERROR("snapshot FAILED, reason: '{}'", Err.what()); + + // Restore any previous snapshot + + if (fs::is_regular_file(TempIndexPath)) + { + fs::remove(IndexPath); + fs::rename(TempIndexPath, IndexPath); + } + } + if (fs::is_regular_file(TempIndexPath)) + { + fs::remove(TempIndexPath); + } +} + +uint64_t +CasContainerStrategy::ReadIndexFile() +{ + std::vector<CasDiskIndexEntry> Entries; + std::filesystem::path IndexPath = GetIndexPath(m_RootDirectory, m_ContainerBaseName); + if (std::filesystem::is_regular_file(IndexPath)) + { + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("read store '{}' index containing {} entries in {}", + IndexPath, + Entries.size(), + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + BasicFile ObjectIndexFile; + ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kRead); + uint64_t Size = ObjectIndexFile.FileSize(); + if (Size >= sizeof(CasDiskIndexHeader)) + { + uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(CasDiskIndexHeader))) / sizeof(CasDiskIndexEntry); + CasDiskIndexHeader Header; + ObjectIndexFile.Read(&Header, sizeof(Header), 0); + if ((Header.Magic == CasDiskIndexHeader::ExpectedMagic) && (Header.Version == CasDiskIndexHeader::CurrentVersion) && + (Header.Checksum == CasDiskIndexHeader::ComputeChecksum(Header)) && (Header.PayloadAlignment > 0) && + (Header.EntryCount <= ExpectedEntryCount)) + { + Entries.resize(Header.EntryCount); + ObjectIndexFile.Read(Entries.data(), Header.EntryCount * sizeof(CasDiskIndexEntry), sizeof(CasDiskIndexHeader)); + m_PayloadAlignment = Header.PayloadAlignment; + + std::string InvalidEntryReason; + for (const CasDiskIndexEntry& Entry : Entries) + { + if (!ValidateEntry(Entry, InvalidEntryReason)) + { + ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", IndexPath, InvalidEntryReason); + continue; + } + m_LocationMap[Entry.Key] = Entry.Location; + } + + return Header.LogPosition; + } + else + { + ZEN_WARN("skipping invalid index file '{}'", IndexPath); + } + } + } + return 0; +} + +uint64_t +CasContainerStrategy::ReadLog(uint64_t SkipEntryCount) +{ + std::filesystem::path LogPath = GetLogPath(m_RootDirectory, m_ContainerBaseName); + if (std::filesystem::is_regular_file(LogPath)) + { + size_t LogEntryCount = 0; + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("read store '{}' log containing {} entries in {}", + m_RootDirectory / m_ContainerBaseName, + LogEntryCount, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + TCasLogFile<CasDiskIndexEntry> CasLog; + CasLog.Open(LogPath, CasLogFile::Mode::kRead); + if (CasLog.Initialize()) + { + uint64_t EntryCount = CasLog.GetLogCount(); + if (EntryCount < SkipEntryCount) + { + ZEN_WARN("reading full log at '{}', reason: Log position from index snapshot is out of range", LogPath); + SkipEntryCount = 0; + } + LogEntryCount = EntryCount - SkipEntryCount; + CasLog.Replay( + [&](const CasDiskIndexEntry& Record) { + LogEntryCount++; + std::string InvalidEntryReason; + if (Record.Flags & CasDiskIndexEntry::kTombstone) + { + m_LocationMap.erase(Record.Key); + return; + } + if (!ValidateEntry(Record, InvalidEntryReason)) + { + ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", LogPath, InvalidEntryReason); + return; + } + m_LocationMap[Record.Key] = Record.Location; + }, + SkipEntryCount); + return LogEntryCount; + } + } + return 0; +} + +void +CasContainerStrategy::OpenContainer(bool IsNewStore) +{ + // Add .running file and delete on clean on close to detect bad termination + + m_LocationMap.clear(); + + std::filesystem::path BasePath = GetBasePath(m_RootDirectory, m_ContainerBaseName); + + if (IsNewStore) + { + std::filesystem::remove_all(BasePath); + } + + m_LogFlushPosition = ReadIndexFile(); + uint64_t LogEntryCount = ReadLog(m_LogFlushPosition); + + CreateDirectories(BasePath); + + std::filesystem::path LogPath = GetLogPath(m_RootDirectory, m_ContainerBaseName); + m_CasLog.Open(LogPath, CasLogFile::Mode::kWrite); + + std::vector<BlockStoreLocation> KnownLocations; + KnownLocations.reserve(m_LocationMap.size()); + for (const auto& Entry : m_LocationMap) + { + const BlockStoreDiskLocation& Location = Entry.second; + KnownLocations.push_back(Location.Get(m_PayloadAlignment)); + } + + m_BlockStore.Initialize(m_BlocksBasePath, m_MaxBlockSize, BlockStoreDiskLocation::MaxBlockIndex + 1, KnownLocations); + + if (IsNewStore || (LogEntryCount > 0)) + { + MakeIndexSnapshot(); + } + + // TODO: should validate integrity of container files here +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +namespace { + static IoBuffer CreateRandomChunk(uint64_t Size) + { + static std::random_device rd; + static std::mt19937 g(rd()); + + std::vector<uint8_t> Values; + Values.resize(Size); + for (size_t Idx = 0; Idx < Size; ++Idx) + { + Values[Idx] = static_cast<uint8_t>(Idx); + } + std::shuffle(Values.begin(), Values.end(), g); + + return IoBufferBuilder::MakeCloneFromMemory(Values.data(), Values.size()); + } +} // namespace + +TEST_CASE("compactcas.hex") +{ + uint32_t Value; + std::string HexString; + CHECK(!ParseHexNumber("", Value)); + char Hex[9]; + + ToHexNumber(0u, Hex); + HexString = std::string(Hex); + CHECK(ParseHexNumber(HexString, Value)); + CHECK(Value == 0u); + + ToHexNumber(std::numeric_limits<std::uint32_t>::max(), Hex); + HexString = std::string(Hex); + CHECK(HexString == "ffffffff"); + CHECK(ParseHexNumber(HexString, Value)); + CHECK(Value == std::numeric_limits<std::uint32_t>::max()); + + ToHexNumber(0xadf14711u, Hex); + HexString = std::string(Hex); + CHECK(HexString == "adf14711"); + CHECK(ParseHexNumber(HexString, Value)); + CHECK(Value == 0xadf14711u); + + ToHexNumber(0x80000000u, Hex); + HexString = std::string(Hex); + CHECK(HexString == "80000000"); + CHECK(ParseHexNumber(HexString, Value)); + CHECK(Value == 0x80000000u); + + ToHexNumber(0x718293a4u, Hex); + HexString = std::string(Hex); + CHECK(HexString == "718293a4"); + CHECK(ParseHexNumber(HexString, Value)); + CHECK(Value == 0x718293a4u); +} + +TEST_CASE("compactcas.compact.gc") +{ + ScopedTemporaryDirectory TempDir; + + const int kIterationCount = 1000; + + std::vector<IoHash> Keys(kIterationCount); + + { + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "test", 65536, 16, true); + + for (int i = 0; i < kIterationCount; ++i) + { + CbObjectWriter Cbo; + Cbo << "id" << i; + CbObject Obj = Cbo.Save(); + + IoBuffer ObjBuffer = Obj.GetBuffer().AsIoBuffer(); + const IoHash Hash = HashBuffer(ObjBuffer); + + Cas.InsertChunk(ObjBuffer, Hash); + + Keys[i] = Hash; + } + + for (int i = 0; i < kIterationCount; ++i) + { + IoBuffer Chunk = Cas.FindChunk(Keys[i]); + + CHECK(!!Chunk); + + CbObject Value = LoadCompactBinaryObject(Chunk); + + CHECK_EQ(Value["id"].AsInt32(), i); + } + } + + // Validate that we can still read the inserted data after closing + // the original cas store + + { + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "test", 65536, 16, false); + + for (int i = 0; i < kIterationCount; ++i) + { + IoBuffer Chunk = Cas.FindChunk(Keys[i]); + + CHECK(!!Chunk); + + CbObject Value = LoadCompactBinaryObject(Chunk); + + CHECK_EQ(Value["id"].AsInt32(), i); + } + } +} + +TEST_CASE("compactcas.compact.totalsize") +{ + std::random_device rd; + std::mt19937 g(rd()); + + // for (uint32_t i = 0; i < 100; ++i) + { + ScopedTemporaryDirectory TempDir; + + const uint64_t kChunkSize = 1024; + const int32_t kChunkCount = 16; + + { + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "test", 65536, 16, true); + + for (int32_t Idx = 0; Idx < kChunkCount; ++Idx) + { + IoBuffer Chunk = CreateRandomChunk(kChunkSize); + const IoHash Hash = HashBuffer(Chunk); + CasStore::InsertResult InsertResult = Cas.InsertChunk(Chunk, Hash); + ZEN_ASSERT(InsertResult.New); + } + + const uint64_t TotalSize = Cas.StorageSize().DiskSize; + CHECK_EQ(kChunkSize * kChunkCount, TotalSize); + } + + { + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "test", 65536, 16, false); + + const uint64_t TotalSize = Cas.StorageSize().DiskSize; + CHECK_EQ(kChunkSize * kChunkCount, TotalSize); + } + + // Re-open again, this time we should have a snapshot + { + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "test", 65536, 16, false); + + const uint64_t TotalSize = Cas.StorageSize().DiskSize; + CHECK_EQ(kChunkSize * kChunkCount, TotalSize); + } + } +} + +TEST_CASE("compactcas.gc.basic") +{ + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "cb", 65536, 1 << 4, true); + + IoBuffer Chunk = CreateRandomChunk(128); + IoHash ChunkHash = IoHash::HashBuffer(Chunk); + + const CasStore::InsertResult InsertResult = Cas.InsertChunk(Chunk, ChunkHash); + CHECK(InsertResult.New); + Cas.Flush(); + + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + + Cas.CollectGarbage(GcCtx); + + CHECK(!Cas.HaveChunk(ChunkHash)); +} + +TEST_CASE("compactcas.gc.removefile") +{ + ScopedTemporaryDirectory TempDir; + + IoBuffer Chunk = CreateRandomChunk(128); + IoHash ChunkHash = IoHash::HashBuffer(Chunk); + { + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "cb", 65536, 1 << 4, true); + + const CasStore::InsertResult InsertResult = Cas.InsertChunk(Chunk, ChunkHash); + CHECK(InsertResult.New); + const CasStore::InsertResult InsertResultDup = Cas.InsertChunk(Chunk, ChunkHash); + CHECK(!InsertResultDup.New); + Cas.Flush(); + } + + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "cb", 65536, 1 << 4, false); + + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + + Cas.CollectGarbage(GcCtx); + + CHECK(!Cas.HaveChunk(ChunkHash)); +} + +TEST_CASE("compactcas.gc.compact") +{ + // for (uint32_t i = 0; i < 100; ++i) + { + ScopedTemporaryDirectory TempDir; + + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "cb", 2048, 1 << 4, true); + + uint64_t ChunkSizes[9] = {128, 541, 1023, 781, 218, 37, 4, 997, 5}; + std::vector<IoBuffer> Chunks; + Chunks.reserve(9); + for (uint64_t Size : ChunkSizes) + { + Chunks.push_back(CreateRandomChunk(Size)); + } + + std::vector<IoHash> ChunkHashes; + ChunkHashes.reserve(9); + for (const IoBuffer& Chunk : Chunks) + { + ChunkHashes.push_back(IoHash::HashBuffer(Chunk.Data(), Chunk.Size())); + } + + CHECK(Cas.InsertChunk(Chunks[0], ChunkHashes[0]).New); + CHECK(Cas.InsertChunk(Chunks[1], ChunkHashes[1]).New); + CHECK(Cas.InsertChunk(Chunks[2], ChunkHashes[2]).New); + CHECK(Cas.InsertChunk(Chunks[3], ChunkHashes[3]).New); + CHECK(Cas.InsertChunk(Chunks[4], ChunkHashes[4]).New); + CHECK(Cas.InsertChunk(Chunks[5], ChunkHashes[5]).New); + CHECK(Cas.InsertChunk(Chunks[6], ChunkHashes[6]).New); + CHECK(Cas.InsertChunk(Chunks[7], ChunkHashes[7]).New); + CHECK(Cas.InsertChunk(Chunks[8], ChunkHashes[8]).New); + + CHECK(Cas.HaveChunk(ChunkHashes[0])); + CHECK(Cas.HaveChunk(ChunkHashes[1])); + CHECK(Cas.HaveChunk(ChunkHashes[2])); + CHECK(Cas.HaveChunk(ChunkHashes[3])); + CHECK(Cas.HaveChunk(ChunkHashes[4])); + CHECK(Cas.HaveChunk(ChunkHashes[5])); + CHECK(Cas.HaveChunk(ChunkHashes[6])); + CHECK(Cas.HaveChunk(ChunkHashes[7])); + CHECK(Cas.HaveChunk(ChunkHashes[8])); + + // Keep first and last + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + + std::vector<IoHash> KeepChunks; + KeepChunks.push_back(ChunkHashes[0]); + KeepChunks.push_back(ChunkHashes[8]); + GcCtx.AddRetainedCids(KeepChunks); + + Cas.Flush(); + Cas.CollectGarbage(GcCtx); + + CHECK(Cas.HaveChunk(ChunkHashes[0])); + CHECK(!Cas.HaveChunk(ChunkHashes[1])); + CHECK(!Cas.HaveChunk(ChunkHashes[2])); + CHECK(!Cas.HaveChunk(ChunkHashes[3])); + CHECK(!Cas.HaveChunk(ChunkHashes[4])); + CHECK(!Cas.HaveChunk(ChunkHashes[5])); + CHECK(!Cas.HaveChunk(ChunkHashes[6])); + CHECK(!Cas.HaveChunk(ChunkHashes[7])); + CHECK(Cas.HaveChunk(ChunkHashes[8])); + + CHECK(ChunkHashes[0] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[0]))); + CHECK(ChunkHashes[8] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[8]))); + + Cas.InsertChunk(Chunks[1], ChunkHashes[1]); + Cas.InsertChunk(Chunks[2], ChunkHashes[2]); + Cas.InsertChunk(Chunks[3], ChunkHashes[3]); + Cas.InsertChunk(Chunks[4], ChunkHashes[4]); + Cas.InsertChunk(Chunks[5], ChunkHashes[5]); + Cas.InsertChunk(Chunks[6], ChunkHashes[6]); + Cas.InsertChunk(Chunks[7], ChunkHashes[7]); + } + + // Keep last + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + std::vector<IoHash> KeepChunks; + KeepChunks.push_back(ChunkHashes[8]); + GcCtx.AddRetainedCids(KeepChunks); + + Cas.Flush(); + Cas.CollectGarbage(GcCtx); + + CHECK(!Cas.HaveChunk(ChunkHashes[0])); + CHECK(!Cas.HaveChunk(ChunkHashes[1])); + CHECK(!Cas.HaveChunk(ChunkHashes[2])); + CHECK(!Cas.HaveChunk(ChunkHashes[3])); + CHECK(!Cas.HaveChunk(ChunkHashes[4])); + CHECK(!Cas.HaveChunk(ChunkHashes[5])); + CHECK(!Cas.HaveChunk(ChunkHashes[6])); + CHECK(!Cas.HaveChunk(ChunkHashes[7])); + CHECK(Cas.HaveChunk(ChunkHashes[8])); + + CHECK(ChunkHashes[8] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[8]))); + + Cas.InsertChunk(Chunks[1], ChunkHashes[1]); + Cas.InsertChunk(Chunks[2], ChunkHashes[2]); + Cas.InsertChunk(Chunks[3], ChunkHashes[3]); + Cas.InsertChunk(Chunks[4], ChunkHashes[4]); + Cas.InsertChunk(Chunks[5], ChunkHashes[5]); + Cas.InsertChunk(Chunks[6], ChunkHashes[6]); + Cas.InsertChunk(Chunks[7], ChunkHashes[7]); + } + + // Keep mixed + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + std::vector<IoHash> KeepChunks; + KeepChunks.push_back(ChunkHashes[1]); + KeepChunks.push_back(ChunkHashes[4]); + KeepChunks.push_back(ChunkHashes[7]); + GcCtx.AddRetainedCids(KeepChunks); + + Cas.Flush(); + Cas.CollectGarbage(GcCtx); + + CHECK(!Cas.HaveChunk(ChunkHashes[0])); + CHECK(Cas.HaveChunk(ChunkHashes[1])); + CHECK(!Cas.HaveChunk(ChunkHashes[2])); + CHECK(!Cas.HaveChunk(ChunkHashes[3])); + CHECK(Cas.HaveChunk(ChunkHashes[4])); + CHECK(!Cas.HaveChunk(ChunkHashes[5])); + CHECK(!Cas.HaveChunk(ChunkHashes[6])); + CHECK(Cas.HaveChunk(ChunkHashes[7])); + CHECK(!Cas.HaveChunk(ChunkHashes[8])); + + CHECK(ChunkHashes[1] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[1]))); + CHECK(ChunkHashes[4] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[4]))); + CHECK(ChunkHashes[7] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[7]))); + + Cas.InsertChunk(Chunks[0], ChunkHashes[0]); + Cas.InsertChunk(Chunks[2], ChunkHashes[2]); + Cas.InsertChunk(Chunks[3], ChunkHashes[3]); + Cas.InsertChunk(Chunks[5], ChunkHashes[5]); + Cas.InsertChunk(Chunks[6], ChunkHashes[6]); + Cas.InsertChunk(Chunks[8], ChunkHashes[8]); + } + + // Keep multiple at end + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + std::vector<IoHash> KeepChunks; + KeepChunks.push_back(ChunkHashes[6]); + KeepChunks.push_back(ChunkHashes[7]); + KeepChunks.push_back(ChunkHashes[8]); + GcCtx.AddRetainedCids(KeepChunks); + + Cas.Flush(); + Cas.CollectGarbage(GcCtx); + + CHECK(!Cas.HaveChunk(ChunkHashes[0])); + CHECK(!Cas.HaveChunk(ChunkHashes[1])); + CHECK(!Cas.HaveChunk(ChunkHashes[2])); + CHECK(!Cas.HaveChunk(ChunkHashes[3])); + CHECK(!Cas.HaveChunk(ChunkHashes[4])); + CHECK(!Cas.HaveChunk(ChunkHashes[5])); + CHECK(Cas.HaveChunk(ChunkHashes[6])); + CHECK(Cas.HaveChunk(ChunkHashes[7])); + CHECK(Cas.HaveChunk(ChunkHashes[8])); + + CHECK(ChunkHashes[6] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[6]))); + CHECK(ChunkHashes[7] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[7]))); + CHECK(ChunkHashes[8] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[8]))); + + Cas.InsertChunk(Chunks[0], ChunkHashes[0]); + Cas.InsertChunk(Chunks[1], ChunkHashes[1]); + Cas.InsertChunk(Chunks[2], ChunkHashes[2]); + Cas.InsertChunk(Chunks[3], ChunkHashes[3]); + Cas.InsertChunk(Chunks[4], ChunkHashes[4]); + Cas.InsertChunk(Chunks[5], ChunkHashes[5]); + } + + // Keep every other + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + std::vector<IoHash> KeepChunks; + KeepChunks.push_back(ChunkHashes[0]); + KeepChunks.push_back(ChunkHashes[2]); + KeepChunks.push_back(ChunkHashes[4]); + KeepChunks.push_back(ChunkHashes[6]); + KeepChunks.push_back(ChunkHashes[8]); + GcCtx.AddRetainedCids(KeepChunks); + + Cas.Flush(); + Cas.CollectGarbage(GcCtx); + + CHECK(Cas.HaveChunk(ChunkHashes[0])); + CHECK(!Cas.HaveChunk(ChunkHashes[1])); + CHECK(Cas.HaveChunk(ChunkHashes[2])); + CHECK(!Cas.HaveChunk(ChunkHashes[3])); + CHECK(Cas.HaveChunk(ChunkHashes[4])); + CHECK(!Cas.HaveChunk(ChunkHashes[5])); + CHECK(Cas.HaveChunk(ChunkHashes[6])); + CHECK(!Cas.HaveChunk(ChunkHashes[7])); + CHECK(Cas.HaveChunk(ChunkHashes[8])); + + CHECK(ChunkHashes[0] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[0]))); + CHECK(ChunkHashes[2] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[2]))); + CHECK(ChunkHashes[4] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[4]))); + CHECK(ChunkHashes[6] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[6]))); + CHECK(ChunkHashes[8] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[8]))); + + Cas.InsertChunk(Chunks[1], ChunkHashes[1]); + Cas.InsertChunk(Chunks[3], ChunkHashes[3]); + Cas.InsertChunk(Chunks[5], ChunkHashes[5]); + Cas.InsertChunk(Chunks[7], ChunkHashes[7]); + } + + // Verify that we nicely appended blocks even after all GC operations + CHECK(ChunkHashes[0] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[0]))); + CHECK(ChunkHashes[1] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[1]))); + CHECK(ChunkHashes[2] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[2]))); + CHECK(ChunkHashes[3] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[3]))); + CHECK(ChunkHashes[4] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[4]))); + CHECK(ChunkHashes[5] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[5]))); + CHECK(ChunkHashes[6] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[6]))); + CHECK(ChunkHashes[7] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[7]))); + CHECK(ChunkHashes[8] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[8]))); + } +} + +TEST_CASE("compactcas.gc.deleteblockonopen") +{ + ScopedTemporaryDirectory TempDir; + + uint64_t ChunkSizes[20] = {128, 541, 311, 181, 218, 37, 4, 397, 5, 92, 551, 721, 31, 92, 16, 99, 131, 41, 541, 84}; + std::vector<IoBuffer> Chunks; + Chunks.reserve(20); + for (uint64_t Size : ChunkSizes) + { + Chunks.push_back(CreateRandomChunk(Size)); + } + + std::vector<IoHash> ChunkHashes; + ChunkHashes.reserve(20); + for (const IoBuffer& Chunk : Chunks) + { + ChunkHashes.push_back(IoHash::HashBuffer(Chunk.Data(), Chunk.Size())); + } + + { + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "test", 1024, 16, true); + + for (size_t i = 0; i < 20; i++) + { + CHECK(Cas.InsertChunk(Chunks[i], ChunkHashes[i]).New); + } + + // GC every other block + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + std::vector<IoHash> KeepChunks; + for (size_t i = 0; i < 20; i += 2) + { + KeepChunks.push_back(ChunkHashes[i]); + } + GcCtx.AddRetainedCids(KeepChunks); + + Cas.Flush(); + Cas.CollectGarbage(GcCtx); + + for (size_t i = 0; i < 20; i += 2) + { + CHECK(Cas.HaveChunk(ChunkHashes[i])); + CHECK(!Cas.HaveChunk(ChunkHashes[i + 1])); + CHECK(ChunkHashes[i] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[i]))); + } + } + } + { + // Re-open + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "test", 1024, 16, false); + + for (size_t i = 0; i < 20; i += 2) + { + CHECK(Cas.HaveChunk(ChunkHashes[i])); + CHECK(!Cas.HaveChunk(ChunkHashes[i + 1])); + CHECK(ChunkHashes[i] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[i]))); + } + } +} + +TEST_CASE("compactcas.gc.handleopeniobuffer") +{ + ScopedTemporaryDirectory TempDir; + + uint64_t ChunkSizes[20] = {128, 541, 311, 181, 218, 37, 4, 397, 5, 92, 551, 721, 31, 92, 16, 99, 131, 41, 541, 84}; + std::vector<IoBuffer> Chunks; + Chunks.reserve(20); + for (const uint64_t& Size : ChunkSizes) + { + Chunks.push_back(CreateRandomChunk(Size)); + } + + std::vector<IoHash> ChunkHashes; + ChunkHashes.reserve(20); + for (const IoBuffer& Chunk : Chunks) + { + ChunkHashes.push_back(IoHash::HashBuffer(Chunk.Data(), Chunk.Size())); + } + + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "test", 1024, 16, true); + + for (size_t i = 0; i < 20; i++) + { + CHECK(Cas.InsertChunk(Chunks[i], ChunkHashes[i]).New); + } + + IoBuffer RetainChunk = Cas.FindChunk(ChunkHashes[5]); + Cas.Flush(); + + // GC everything + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + Cas.CollectGarbage(GcCtx); + + for (size_t i = 0; i < 20; i++) + { + CHECK(!Cas.HaveChunk(ChunkHashes[i])); + } + + CHECK(ChunkHashes[5] == IoHash::HashBuffer(RetainChunk)); +} + +TEST_CASE("compactcas.threadedinsert") +{ + // for (uint32_t i = 0; i < 100; ++i) + { + ScopedTemporaryDirectory TempDir; + + const uint64_t kChunkSize = 1048; + const int32_t kChunkCount = 4096; + uint64_t ExpectedSize = 0; + + std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> Chunks; + Chunks.reserve(kChunkCount); + + for (int32_t Idx = 0; Idx < kChunkCount; ++Idx) + { + while (true) + { + IoBuffer Chunk = CreateRandomChunk(kChunkSize); + IoHash Hash = HashBuffer(Chunk); + if (Chunks.contains(Hash)) + { + continue; + } + Chunks[Hash] = Chunk; + ExpectedSize += Chunk.Size(); + break; + } + } + + std::atomic<size_t> WorkCompleted = 0; + WorkerThreadPool ThreadPool(4); + GcManager Gc; + CasContainerStrategy Cas(Gc); + Cas.Initialize(TempDir.Path(), "test", 32768, 16, true); + { + for (const auto& Chunk : Chunks) + { + const IoHash& Hash = Chunk.first; + const IoBuffer& Buffer = Chunk.second; + ThreadPool.ScheduleWork([&Cas, &WorkCompleted, Buffer, Hash]() { + CasStore::InsertResult InsertResult = Cas.InsertChunk(Buffer, Hash); + ZEN_ASSERT(InsertResult.New); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < Chunks.size()) + { + Sleep(1); + } + } + + WorkCompleted = 0; + const uint64_t TotalSize = Cas.StorageSize().DiskSize; + CHECK_LE(ExpectedSize, TotalSize); + CHECK_GE(ExpectedSize + 32768, TotalSize); + + { + for (const auto& Chunk : Chunks) + { + ThreadPool.ScheduleWork([&Cas, &WorkCompleted, &Chunk]() { + IoHash ChunkHash = Chunk.first; + IoBuffer Buffer = Cas.FindChunk(ChunkHash); + IoHash Hash = IoHash::HashBuffer(Buffer); + CHECK(ChunkHash == Hash); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < Chunks.size()) + { + Sleep(1); + } + } + + std::unordered_set<IoHash, IoHash::Hasher> GcChunkHashes; + GcChunkHashes.reserve(Chunks.size()); + for (const auto& Chunk : Chunks) + { + GcChunkHashes.insert(Chunk.first); + } + { + WorkCompleted = 0; + std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> NewChunks; + NewChunks.reserve(kChunkCount); + + for (int32_t Idx = 0; Idx < kChunkCount; ++Idx) + { + IoBuffer Chunk = CreateRandomChunk(kChunkSize); + IoHash Hash = HashBuffer(Chunk); + NewChunks[Hash] = Chunk; + } + + std::atomic_uint32_t AddedChunkCount; + + for (const auto& Chunk : NewChunks) + { + ThreadPool.ScheduleWork([&Cas, &WorkCompleted, Chunk, &AddedChunkCount]() { + Cas.InsertChunk(Chunk.second, Chunk.first); + AddedChunkCount.fetch_add(1); + WorkCompleted.fetch_add(1); + }); + } + for (const auto& Chunk : Chunks) + { + ThreadPool.ScheduleWork([&Cas, &WorkCompleted, Chunk]() { + IoHash ChunkHash = Chunk.first; + IoBuffer Buffer = Cas.FindChunk(ChunkHash); + if (Buffer) + { + CHECK(ChunkHash == IoHash::HashBuffer(Buffer)); + } + WorkCompleted.fetch_add(1); + }); + } + + while (AddedChunkCount.load() < NewChunks.size()) + { + // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope + for (const auto& Chunk : NewChunks) + { + if (Cas.HaveChunk(Chunk.first)) + { + GcChunkHashes.emplace(Chunk.first); + } + } + std::vector<IoHash> KeepHashes(GcChunkHashes.begin(), GcChunkHashes.end()); + size_t C = 0; + while (C < KeepHashes.size()) + { + if (C % 155 == 0) + { + if (C < KeepHashes.size() - 1) + { + KeepHashes[C] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + if (C + 3 < KeepHashes.size() - 1) + { + KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + } + C++; + } + + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + GcCtx.AddRetainedCids(KeepHashes); + Cas.CollectGarbage(GcCtx); + const HashKeySet& Deleted = GcCtx.DeletedCids(); + Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); }); + } + + while (WorkCompleted < NewChunks.size() + Chunks.size()) + { + Sleep(1); + } + + // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope + for (const auto& Chunk : NewChunks) + { + if (Cas.HaveChunk(Chunk.first)) + { + GcChunkHashes.emplace(Chunk.first); + } + } + std::vector<IoHash> KeepHashes(GcChunkHashes.begin(), GcChunkHashes.end()); + size_t C = 0; + while (C < KeepHashes.size()) + { + if (C % 155 == 0) + { + if (C < KeepHashes.size() - 1) + { + KeepHashes[C] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + if (C + 3 < KeepHashes.size() - 1) + { + KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1]; + KeepHashes.pop_back(); + } + } + C++; + } + + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + GcCtx.AddRetainedCids(KeepHashes); + Cas.CollectGarbage(GcCtx); + const HashKeySet& Deleted = GcCtx.DeletedCids(); + Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); }); + } + { + WorkCompleted = 0; + for (const IoHash& ChunkHash : GcChunkHashes) + { + ThreadPool.ScheduleWork([&Cas, &WorkCompleted, ChunkHash]() { + CHECK(Cas.HaveChunk(ChunkHash)); + CHECK(ChunkHash == IoHash::HashBuffer(Cas.FindChunk(ChunkHash))); + WorkCompleted.fetch_add(1); + }); + } + while (WorkCompleted < GcChunkHashes.size()) + { + Sleep(1); + } + } + } +} + +#endif + +void +compactcas_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenstore/compactcas.h b/src/zenstore/compactcas.h new file mode 100644 index 000000000..b0c6699eb --- /dev/null +++ b/src/zenstore/compactcas.h @@ -0,0 +1,95 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> +#include <zenstore/blockstore.h> +#include <zenstore/caslog.h> +#include <zenstore/gc.h> + +#include "cas.h" + +#include <atomic> +#include <limits> +#include <unordered_map> + +namespace spdlog { +class logger; +} + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +#pragma pack(push) +#pragma pack(1) + +struct CasDiskIndexEntry +{ + static const uint8_t kTombstone = 0x01; + + IoHash Key; + BlockStoreDiskLocation Location; + ZenContentType ContentType = ZenContentType::kUnknownContentType; + uint8_t Flags = 0; +}; + +#pragma pack(pop) + +static_assert(sizeof(CasDiskIndexEntry) == 32); + +/** This implements a storage strategy for small CAS values + * + * New chunks are simply appended to a small object file, and an index is + * maintained to allow chunks to be looked up within the active small object + * files + * + */ + +struct CasContainerStrategy final : public GcStorage +{ + CasContainerStrategy(GcManager& Gc); + ~CasContainerStrategy(); + + CasStore::InsertResult InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash); + IoBuffer FindChunk(const IoHash& ChunkHash); + bool HaveChunk(const IoHash& ChunkHash); + void FilterChunks(HashKeySet& InOutChunks); + void Initialize(const std::filesystem::path& RootDirectory, + const std::string_view ContainerBaseName, + uint32_t MaxBlockSize, + uint64_t Alignment, + bool IsNewStore); + void Flush(); + void Scrub(ScrubContext& Ctx); + virtual void CollectGarbage(GcContext& GcCtx) override; + virtual GcStorageSize StorageSize() const override { return {.DiskSize = m_BlockStore.TotalSize()}; } + +private: + CasStore::InsertResult InsertChunk(const void* ChunkData, size_t ChunkSize, const IoHash& ChunkHash); + void MakeIndexSnapshot(); + uint64_t ReadIndexFile(); + uint64_t ReadLog(uint64_t SkipEntryCount); + void OpenContainer(bool IsNewStore); + + spdlog::logger& Log() { return m_Log; } + + std::filesystem::path m_RootDirectory; + spdlog::logger& m_Log; + uint64_t m_PayloadAlignment = 1u << 4; + uint64_t m_MaxBlockSize = 1u << 28; + bool m_IsInitialized = false; + TCasLogFile<CasDiskIndexEntry> m_CasLog; + uint64_t m_LogFlushPosition = 0; + std::string m_ContainerBaseName; + std::filesystem::path m_BlocksBasePath; + BlockStore m_BlockStore; + + RwLock m_LocationMapLock; + typedef std::unordered_map<IoHash, BlockStoreDiskLocation, IoHash::Hasher> LocationMap_t; + LocationMap_t m_LocationMap; +}; + +void compactcas_forcelink(); + +} // namespace zen diff --git a/src/zenstore/filecas.cpp b/src/zenstore/filecas.cpp new file mode 100644 index 000000000..1d25920c4 --- /dev/null +++ b/src/zenstore/filecas.cpp @@ -0,0 +1,1452 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "filecas.h" + +#include <zencore/compress.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/memory.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> +#include <zencore/thread.h> +#include <zencore/timer.h> +#include <zencore/uid.h> +#include <zenstore/gc.h> +#include <zenstore/scrubcontext.h> +#include <zenutil/basicfile.h> + +#if ZEN_WITH_TESTS +# include <zencore/compactbinarybuilder.h> +#endif + +#include <gsl/gsl-lite.hpp> + +#include <barrier> +#include <filesystem> +#include <functional> +#include <unordered_map> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <xxhash.h> +#if ZEN_PLATFORM_WINDOWS +# include <atlfile.h> +#endif +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen { + +namespace filecas::impl { + const char* IndexExtension = ".uidx"; + const char* LogExtension = ".ulog"; + + std::filesystem::path GetIndexPath(const std::filesystem::path& RootDir) { return RootDir / fmt::format("cas{}", IndexExtension); } + + std::filesystem::path GetTempIndexPath(const std::filesystem::path& RootDir) + { + return RootDir / fmt::format("cas.tmp{}", IndexExtension); + } + + std::filesystem::path GetLogPath(const std::filesystem::path& RootDir) { return RootDir / fmt::format("cas{}", LogExtension); } + +#pragma pack(push) +#pragma pack(1) + + struct FileCasIndexHeader + { + static constexpr uint32_t ExpectedMagic = 0x75696478; // 'uidx'; + static constexpr uint32_t CurrentVersion = 1; + + uint32_t Magic = ExpectedMagic; + uint32_t Version = CurrentVersion; + uint64_t EntryCount = 0; + uint64_t LogPosition = 0; + uint32_t Reserved = 0; + uint32_t Checksum = 0; + + static uint32_t ComputeChecksum(const FileCasIndexHeader& Header) + { + return XXH32(&Header.Magic, sizeof(FileCasIndexHeader) - sizeof(uint32_t), 0xC0C0'BABA); + } + }; + + static_assert(sizeof(FileCasIndexHeader) == 32); + +#pragma pack(pop) + +} // namespace filecas::impl + +FileCasStrategy::ShardingHelper::ShardingHelper(const std::filesystem::path& RootPath, const IoHash& ChunkHash) +{ + ShardedPath.Append(RootPath.c_str()); + + ExtendableStringBuilder<64> HashString; + ChunkHash.ToHexString(HashString); + + const char* str = HashString.c_str(); + + // Shard into a path with two directory levels containing 12 bits and 8 bits + // respectively. + // + // This results in a maximum of 4096 * 256 directories + // + // The numbers have been chosen somewhat arbitrarily but are large to scale + // to very large chunk repositories without creating too many directories + // on a single level since NTFS does not deal very well with this. + // + // It may or may not make sense to make this a configurable policy, and it + // would probably be a good idea to measure performance for different + // policies and chunk counts + + ShardedPath.AppendSeparator(); + ShardedPath.AppendAsciiRange(str, str + 3); + + ShardedPath.AppendSeparator(); + ShardedPath.AppendAsciiRange(str + 3, str + 5); + Shard2len = ShardedPath.Size(); + + ShardedPath.AppendSeparator(); + ShardedPath.AppendAsciiRange(str + 5, str + 40); +} + +////////////////////////////////////////////////////////////////////////// + +FileCasStrategy::FileCasStrategy(GcManager& Gc) : GcStorage(Gc), m_Log(logging::Get("filecas")) +{ +} + +FileCasStrategy::~FileCasStrategy() +{ +} + +void +FileCasStrategy::Initialize(const std::filesystem::path& RootDirectory, bool IsNewStore) +{ + using namespace filecas::impl; + + m_IsInitialized = true; + + m_RootDirectory = RootDirectory; + + m_Index.clear(); + + std::filesystem::path LogPath = GetLogPath(m_RootDirectory); + std::filesystem::path IndexPath = GetIndexPath(m_RootDirectory); + + if (IsNewStore) + { + std::filesystem::remove(LogPath); + std::filesystem::remove(IndexPath); + + if (std::filesystem::is_directory(m_RootDirectory)) + { + // We need to explicitly only delete sharded root folders as the cas manifest, tinyobject and smallobject cas folders may reside + // in this folder as well + struct Visitor : public FileSystemTraversal::TreeVisitor + { + virtual void VisitFile(const std::filesystem::path&, const path_view&, uint64_t) override + { + // We don't care about files + } + static bool IsHexChar(std::filesystem::path::value_type C) + { + return std::find(&HexChars[0], &HexChars[16], C) != &HexChars[16]; + } + virtual bool VisitDirectory([[maybe_unused]] const std::filesystem::path& Parent, + [[maybe_unused]] const path_view& DirectoryName) override + { + if (DirectoryName.length() == 3) + { + if (IsHexChar(DirectoryName[0]) && IsHexChar(DirectoryName[1]) && IsHexChar(DirectoryName[2])) + { + ShardedRoots.push_back(Parent / DirectoryName); + } + } + return false; + } + std::vector<std::filesystem::path> ShardedRoots; + } CasVisitor; + + FileSystemTraversal Traversal; + Traversal.TraverseFileSystem(m_RootDirectory, CasVisitor); + for (const std::filesystem::path& SharededRoot : CasVisitor.ShardedRoots) + { + std::filesystem::remove_all(SharededRoot); + } + } + } + + m_LogFlushPosition = ReadIndexFile(); + uint64_t LogEntryCount = ReadLog(m_LogFlushPosition); + for (const auto& Entry : m_Index) + { + m_TotalSize.fetch_add(Entry.second.Size, std::memory_order::relaxed); + } + + CreateDirectories(m_RootDirectory); + m_CasLog.Open(LogPath, CasLogFile::Mode::kWrite); + + if (IsNewStore || LogEntryCount > 0) + { + MakeIndexSnapshot(); + } +} + +#if ZEN_PLATFORM_WINDOWS +static void +DeletePayloadFileOnClose(const void* FileHandle) +{ + const HANDLE WinFileHandle = (const HANDLE)FileHandle; + // This will cause the file to be deleted when the last handle to it is closed + FILE_DISPOSITION_INFO Fdi{}; + Fdi.DeleteFile = TRUE; + BOOL Success = SetFileInformationByHandle(WinFileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); + + if (!Success) + { + // TODO: We should provide information to this function to tell it if the payload is temporary or not and if we are allowed + // to delete it. + ZEN_WARN("Failed to flag CAS temporary payload file '{}' for deletion: '{}'", + PathFromHandle(WinFileHandle), + GetLastErrorAsString()); + } +} +#endif + +CasStore::InsertResult +FileCasStrategy::InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, CasStore::InsertMode Mode) +{ + ZEN_ASSERT(m_IsInitialized); + +#if !ZEN_WITH_TESTS + ZEN_ASSERT(Chunk.GetContentType() == ZenContentType::kCompressedBinary); +#endif + + if (Mode == CasStore::InsertMode::kCopyOnly) + { + { + RwLock::SharedLockScope _(m_Lock); + if (m_Index.contains(ChunkHash)) + { + return CasStore::InsertResult{.New = false}; + } + } + return InsertChunk(Chunk.Data(), Chunk.Size(), ChunkHash); + } + + // File-based chunks have special case handling whereby we move the file into + // place in the file store directory, thus avoiding unnecessary copying + + IoBufferFileReference FileRef; + if (Chunk.IsWholeFile() && Chunk.GetFileReference(/* out */ FileRef)) + { + { + bool Exists = true; + { + RwLock::SharedLockScope _(m_Lock); + Exists = m_Index.contains(ChunkHash); + } + if (Exists) + { +#if ZEN_PLATFORM_WINDOWS + DeletePayloadFileOnClose(FileRef.FileHandle); +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + std::filesystem::path FilePath = PathFromHandle(FileRef.FileHandle); + if (unlink(FilePath.c_str()) < 0) + { + int UnlinkError = zen::GetLastError(); + if (UnlinkError != ENOENT) + { + ZEN_WARN("Failed to unlink CAS temporary payload file '{}': '{}'", + FilePath.string(), + GetSystemErrorAsString(UnlinkError)); + } + } +#endif + return CasStore::InsertResult{.New = false}; + } + } + + ShardingHelper Name(m_RootDirectory.c_str(), ChunkHash); + + RwLock::ExclusiveLockScope HashLock(LockForHash(ChunkHash)); + +#if ZEN_PLATFORM_WINDOWS + const HANDLE ChunkFileHandle = FileRef.FileHandle; + // See if file already exists + { + CAtlFile PayloadFile; + + if (HRESULT hRes = PayloadFile.Create(Name.ShardedPath.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING); SUCCEEDED(hRes)) + { + // If we succeeded in opening the target file then we don't need to do anything else because it already exists + // and should contain the content we were about to insert + + // We do need to ensure the source file goes away on close, however + size_t ChunkSize = Chunk.GetSize(); + uint64_t FileSize = 0; + if (HRESULT hSizeRes = PayloadFile.GetSize(FileSize); SUCCEEDED(hSizeRes) && FileSize == ChunkSize) + { + HashLock.ReleaseNow(); + + bool IsNew = false; + { + RwLock::ExclusiveLockScope __(m_Lock); + IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = ChunkSize}}).second; + } + if (IsNew) + + { + m_TotalSize.fetch_add(static_cast<uint64_t>(ChunkSize), std::memory_order::relaxed); + } + + DeletePayloadFileOnClose(ChunkFileHandle); + + return CasStore::InsertResult{.New = IsNew}; + } + else + { + ZEN_WARN("get file size FAILED or file size mismatch of file cas '{}'. Expected {}, found {}. Trying to overwrite", + Name.ShardedPath.ToUtf8(), + ChunkSize, + FileSize); + } + } + else + { + if (hRes == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) + { + // Shard directory does not exist + } + else if (hRes == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND)) + { + // Shard directory exists, but not the file + } + else if (hRes == HRESULT_FROM_WIN32(ERROR_SHARING_VIOLATION)) + { + // Sharing violation, likely because we are trying to open a file + // which has been renamed on another thread, and the file handle + // used to rename it is still open. We handle this case below + // instead of here + } + else + { + ZEN_INFO("Unexpected error opening file '{}': {}", Name.ShardedPath.ToUtf8(), hRes); + } + } + } + + std::filesystem::path FullPath(Name.ShardedPath.c_str()); + + std::filesystem::path FilePath = FullPath.parent_path(); + std::wstring FileName = FullPath.native(); + + const DWORD BufferSize = sizeof(FILE_RENAME_INFO) + gsl::narrow<DWORD>(FileName.size() * sizeof(WCHAR)); + FILE_RENAME_INFO* RenameInfo = reinterpret_cast<FILE_RENAME_INFO*>(Memory::Alloc(BufferSize)); + memset(RenameInfo, 0, BufferSize); + + RenameInfo->ReplaceIfExists = FALSE; + RenameInfo->FileNameLength = gsl::narrow<DWORD>(FileName.size()); + memcpy(RenameInfo->FileName, FileName.c_str(), FileName.size() * sizeof(WCHAR)); + RenameInfo->FileName[FileName.size()] = 0; + + auto $ = MakeGuard([&] { Memory::Free(RenameInfo); }); + + // Try to move file into place + BOOL Success = SetFileInformationByHandle(ChunkFileHandle, FileRenameInfo, RenameInfo, BufferSize); + + if (!Success) + { + // The rename/move could fail because the target directory does not yet exist. This code attempts + // to create it + + CAtlFile DirHandle; + + auto InternalCreateDirectoryHandle = [&] { + return DirHandle.Create(FilePath.c_str(), + GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS); + }; + + // It's possible for several threads to enter this logic trying to create the same + // directory. Only one will create the directory of course, but all threads will + // make it through okay + + HRESULT hRes = InternalCreateDirectoryHandle(); + + if (FAILED(hRes)) + { + // TODO: we can handle directory creation more intelligently and efficiently than + // this currently does + + CreateDirectories(FilePath.c_str()); + + hRes = InternalCreateDirectoryHandle(); + } + + if (FAILED(hRes)) + { + ThrowSystemException(hRes, fmt::format("Failed to open shard directory '{}'", FilePath)); + } + + // Retry rename/move + + Success = SetFileInformationByHandle(ChunkFileHandle, FileRenameInfo, RenameInfo, BufferSize); + } + + if (Success) + { + m_CasLog.Append({.Key = ChunkHash, .Size = Chunk.Size()}); + + HashLock.ReleaseNow(); + + bool IsNew = false; + { + RwLock::ExclusiveLockScope __(m_Lock); + IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = Chunk.Size()}}).second; + } + if (IsNew) + { + m_TotalSize.fetch_add(Chunk.Size(), std::memory_order::relaxed); + } + + return CasStore::InsertResult{.New = IsNew}; + } + + const DWORD LastError = GetLastError(); + + if ((LastError == ERROR_FILE_EXISTS) || (LastError == ERROR_ALREADY_EXISTS)) + { + HashLock.ReleaseNow(); + DeletePayloadFileOnClose(ChunkFileHandle); + + bool IsNew = false; + { + RwLock::ExclusiveLockScope __(m_Lock); + IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = Chunk.Size()}}).second; + } + if (IsNew) + { + m_TotalSize.fetch_add(Chunk.Size(), std::memory_order::relaxed); + } + + return CasStore::InsertResult{.New = IsNew}; + } + + ZEN_WARN("rename of CAS payload file failed ('{}'), falling back to regular write for insert of {}", + GetSystemErrorAsString(LastError), + ChunkHash); + + DeletePayloadFileOnClose(ChunkFileHandle); + +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + std::filesystem::path SourcePath = PathFromHandle(FileRef.FileHandle); + std::filesystem::path DestPath = Name.ShardedPath.c_str(); + int Ret = link(SourcePath.c_str(), DestPath.c_str()); + if (Ret < 0 && zen::GetLastError() == ENOENT) + { + // Destination directory doesn't exist. Create it any try again. + CreateDirectories(DestPath.parent_path().c_str()); + Ret = link(SourcePath.c_str(), DestPath.c_str()); + } + int LinkError = zen::GetLastError(); + + if (unlink(SourcePath.c_str()) < 0) + { + int UnlinkError = zen::GetLastError(); + if (UnlinkError != ENOENT) + { + ZEN_WARN("Failed to unlink CAS temporary payload file '{}': '{}'", + SourcePath.string(), + GetSystemErrorAsString(UnlinkError)); + } + } + + // It is possible that someone beat us to it in linking the file. In that + // case a "file exists" error is okay. All others are not. + if (Ret < 0) + { + if (LinkError == EEXIST) + { + HashLock.ReleaseNow(); + bool IsNew = false; + { + RwLock::ExclusiveLockScope __(m_Lock); + IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = Chunk.Size()}}).second; + } + if (IsNew) + { + m_TotalSize.fetch_add(Chunk.Size(), std::memory_order::relaxed); + } + return CasStore::InsertResult{.New = IsNew}; + } + + ZEN_WARN("link of CAS payload file failed ('{}'), falling back to regular write for insert of {}", + GetSystemErrorAsString(LinkError), + ChunkHash); + } + else + { + HashLock.ReleaseNow(); + bool IsNew = false; + { + RwLock::ExclusiveLockScope __(m_Lock); + IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = Chunk.Size()}}).second; + } + if (IsNew) + { + m_TotalSize.fetch_add(Chunk.Size(), std::memory_order::relaxed); + } + return CasStore::InsertResult{.New = IsNew}; + } +#endif // ZEN_PLATFORM_* + } + + return InsertChunk(Chunk.Data(), Chunk.Size(), ChunkHash); +} + +CasStore::InsertResult +FileCasStrategy::InsertChunk(const void* const ChunkData, const size_t ChunkSize, const IoHash& ChunkHash) +{ + ZEN_ASSERT(m_IsInitialized); + + { + RwLock::SharedLockScope _(m_Lock); + if (m_Index.contains(ChunkHash)) + { + return {.New = false}; + } + } + + ShardingHelper Name(m_RootDirectory.c_str(), ChunkHash); + + // See if file already exists + +#if ZEN_PLATFORM_WINDOWS + CAtlFile PayloadFile; + + HRESULT hRes = PayloadFile.Create(Name.ShardedPath.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING); + + if (SUCCEEDED(hRes)) + { + // If we succeeded in opening the file then we don't need to do anything else because it already exists and should contain the + // content we were about to insert + + bool IsNew = false; + { + RwLock::ExclusiveLockScope _(m_Lock); + IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = ChunkSize}}).second; + } + if (IsNew) + { + m_TotalSize.fetch_add(static_cast<uint64_t>(ChunkSize), std::memory_order::relaxed); + } + return CasStore::InsertResult{.New = IsNew}; + } + + PayloadFile.Close(); +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + if (access(Name.ShardedPath.c_str(), F_OK) == 0) + { + return CasStore::InsertResult{.New = false}; + } +#endif + + RwLock::ExclusiveLockScope HashLock(LockForHash(ChunkHash)); + +#if ZEN_PLATFORM_WINDOWS + // For now, use double-checked locking to see if someone else was first + + hRes = PayloadFile.Create(Name.ShardedPath.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING); + + if (SUCCEEDED(hRes)) + { + uint64_t FileSize = 0; + if (HRESULT hSizeRes = PayloadFile.GetSize(FileSize); SUCCEEDED(hSizeRes) && FileSize == ChunkSize) + { + // If we succeeded in opening the file then and the size is correct we don't need to do anything + // else because someone else managed to create the file before we did. Just return. + + HashLock.ReleaseNow(); + bool IsNew = false; + { + RwLock::ExclusiveLockScope __(m_Lock); + IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = ChunkSize}}).second; + } + if (IsNew) + { + m_TotalSize.fetch_add(static_cast<uint64_t>(ChunkSize), std::memory_order::relaxed); + } + return CasStore::InsertResult{.New = IsNew}; + } + else + { + ZEN_WARN("get file size FAILED or file size mismatch of file cas '{}'. Expected {}, found {}. Trying to overwrite", + Name.ShardedPath.ToUtf8(), + ChunkSize, + FileSize); + } + } + + if ((hRes != HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND)) && (hRes != HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND))) + { + ZEN_WARN("Unexpected error code when opening shard file for read: {:#x}", uint32_t(hRes)); + } + + auto InternalCreateFile = [&] { return PayloadFile.Create(Name.ShardedPath.c_str(), GENERIC_WRITE, FILE_SHARE_DELETE, CREATE_ALWAYS); }; + + hRes = InternalCreateFile(); + + if (hRes == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) + { + // Ensure parent directories exist and retry file creation + CreateDirectories(std::wstring_view(Name.ShardedPath.c_str(), Name.Shard2len)); + hRes = InternalCreateFile(); + } + + if (FAILED(hRes)) + { + ThrowSystemException(hRes, fmt::format("Failed to open shard file '{}'", Name.ShardedPath.ToUtf8())); + } +#else + // Attempt to exclusively create the file. + auto InternalCreateFile = [&] { + int Fd = open(Name.ShardedPath.c_str(), O_WRONLY | O_CREAT | O_EXCL | O_CLOEXEC, 0666); + if (Fd >= 0) + { + fchmod(Fd, 0666); + } + return Fd; + }; + int Fd = InternalCreateFile(); + if (Fd < 0) + { + switch (zen::GetLastError()) + { + case EEXIST: + // Another thread has beat us to it so we're golden. + { + HashLock.ReleaseNow(); + + bool IsNew = false; + { + RwLock::ExclusiveLockScope __(m_Lock); + IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = ChunkSize}}).second; + } + if (IsNew) + { + m_TotalSize.fetch_add(static_cast<uint64_t>(ChunkSize), std::memory_order::relaxed); + } + return {.New = IsNew}; + } + break; + + case ENOENT: + if (zen::CreateDirectories(std::string_view(Name.ShardedPath.c_str(), Name.Shard2len))) + { + Fd = InternalCreateFile(); + if (Fd >= 0) + { + break; + } + } + ThrowLastError(fmt::format("Failed creating shard directory '{}'", Name.ShardedPath)); + + default: + ThrowLastError(fmt::format("Unexpected error occurred opening shard file '{}'", Name.ShardedPath.ToUtf8())); + } + } + + struct FdWrapper + { + ~FdWrapper() { Close(); } + void Write(const void* Cursor, size_t Size) { (void)!write(Fd, Cursor, Size); } + void Close() + { + if (Fd >= 0) + { + close(Fd); + Fd = -1; + } + } + int Fd; + } PayloadFile = {Fd}; +#endif // ZEN_PLATFORM_WINDOWS + + size_t ChunkRemain = ChunkSize; + auto ChunkCursor = reinterpret_cast<const uint8_t*>(ChunkData); + + while (ChunkRemain != 0) + { + uint32_t ByteCount = uint32_t(std::min<size_t>(4 * 1024 * 1024ull, ChunkRemain)); + + PayloadFile.Write(ChunkCursor, ByteCount); + + ChunkCursor += ByteCount; + ChunkRemain -= ByteCount; + } + + // We cannot rely on RAII to close the file handle since it would be closed + // *after* the lock is released due to the initialization order + PayloadFile.Close(); + + m_CasLog.Append({.Key = ChunkHash, .Size = ChunkSize}); + + HashLock.ReleaseNow(); + + bool IsNew = false; + { + RwLock::ExclusiveLockScope __(m_Lock); + IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = ChunkSize}}).second; + } + if (IsNew) + { + m_TotalSize.fetch_add(static_cast<uint64_t>(ChunkSize), std::memory_order::relaxed); + } + + return {.New = IsNew}; +} + +IoBuffer +FileCasStrategy::FindChunk(const IoHash& ChunkHash) +{ + ZEN_ASSERT(m_IsInitialized); + + { + RwLock::SharedLockScope _(m_Lock); + if (!m_Index.contains(ChunkHash)) + { + return {}; + } + } + + ShardingHelper Name(m_RootDirectory.c_str(), ChunkHash); + + RwLock::SharedLockScope _(LockForHash(ChunkHash)); + + return IoBufferBuilder::MakeFromFile(Name.ShardedPath.c_str()); +} + +bool +FileCasStrategy::HaveChunk(const IoHash& ChunkHash) +{ + ZEN_ASSERT(m_IsInitialized); + + RwLock::SharedLockScope _(m_Lock); + return m_Index.contains(ChunkHash); +} + +void +FileCasStrategy::DeleteChunk(const IoHash& ChunkHash, std::error_code& Ec) +{ + ShardingHelper Name(m_RootDirectory.c_str(), ChunkHash); + + uint64_t FileSize = static_cast<uint64_t>(std::filesystem::file_size(Name.ShardedPath.c_str(), Ec)); + if (Ec) + { + ZEN_WARN("get file size FAILED, file cas '{}'", Name.ShardedPath.ToUtf8()); + FileSize = 0; + } + + ZEN_DEBUG("deleting CAS payload file '{}' {}", Name.ShardedPath.ToUtf8(), NiceBytes(FileSize)); + std::filesystem::remove(Name.ShardedPath.c_str(), Ec); + + if (!Ec || !std::filesystem::exists(Name.ShardedPath.c_str())) + { + { + RwLock::ExclusiveLockScope _(m_Lock); + if (auto It = m_Index.find(ChunkHash); It != m_Index.end()) + { + m_TotalSize.fetch_sub(It->second.Size, std::memory_order_relaxed); + m_Index.erase(It); + } + } + m_CasLog.Append({.Key = ChunkHash, .Flags = FileCasIndexEntry::kTombStone, .Size = FileSize}); + } +} + +void +FileCasStrategy::FilterChunks(HashKeySet& InOutChunks) +{ + ZEN_ASSERT(m_IsInitialized); + + // NOTE: it's not a problem now, but in the future if a GC should happen while this + // is in flight, the result could be wrong since chunks could go away in the meantime. + // + // It would be good to have a pinning mechanism to make this less likely but + // given that chunks could go away at any point after the results are returned to + // a caller, this is something which needs to be taken into account by anyone consuming + // this functionality in any case + + InOutChunks.RemoveHashesIf([&](const IoHash& Hash) { return HaveChunk(Hash); }); +} + +void +FileCasStrategy::IterateChunks(std::function<void(const IoHash& Hash, IoBuffer&& Payload)>&& Callback) +{ + ZEN_ASSERT(m_IsInitialized); + + RwLock::SharedLockScope _(m_Lock); + for (const auto& It : m_Index) + { + const IoHash& NameHash = It.first; + ShardingHelper Name(m_RootDirectory.c_str(), NameHash); + IoBuffer Payload = IoBufferBuilder::MakeFromFile(Name.ShardedPath.c_str()); + Callback(NameHash, std::move(Payload)); + } +} + +void +FileCasStrategy::Flush() +{ + // Since we don't keep files open after writing there's nothing specific + // to flush here. + // + // Depending on what semantics we want Flush() to provide, it could be + // argued that this should just flush the volume which we are using to + // store the CAS files on here, to ensure metadata is flushed along + // with file data + // + // Related: to facilitate more targeted validation during recovery we could + // maintain a log of when chunks were created +} + +void +FileCasStrategy::Scrub(ScrubContext& Ctx) +{ + ZEN_ASSERT(m_IsInitialized); + + std::vector<IoHash> BadHashes; + uint64_t ChunkCount{0}, ChunkBytes{0}; + + { + std::vector<FileCasStrategy::FileCasIndexEntry> ScannedEntries = FileCasStrategy::ScanFolderForCasFiles(m_RootDirectory); + RwLock::ExclusiveLockScope _(m_Lock); + for (const FileCasStrategy::FileCasIndexEntry& Entry : ScannedEntries) + { + if (m_Index.insert({Entry.Key, {.Size = Entry.Size}}).second) + { + m_TotalSize.fetch_add(static_cast<uint64_t>(Entry.Size), std::memory_order::relaxed); + m_CasLog.Append({.Key = Entry.Key, .Size = Entry.Size}); + } + } + } + + IterateChunks([&](const IoHash& Hash, IoBuffer&& Payload) { + if (!Payload) + { + BadHashes.push_back(Hash); + return; + } + ++ChunkCount; + ChunkBytes += Payload.GetSize(); + + IoHash RawHash; + uint64_t RawSize; + if (CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize)) + { + if (RawHash != Hash) + { + // Hash mismatch + BadHashes.push_back(Hash); + return; + } + return; + } +#if ZEN_WITH_TESTS + IoHash ComputedHash = IoHash::HashBuffer(CompositeBuffer(SharedBuffer(std::move(Payload)))); + if (ComputedHash == Hash) + { + return; + } +#endif + BadHashes.push_back(Hash); + }); + + Ctx.ReportScrubbed(ChunkCount, ChunkBytes); + + if (!BadHashes.empty()) + { + ZEN_WARN("file CAS scrubbing: {} bad chunks found", BadHashes.size()); + + if (Ctx.RunRecovery()) + { + ZEN_WARN("recovery: deleting backing files for {} bad chunks which were identified as bad", BadHashes.size()); + + for (const IoHash& Hash : BadHashes) + { + std::error_code Ec; + DeleteChunk(Hash, Ec); + + if (Ec) + { + ZEN_WARN("failed to delete file for chunk {}", Hash); + } + } + } + } + + // Let whomever it concerns know about the bad chunks. This could + // be used to invalidate higher level data structures more efficiently + // than a full validation pass might be able to do + Ctx.ReportBadCidChunks(BadHashes); + + ZEN_INFO("file CAS scrubbed: {} chunks ({})", ChunkCount, NiceBytes(ChunkBytes)); +} + +void +FileCasStrategy::CollectGarbage(GcContext& GcCtx) +{ + ZEN_ASSERT(m_IsInitialized); + + ZEN_DEBUG("collecting garbage from {}", m_RootDirectory); + + std::vector<IoHash> ChunksToDelete; + std::atomic<uint64_t> ChunksToDeleteBytes{0}; + std::atomic<uint64_t> ChunkCount{0}, ChunkBytes{0}; + + std::vector<IoHash> CandidateCas; + CandidateCas.resize(1); + + uint64_t DeletedCount = 0; + uint64_t OldTotalSize = m_TotalSize.load(std::memory_order::relaxed); + + Stopwatch TotalTimer; + const auto _ = MakeGuard([&] { + ZEN_DEBUG("garbage collect for '{}' DONE after {}, deleted {} out of {} files, removed {} out of {}", + m_RootDirectory, + NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()), + DeletedCount, + ChunkCount, + NiceBytes(OldTotalSize - m_TotalSize.load(std::memory_order::relaxed)), + NiceBytes(OldTotalSize)); + }); + + IterateChunks([&](const IoHash& Hash, IoBuffer&& Payload) { + bool KeepThis = false; + CandidateCas[0] = Hash; + GcCtx.FilterCids(CandidateCas, [&](const IoHash& Hash) { + ZEN_UNUSED(Hash); + KeepThis = true; + }); + + const uint64_t FileSize = Payload.GetSize(); + + if (!KeepThis) + { + ChunksToDelete.push_back(Hash); + ChunksToDeleteBytes.fetch_add(FileSize); + } + + ++ChunkCount; + ChunkBytes.fetch_add(FileSize); + }); + + // TODO, any entires we did not encounter during our IterateChunks should be removed from the index + + if (ChunksToDelete.empty()) + { + ZEN_DEBUG("gc for '{}' SKIPPED, nothing to delete", m_RootDirectory); + return; + } + + ZEN_DEBUG("deleting file CAS garbage for '{}': {} out of {} chunks ({})", + m_RootDirectory, + ChunksToDelete.size(), + ChunkCount.load(), + NiceBytes(ChunksToDeleteBytes)); + + if (GcCtx.IsDeletionMode() == false) + { + ZEN_DEBUG("NOTE: not actually deleting anything since deletion is disabled"); + + return; + } + + for (const IoHash& Hash : ChunksToDelete) + { + ZEN_TRACE("deleting chunk {}", Hash); + + std::error_code Ec; + DeleteChunk(Hash, Ec); + + if (Ec) + { + ZEN_WARN("gc for '{}' failed to delete file for chunk {}: '{}'", m_RootDirectory, Hash, Ec.message()); + continue; + } + DeletedCount++; + } + + GcCtx.AddDeletedCids(ChunksToDelete); +} + +bool +FileCasStrategy::ValidateEntry(const FileCasIndexEntry& Entry, std::string& OutReason) +{ + if (Entry.Key == IoHash::Zero) + { + OutReason = fmt::format("Invalid hash key {}", Entry.Key.ToHexString()); + return false; + } + if (Entry.Flags & (~FileCasIndexEntry::kTombStone)) + { + OutReason = fmt::format("Invalid flags {} for entry {}", Entry.Flags, Entry.Key.ToHexString()); + return false; + } + if (Entry.IsFlagSet(FileCasIndexEntry::kTombStone)) + { + return true; + } + uint64_t Size = Entry.Size; + if (Size == 0) + { + OutReason = fmt::format("Invalid size {} for entry {}", Size, Entry.Key.ToHexString()); + return false; + } + return true; +} + +void +FileCasStrategy::MakeIndexSnapshot() +{ + using namespace filecas::impl; + + uint64_t LogCount = m_CasLog.GetLogCount(); + if (m_LogFlushPosition == LogCount) + { + return; + } + ZEN_DEBUG("write store snapshot for '{}'", m_RootDirectory); + uint64_t EntryCount = 0; + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("wrote store snapshot for '{}' containing {} entries in {}", + m_RootDirectory, + EntryCount, + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + namespace fs = std::filesystem; + + fs::path IndexPath = GetIndexPath(m_RootDirectory); + fs::path STmpIndexPath = GetTempIndexPath(m_RootDirectory); + + // Move index away, we keep it if something goes wrong + if (fs::is_regular_file(STmpIndexPath)) + { + fs::remove(STmpIndexPath); + } + if (fs::is_regular_file(IndexPath)) + { + fs::rename(IndexPath, STmpIndexPath); + } + + try + { + // Write the current state of the location map to a new index state + std::vector<FileCasIndexEntry> Entries; + + { + Entries.resize(m_Index.size()); + + uint64_t EntryIndex = 0; + for (auto& Entry : m_Index) + { + FileCasIndexEntry& IndexEntry = Entries[EntryIndex++]; + IndexEntry.Key = Entry.first; + IndexEntry.Size = Entry.second.Size; + } + } + + BasicFile ObjectIndexFile; + ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kTruncate); + filecas::impl::FileCasIndexHeader Header = {.EntryCount = Entries.size(), .LogPosition = LogCount}; + + Header.Checksum = filecas::impl::FileCasIndexHeader::ComputeChecksum(Header); + + ObjectIndexFile.Write(&Header, sizeof(filecas::impl::FileCasIndexHeader), 0); + ObjectIndexFile.Write(Entries.data(), Entries.size() * sizeof(FileCasIndexEntry), sizeof(filecas::impl::FileCasIndexHeader)); + ObjectIndexFile.Flush(); + ObjectIndexFile.Close(); + EntryCount = Entries.size(); + m_LogFlushPosition = LogCount; + } + catch (std::exception& Err) + { + ZEN_ERROR("snapshot FAILED, reason: '{}'", Err.what()); + + // Restore any previous snapshot + + if (fs::is_regular_file(STmpIndexPath)) + { + fs::remove(IndexPath); + fs::rename(STmpIndexPath, IndexPath); + } + } + if (fs::is_regular_file(STmpIndexPath)) + { + fs::remove(STmpIndexPath); + } +} +uint64_t +FileCasStrategy::ReadIndexFile() +{ + using namespace filecas::impl; + + std::vector<FileCasIndexEntry> Entries; + std::filesystem::path IndexPath = GetIndexPath(m_RootDirectory); + if (std::filesystem::is_regular_file(IndexPath)) + { + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("read store '{}' index containing {} entries in {}", + IndexPath, + Entries.size(), + NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + + BasicFile ObjectIndexFile; + ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kRead); + uint64_t Size = ObjectIndexFile.FileSize(); + if (Size >= sizeof(FileCasIndexHeader)) + { + uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(FileCasIndexHeader))) / sizeof(FileCasIndexEntry); + FileCasIndexHeader Header; + ObjectIndexFile.Read(&Header, sizeof(Header), 0); + if ((Header.Magic == FileCasIndexHeader::ExpectedMagic) && (Header.Version == FileCasIndexHeader::CurrentVersion) && + (Header.Checksum == FileCasIndexHeader::ComputeChecksum(Header)) && (Header.EntryCount <= ExpectedEntryCount)) + { + Entries.resize(Header.EntryCount); + ObjectIndexFile.Read(Entries.data(), Header.EntryCount * sizeof(FileCasIndexEntry), sizeof(FileCasIndexHeader)); + + std::string InvalidEntryReason; + for (const FileCasIndexEntry& Entry : Entries) + { + if (!ValidateEntry(Entry, InvalidEntryReason)) + { + ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", IndexPath, InvalidEntryReason); + continue; + } + m_Index.insert_or_assign(Entry.Key, IndexEntry{.Size = Entry.Size}); + } + + return Header.LogPosition; + } + else + { + ZEN_WARN("skipping invalid index file '{}'", IndexPath); + } + } + return 0; + } + + if (std::filesystem::is_directory(m_RootDirectory)) + { + ZEN_INFO("missing index for file cas, scanning for cas files in {}", m_RootDirectory); + TCasLogFile<FileCasIndexEntry> CasLog; + uint64_t TotalSize = 0; + Stopwatch TotalTimer; + const auto _ = MakeGuard([&] { + ZEN_INFO("scanned file cas folder '{}' DONE after {}, found {} files totalling {}", + m_RootDirectory, + NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()), + CasLog.GetLogCount(), + NiceBytes(TotalSize)); + }); + + std::filesystem::path LogPath = GetLogPath(m_RootDirectory); + + std::vector<FileCasStrategy::FileCasIndexEntry> ScannedEntries = FileCasStrategy::ScanFolderForCasFiles(m_RootDirectory); + CasLog.Open(LogPath, CasLogFile::Mode::kTruncate); + std::string InvalidEntryReason; + for (const FileCasStrategy::FileCasIndexEntry& Entry : ScannedEntries) + { + if (!ValidateEntry(Entry, InvalidEntryReason)) + { + ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", m_RootDirectory, InvalidEntryReason); + continue; + } + m_Index.insert_or_assign(Entry.Key, IndexEntry{.Size = Entry.Size}); + CasLog.Append(Entry); + } + + CasLog.Close(); + } + + return 0; +} + +uint64_t +FileCasStrategy::ReadLog(uint64_t SkipEntryCount) +{ + using namespace filecas::impl; + + std::filesystem::path LogPath = GetLogPath(m_RootDirectory); + if (std::filesystem::is_regular_file(LogPath)) + { + uint64_t LogEntryCount = 0; + Stopwatch Timer; + const auto _ = MakeGuard([&] { + ZEN_INFO("read store '{}' log containing {} entries in {}", LogPath, LogEntryCount, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); + }); + TCasLogFile<FileCasIndexEntry> CasLog; + CasLog.Open(LogPath, CasLogFile::Mode::kRead); + if (CasLog.Initialize()) + { + uint64_t EntryCount = CasLog.GetLogCount(); + if (EntryCount < SkipEntryCount) + { + ZEN_WARN("reading full log at '{}', reason: Log position from index snapshot is out of range", LogPath); + SkipEntryCount = 0; + } + LogEntryCount = EntryCount - SkipEntryCount; + m_Index.reserve(LogEntryCount); + uint64_t InvalidEntryCount = 0; + CasLog.Replay( + [&](const FileCasIndexEntry& Record) { + std::string InvalidEntryReason; + if (Record.Flags & FileCasIndexEntry::kTombStone) + { + m_Index.erase(Record.Key); + return; + } + if (!ValidateEntry(Record, InvalidEntryReason)) + { + ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", LogPath, InvalidEntryReason); + ++InvalidEntryCount; + return; + } + m_Index.insert_or_assign(Record.Key, IndexEntry{.Size = Record.Size}); + }, + SkipEntryCount); + if (InvalidEntryCount) + { + ZEN_WARN("found {} invalid entries in '{}'", InvalidEntryCount, LogPath); + } + return LogEntryCount; + } + } + return 0; +} + +std::vector<FileCasStrategy::FileCasIndexEntry> +FileCasStrategy::ScanFolderForCasFiles(const std::filesystem::path& RootDir) +{ + using namespace filecas::impl; + + std::vector<FileCasIndexEntry> Entries; + struct Visitor : public FileSystemTraversal::TreeVisitor + { + Visitor(const std::filesystem::path& RootDir, std::vector<FileCasIndexEntry>& Entries) : RootDirectory(RootDir), Entries(Entries) {} + virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t FileSize) override + { + std::filesystem::path RelPath = std::filesystem::relative(Parent, RootDirectory); + + std::filesystem::path::string_type PathString = RelPath.native(); + + if ((PathString.size() == (3 + 2 + 1)) && (File.size() == (40 - 3 - 2))) + { + if (PathString.at(3) == std::filesystem::path::preferred_separator) + { + PathString.erase(3, 1); + } + PathString.append(File); + + // TODO: should validate that we're actually dealing with a valid hex string here +#if ZEN_PLATFORM_WINDOWS + StringBuilder<64> Utf8; + WideToUtf8(PathString, Utf8); + IoHash NameHash = IoHash::FromHexString({Utf8.Data(), Utf8.Size()}); +#else + IoHash NameHash = IoHash::FromHexString(PathString); +#endif + Entries.emplace_back(FileCasIndexEntry{.Key = NameHash, .Size = FileSize}); + } + } + + virtual bool VisitDirectory([[maybe_unused]] const std::filesystem::path& Parent, + [[maybe_unused]] const path_view& DirectoryName) override + { + return true; + } + + const std::filesystem::path& RootDirectory; + std::vector<FileCasIndexEntry>& Entries; + } CasVisitor{RootDir, Entries}; + + FileSystemTraversal Traversal; + Traversal.TraverseFileSystem(RootDir, CasVisitor); + return Entries; +}; + + ////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +TEST_CASE("cas.file.move") +{ + // specifying an absolute path here can be helpful when using procmon to dig into things + ScopedTemporaryDirectory TempDir; // {"d:\\filecas_testdir"}; + + GcManager Gc; + + FileCasStrategy FileCas(Gc); + FileCas.Initialize(TempDir.Path() / "cas", /* IsNewStore */ true); + + { + std::filesystem::path Payload1Path{TempDir.Path() / "payload_1"}; + + IoBuffer ZeroBytes{1024 * 1024}; + IoHash ZeroHash = IoHash::HashBuffer(ZeroBytes); + + BasicFile PayloadFile; + PayloadFile.Open(Payload1Path, BasicFile::Mode::kTruncate); + PayloadFile.Write(ZeroBytes, 0); + PayloadFile.Close(); + + IoBuffer Payload1 = IoBufferBuilder::MakeFromTemporaryFile(Payload1Path); + + CasStore::InsertResult Result = FileCas.InsertChunk(Payload1, ZeroHash); + CHECK_EQ(Result.New, true); + } + +# if 0 + SUBCASE("stresstest") + { + std::vector<IoHash> PayloadHashes; + + const int kWorkers = 64; + const int kItemCount = 128; + + for (int w = 0; w < kWorkers; ++w) + { + for (int i = 0; i < kItemCount; ++i) + { + IoBuffer Payload{1024}; + *reinterpret_cast<int*>(Payload.MutableData()) = i; + PayloadHashes.push_back(IoHash::HashBuffer(Payload)); + + std::filesystem::path PayloadPath{TempDir.Path() / fmt::format("payload_{}_{}", w, i)}; + WriteFile(PayloadPath, Payload); + } + } + + std::barrier Sync{kWorkers}; + + auto PopulateAll = [&](int w) { + std::vector<IoBuffer> Buffers; + + for (int i = 0; i < kItemCount; ++i) + { + std::filesystem::path PayloadPath{TempDir.Path() / fmt::format("payload_{}_{}", w, i)}; + IoBuffer Payload = IoBufferBuilder::MakeFromTemporaryFile(PayloadPath); + Buffers.push_back(Payload); + Sync.arrive_and_wait(); + CasStore::InsertResult Result = FileCas.InsertChunk(Payload, PayloadHashes[i]); + } + }; + + std::vector<std::jthread> Threads; + + for (int i = 0; i < kWorkers; ++i) + { + Threads.push_back(std::jthread(PopulateAll, i)); + } + + for (std::jthread& Thread : Threads) + { + Thread.join(); + } + } +# endif +} + +TEST_CASE("cas.file.gc") +{ + // specifying an absolute path here can be helpful when using procmon to dig into things + ScopedTemporaryDirectory TempDir; // {"d:\\filecas_testdir"}; + + GcManager Gc; + FileCasStrategy FileCas(Gc); + FileCas.Initialize(TempDir.Path() / "cas", /* IsNewStore */ true); + + const int kIterationCount = 1000; + std::vector<IoHash> Keys{kIterationCount}; + + auto InsertChunks = [&] { + for (int i = 0; i < kIterationCount; ++i) + { + CbObjectWriter Cbo; + Cbo << "id" << i; + CbObject Obj = Cbo.Save(); + + IoBuffer ObjBuffer = Obj.GetBuffer().AsIoBuffer(); + IoHash Hash = HashBuffer(ObjBuffer); + + FileCas.InsertChunk(ObjBuffer, Hash); + + Keys[i] = Hash; + } + }; + + // Drop everything + + { + InsertChunks(); + + GcContext Ctx(GcClock::Now() - std::chrono::hours(24)); + FileCas.CollectGarbage(Ctx); + + for (const IoHash& Key : Keys) + { + IoBuffer Chunk = FileCas.FindChunk(Key); + + CHECK(!Chunk); + } + } + + // Keep roughly half of the chunks + + { + InsertChunks(); + + GcContext Ctx(GcClock::Now() - std::chrono::hours(24)); + + for (const IoHash& Key : Keys) + { + if (Key.Hash[0] & 1) + { + Ctx.AddRetainedCids(std::vector<IoHash>{Key}); + } + } + + FileCas.CollectGarbage(Ctx); + + for (const IoHash& Key : Keys) + { + if (Key.Hash[0] & 1) + { + CHECK(FileCas.FindChunk(Key)); + } + else + { + CHECK(!FileCas.FindChunk(Key)); + } + } + } +} + +#endif + +void +filecas_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenstore/filecas.h b/src/zenstore/filecas.h new file mode 100644 index 000000000..420b3a634 --- /dev/null +++ b/src/zenstore/filecas.h @@ -0,0 +1,102 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#include <zencore/filesystem.h> +#include <zencore/iobuffer.h> +#include <zencore/iohash.h> +#include <zencore/thread.h> +#include <zenstore/caslog.h> +#include <zenstore/gc.h> + +#include "cas.h" + +#include <atomic> +#include <functional> + +namespace spdlog { +class logger; +} + +namespace zen { + +class BasicFile; + +/** CAS storage strategy using a file-per-chunk storage strategy + */ + +struct FileCasStrategy final : public GcStorage +{ + FileCasStrategy(GcManager& Gc); + ~FileCasStrategy(); + + void Initialize(const std::filesystem::path& RootDirectory, bool IsNewStore); + CasStore::InsertResult InsertChunk(IoBuffer Chunk, + const IoHash& ChunkHash, + CasStore::InsertMode Mode = CasStore::InsertMode::kMayBeMovedInPlace); + IoBuffer FindChunk(const IoHash& ChunkHash); + bool HaveChunk(const IoHash& ChunkHash); + void FilterChunks(HashKeySet& InOutChunks); + void Flush(); + void Scrub(ScrubContext& Ctx); + virtual void CollectGarbage(GcContext& GcCtx) override; + virtual GcStorageSize StorageSize() const override { return {.DiskSize = m_TotalSize.load(std::memory_order::relaxed)}; } + +private: + void MakeIndexSnapshot(); + uint64_t ReadIndexFile(); + uint64_t ReadLog(uint64_t LogPosition); + + struct IndexEntry + { + uint64_t Size = 0; + }; + using IndexMap = tsl::robin_map<IoHash, IndexEntry, IoHash::Hasher>; + + CasStore::InsertResult InsertChunk(const void* ChunkData, size_t ChunkSize, const IoHash& ChunkHash); + + std::filesystem::path m_RootDirectory; + RwLock m_Lock; + IndexMap m_Index; + RwLock m_ShardLocks[256]; // TODO: these should be spaced out so they don't share cache lines + spdlog::logger& m_Log; + spdlog::logger& Log() { return m_Log; } + std::atomic_uint64_t m_TotalSize{}; + bool m_IsInitialized = false; + + struct FileCasIndexEntry + { + static const uint32_t kTombStone = 0x0000'0001; + + bool IsFlagSet(const uint32_t Flag) const { return (Flags & kTombStone) == Flag; } + + IoHash Key; + uint32_t Flags = 0; + uint64_t Size = 0; + }; + static bool ValidateEntry(const FileCasIndexEntry& Entry, std::string& OutReason); + static std::vector<FileCasStrategy::FileCasIndexEntry> ScanFolderForCasFiles(const std::filesystem::path& RootDir); + + static_assert(sizeof(FileCasIndexEntry) == 32); + + TCasLogFile<FileCasIndexEntry> m_CasLog; + uint64_t m_LogFlushPosition = 0; + + inline RwLock& LockForHash(const IoHash& Hash) { return m_ShardLocks[Hash.Hash[19]]; } + void IterateChunks(std::function<void(const IoHash& Hash, IoBuffer&& Payload)>&& Callback); + void DeleteChunk(const IoHash& ChunkHash, std::error_code& Ec); + + struct ShardingHelper + { + ShardingHelper(const std::filesystem::path& RootPath, const IoHash& ChunkHash); + + size_t Shard2len = 0; + ExtendablePathBuilder<128> ShardedPath; + }; +}; + +void filecas_forcelink(); + +} // namespace zen diff --git a/src/zenstore/gc.cpp b/src/zenstore/gc.cpp new file mode 100644 index 000000000..370c3c965 --- /dev/null +++ b/src/zenstore/gc.cpp @@ -0,0 +1,1312 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenstore/gc.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinaryvalidation.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> +#include <zencore/string.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> +#include <zencore/timer.h> +#include <zenstore/cidstore.h> + +#include "cas.h" + +#include <fmt/format.h> +#include <filesystem> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#else +# include <fcntl.h> +# include <sys/file.h> +# include <sys/stat.h> +# include <unistd.h> +#endif + +#if ZEN_WITH_TESTS +# include <zencore/compress.h> +# include <algorithm> +# include <random> +#endif + +template<> +struct fmt::formatter<zen::GcClock::TimePoint> : formatter<string_view> +{ + template<typename FormatContext> + auto format(const zen::GcClock::TimePoint& TimePoint, FormatContext& ctx) + { + std::time_t Time = std::chrono::system_clock::to_time_t(TimePoint); + zen::ExtendableStringBuilder<32> String; + String << std::ctime(&Time); + return formatter<string_view>::format(String.ToView(), ctx); + } +}; + +namespace zen { + +using namespace std::literals; +namespace fs = std::filesystem; + +////////////////////////////////////////////////////////////////////////// + +namespace { + std::error_code CreateGCReserve(const std::filesystem::path& Path, uint64_t Size) + { + if (Size == 0) + { + std::filesystem::remove(Path); + return std::error_code{}; + } + CreateDirectories(Path.parent_path()); + if (std::filesystem::is_regular_file(Path) && std::filesystem::file_size(Path) == Size) + { + return std::error_code(); + } +#if ZEN_PLATFORM_WINDOWS + DWORD dwCreationDisposition = CREATE_ALWAYS; + DWORD dwDesiredAccess = GENERIC_READ | GENERIC_WRITE; + + const DWORD dwShareMode = 0; + const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL; + HANDLE hTemplateFile = nullptr; + + HANDLE FileHandle = CreateFile(Path.c_str(), + dwDesiredAccess, + dwShareMode, + /* lpSecurityAttributes */ nullptr, + dwCreationDisposition, + dwFlagsAndAttributes, + hTemplateFile); + + if (FileHandle == INVALID_HANDLE_VALUE) + { + return MakeErrorCodeFromLastError(); + } + bool Keep = true; + auto _ = MakeGuard([&]() { + ::CloseHandle(FileHandle); + if (!Keep) + { + ::DeleteFile(Path.c_str()); + } + }); + LARGE_INTEGER liFileSize; + liFileSize.QuadPart = Size; + BOOL OK = ::SetFilePointerEx(FileHandle, liFileSize, 0, FILE_BEGIN); + if (!OK) + { + return MakeErrorCodeFromLastError(); + } + OK = ::SetEndOfFile(FileHandle); + if (!OK) + { + return MakeErrorCodeFromLastError(); + } + Keep = true; +#else + int OpenFlags = O_CLOEXEC | O_RDWR | O_CREAT; + int Fd = open(Path.c_str(), OpenFlags, 0666); + if (Fd < 0) + { + return MakeErrorCodeFromLastError(); + } + + bool Keep = true; + auto _ = MakeGuard([&]() { + close(Fd); + if (!Keep) + { + unlink(Path.c_str()); + } + }); + + if (fchmod(Fd, 0666) < 0) + { + return MakeErrorCodeFromLastError(); + } + +# if ZEN_PLATFORM_MAC + if (ftruncate(Fd, (off_t)Size) < 0) + { + return MakeErrorCodeFromLastError(); + } +# else + if (ftruncate64(Fd, (off64_t)Size) < 0) + { + return MakeErrorCodeFromLastError(); + } + int Error = posix_fallocate64(Fd, 0, (off64_t)Size); + if (Error) + { + return MakeErrorCode(Error); + } +# endif + Keep = true; +#endif + return std::error_code{}; + } + +} // namespace + +////////////////////////////////////////////////////////////////////////// + +CbObject +LoadCompactBinaryObject(const fs::path& Path) +{ + FileContents Result = ReadFile(Path); + + if (!Result.ErrorCode) + { + IoBuffer Buffer = Result.Flatten(); + if (CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All); Error == CbValidateError::None) + { + return LoadCompactBinaryObject(Buffer); + } + } + + return CbObject(); +} + +void +SaveCompactBinaryObject(const fs::path& Path, const CbObject& Object) +{ + WriteFile(Path, Object.GetBuffer().AsIoBuffer()); +} + +////////////////////////////////////////////////////////////////////////// + +struct GcContext::GcState +{ + using CacheKeyContexts = std::unordered_map<std::string, std::vector<IoHash>>; + + CacheKeyContexts m_ExpiredCacheKeys; + HashKeySet m_RetainedCids; + HashKeySet m_DeletedCids; + GcClock::TimePoint m_ExpireTime; + bool m_DeletionMode = true; + bool m_CollectSmallObjects = false; + + std::filesystem::path DiskReservePath; +}; + +GcContext::GcContext(const GcClock::TimePoint& ExpireTime) : m_State(std::make_unique<GcState>()) +{ + m_State->m_ExpireTime = ExpireTime; +} + +GcContext::~GcContext() +{ +} + +void +GcContext::AddRetainedCids(std::span<const IoHash> Cids) +{ + m_State->m_RetainedCids.AddHashesToSet(Cids); +} + +void +GcContext::SetExpiredCacheKeys(const std::string& CacheKeyContext, std::vector<IoHash>&& ExpiredKeys) +{ + m_State->m_ExpiredCacheKeys[CacheKeyContext] = std::move(ExpiredKeys); +} + +void +GcContext::IterateCids(std::function<void(const IoHash&)> Callback) +{ + m_State->m_RetainedCids.IterateHashes([&](const IoHash& Hash) { Callback(Hash); }); +} + +void +GcContext::FilterCids(std::span<const IoHash> Cid, std::function<void(const IoHash&)> KeepFunc) +{ + m_State->m_RetainedCids.FilterHashes(Cid, [&](const IoHash& Hash) { KeepFunc(Hash); }); +} + +void +GcContext::FilterCids(std::span<const IoHash> Cid, std::function<void(const IoHash&, bool)>&& FilterFunc) +{ + m_State->m_RetainedCids.FilterHashes(Cid, std::move(FilterFunc)); +} + +void +GcContext::AddDeletedCids(std::span<const IoHash> Cas) +{ + m_State->m_DeletedCids.AddHashesToSet(Cas); +} + +const HashKeySet& +GcContext::DeletedCids() +{ + return m_State->m_DeletedCids; +} + +std::span<const IoHash> +GcContext::ExpiredCacheKeys(const std::string& CacheKeyContext) const +{ + return m_State->m_ExpiredCacheKeys[CacheKeyContext]; +} + +bool +GcContext::IsDeletionMode() const +{ + return m_State->m_DeletionMode; +} + +void +GcContext::SetDeletionMode(bool NewState) +{ + m_State->m_DeletionMode = NewState; +} + +bool +GcContext::CollectSmallObjects() const +{ + return m_State->m_CollectSmallObjects; +} + +void +GcContext::CollectSmallObjects(bool NewState) +{ + m_State->m_CollectSmallObjects = NewState; +} + +GcClock::TimePoint +GcContext::ExpireTime() const +{ + return m_State->m_ExpireTime; +} + +void +GcContext::DiskReservePath(const std::filesystem::path& Path) +{ + m_State->DiskReservePath = Path; +} + +uint64_t +GcContext::ClaimGCReserve() +{ + if (!std::filesystem::is_regular_file(m_State->DiskReservePath)) + { + return 0; + } + uint64_t ReclaimedSize = std::filesystem::file_size(m_State->DiskReservePath); + if (std::filesystem::remove(m_State->DiskReservePath)) + { + return ReclaimedSize; + } + return 0; +} + +////////////////////////////////////////////////////////////////////////// + +GcContributor::GcContributor(GcManager& Gc) : m_Gc(Gc) +{ + m_Gc.AddGcContributor(this); +} + +GcContributor::~GcContributor() +{ + m_Gc.RemoveGcContributor(this); +} + +////////////////////////////////////////////////////////////////////////// + +GcStorage::GcStorage(GcManager& Gc) : m_Gc(Gc) +{ + m_Gc.AddGcStorage(this); +} + +GcStorage::~GcStorage() +{ + m_Gc.RemoveGcStorage(this); +} + +////////////////////////////////////////////////////////////////////////// + +GcManager::GcManager() : m_Log(logging::Get("gc")) +{ +} + +GcManager::~GcManager() +{ +} + +void +GcManager::AddGcContributor(GcContributor* Contributor) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_GcContribs.push_back(Contributor); +} + +void +GcManager::RemoveGcContributor(GcContributor* Contributor) +{ + RwLock::ExclusiveLockScope _(m_Lock); + std::erase_if(m_GcContribs, [&](GcContributor* $) { return $ == Contributor; }); +} + +void +GcManager::AddGcStorage(GcStorage* Storage) +{ + ZEN_ASSERT(Storage != nullptr); + RwLock::ExclusiveLockScope _(m_Lock); + m_GcStorage.push_back(Storage); +} + +void +GcManager::RemoveGcStorage(GcStorage* Storage) +{ + RwLock::ExclusiveLockScope _(m_Lock); + std::erase_if(m_GcStorage, [&](GcStorage* $) { return $ == Storage; }); +} + +void +GcManager::CollectGarbage(GcContext& GcCtx) +{ + RwLock::SharedLockScope _(m_Lock); + + // First gather reference set + { + Stopwatch Timer; + const auto Guard = MakeGuard([&] { ZEN_INFO("gathered references in {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); }); + for (GcContributor* Contributor : m_GcContribs) + { + Contributor->GatherReferences(GcCtx); + } + } + + // Then trim storage + { + GcStorageSize GCTotalSizeDiff; + Stopwatch Timer; + const auto Guard = MakeGuard([&] { + ZEN_INFO("collected garbage in {}. Removed {} disk space, {} memory", + NiceTimeSpanMs(Timer.GetElapsedTimeMs()), + NiceBytes(GCTotalSizeDiff.DiskSize), + NiceBytes(GCTotalSizeDiff.MemorySize)); + }); + for (GcStorage* Storage : m_GcStorage) + { + const auto PreSize = Storage->StorageSize(); + Storage->CollectGarbage(GcCtx); + const auto PostSize = Storage->StorageSize(); + GCTotalSizeDiff.DiskSize += PreSize.DiskSize > PostSize.DiskSize ? PreSize.DiskSize - PostSize.DiskSize : 0; + GCTotalSizeDiff.MemorySize += PreSize.MemorySize > PostSize.MemorySize ? PreSize.MemorySize - PostSize.MemorySize : 0; + } + } +} + +GcStorageSize +GcManager::TotalStorageSize() const +{ + RwLock::SharedLockScope _(m_Lock); + + GcStorageSize TotalSize; + + for (GcStorage* Storage : m_GcStorage) + { + const auto Size = Storage->StorageSize(); + TotalSize.DiskSize += Size.DiskSize; + TotalSize.MemorySize += Size.MemorySize; + } + + return TotalSize; +} + +#if ZEN_USE_REF_TRACKING +void +GcManager::OnNewCidReferences(std::span<IoHash> Hashes) +{ + ZEN_UNUSED(Hashes); +} + +void +GcManager::OnCommittedCidReferences(std::span<IoHash> Hashes) +{ + ZEN_UNUSED(Hashes); +} + +void +GcManager::OnDroppedCidReferences(std::span<IoHash> Hashes) +{ + ZEN_UNUSED(Hashes); +} +#endif + +////////////////////////////////////////////////////////////////////////// +void +DiskUsageWindow::KeepRange(GcClock::Tick StartTick, GcClock::Tick EndTick) +{ + auto It = m_LogWindow.begin(); + if (It == m_LogWindow.end()) + { + return; + } + while (It->SampleTime < StartTick) + { + ++It; + if (It == m_LogWindow.end()) + { + m_LogWindow.clear(); + return; + } + } + m_LogWindow.erase(m_LogWindow.begin(), It); + + It = m_LogWindow.begin(); + while (It != m_LogWindow.end()) + { + if (It->SampleTime >= EndTick) + { + m_LogWindow.erase(It, m_LogWindow.end()); + return; + } + It++; + } +} + +std::vector<uint64_t> +DiskUsageWindow::GetDiskDeltas(GcClock::Tick StartTick, GcClock::Tick EndTick, GcClock::Tick DeltaWidth, uint64_t& OutMaxDelta) const +{ + ZEN_ASSERT(StartTick != -1); + ZEN_ASSERT(DeltaWidth > 0); + + std::vector<uint64_t> Result; + Result.reserve((EndTick - StartTick + DeltaWidth - 1) / DeltaWidth); + + size_t WindowSize = m_LogWindow.size(); + GcClock::Tick FirstWindowTick = WindowSize < 2 ? EndTick : m_LogWindow[1].SampleTime; + + GcClock::Tick RangeStart = StartTick; + while (FirstWindowTick >= RangeStart + DeltaWidth && RangeStart < EndTick) + { + Result.push_back(0); + RangeStart += DeltaWidth; + } + + uint64_t DeltaSum = 0; + size_t WindowIndex = 1; + while (WindowIndex < WindowSize && RangeStart < EndTick) + { + const DiskUsageEntry& Entry = m_LogWindow[WindowIndex]; + if (Entry.SampleTime < RangeStart) + { + ++WindowIndex; + continue; + } + GcClock::Tick RangeEnd = Min(EndTick, RangeStart + DeltaWidth); + ZEN_ASSERT(Entry.SampleTime >= RangeStart); + if (Entry.SampleTime >= RangeEnd) + { + Result.push_back(DeltaSum); + OutMaxDelta = Max(DeltaSum, OutMaxDelta); + DeltaSum = 0; + RangeStart = RangeEnd; + continue; + } + const DiskUsageEntry& PrevEntry = m_LogWindow[WindowIndex - 1]; + if (Entry.DiskUsage > PrevEntry.DiskUsage) + { + uint64_t Delta = Entry.DiskUsage - PrevEntry.DiskUsage; + DeltaSum += Delta; + } + WindowIndex++; + } + + while (RangeStart < EndTick) + { + Result.push_back(DeltaSum); + OutMaxDelta = Max(DeltaSum, OutMaxDelta); + DeltaSum = 0; + RangeStart += DeltaWidth; + } + return Result; +} + +GcClock::Tick +DiskUsageWindow::FindTimepointThatRemoves(uint64_t Amount, GcClock::Tick EndTick) const +{ + ZEN_ASSERT(Amount > 0); + uint64_t RemainingToFind = Amount; + size_t Offset = 1; + while (Offset < m_LogWindow.size()) + { + const DiskUsageEntry& Entry = m_LogWindow[Offset]; + if (Entry.SampleTime >= EndTick) + { + return EndTick; + } + const DiskUsageEntry& PreviousEntry = m_LogWindow[Offset - 1]; + uint64_t Delta = Entry.DiskUsage > PreviousEntry.DiskUsage ? Entry.DiskUsage - PreviousEntry.DiskUsage : 0; + if (Delta >= RemainingToFind) + { + return m_LogWindow[Offset].SampleTime + 1; + } + RemainingToFind -= Delta; + Offset++; + } + return EndTick; +} + +////////////////////////////////////////////////////////////////////////// + +GcScheduler::GcScheduler(GcManager& GcManager) : m_Log(logging::Get("gc")), m_GcManager(GcManager) +{ +} + +GcScheduler::~GcScheduler() +{ + Shutdown(); +} + +void +GcScheduler::Initialize(const GcSchedulerConfig& Config) +{ + using namespace std::chrono; + + m_Config = Config; + + if (m_Config.Interval.count() && m_Config.Interval < m_Config.MonitorInterval) + { + m_Config.Interval = m_Config.MonitorInterval; + } + + std::filesystem::create_directories(Config.RootDirectory); + + std::error_code Ec = CreateGCReserve(m_Config.RootDirectory / "reserve.gc", m_Config.DiskReserveSize); + if (Ec) + { + ZEN_WARN("unable to create GC reserve at '{}' with size {}, reason '{}'", + m_Config.RootDirectory / "reserve.gc", + NiceBytes(m_Config.DiskReserveSize), + Ec.message()); + } + + m_LastGcTime = GcClock::Now(); + m_LastGcExpireTime = GcClock::TimePoint::min(); + + if (CbObject SchedulerState = LoadCompactBinaryObject(Config.RootDirectory / "gc_state")) + { + m_LastGcTime = GcClock::TimePoint(GcClock::Duration(SchedulerState["LastGcTime"sv].AsInt64())); + m_LastGcExpireTime = + GcClock::TimePoint(GcClock::Duration(SchedulerState["LastGcExpireTime"].AsInt64(GcClock::Duration::min().count()))); + if (m_LastGcTime + m_Config.Interval < GcClock::Now()) + { + // TODO: Trigger GC? + m_LastGcTime = GcClock::Now(); + } + } + + m_DiskUsageLog.Open(m_Config.RootDirectory / "gc.dlog", CasLogFile::Mode::kWrite); + m_DiskUsageLog.Initialize(); + const GcClock::Tick LastGCTick = m_LastGcTime.time_since_epoch().count(); + m_DiskUsageLog.Replay( + [this, LastGCTick](const DiskUsageWindow::DiskUsageEntry& Entry) { + if (Entry.SampleTime >= m_LastGcExpireTime.time_since_epoch().count()) + { + m_DiskUsageWindow.Append(Entry); + } + }, + 0); + + m_NextGcTime = NextGcTime(m_LastGcTime); + m_GcThread = std::thread(&GcScheduler::SchedulerThread, this); +} + +void +GcScheduler::Shutdown() +{ + if (static_cast<uint32_t>(GcSchedulerStatus::kStopped) != m_Status) + { + bool GcIsRunning = m_Status == static_cast<uint32_t>(GcSchedulerStatus::kRunning); + m_Status = static_cast<uint32_t>(GcSchedulerStatus::kStopped); + m_GcSignal.notify_one(); + + if (m_GcThread.joinable()) + { + if (GcIsRunning) + { + ZEN_INFO("Waiting for garbage collection to complete"); + } + m_GcThread.join(); + } + } + m_DiskUsageLog.Flush(); + m_DiskUsageLog.Close(); +} + +bool +GcScheduler::Trigger(const GcScheduler::TriggerParams& Params) +{ + if (m_Config.Enabled) + { + std::unique_lock Lock(m_GcMutex); + if (static_cast<uint32_t>(GcSchedulerStatus::kIdle) == m_Status) + { + m_TriggerParams = Params; + uint32_t IdleState = static_cast<uint32_t>(GcSchedulerStatus::kIdle); + if (m_Status.compare_exchange_strong(IdleState, static_cast<uint32_t>(GcSchedulerStatus::kRunning))) + { + m_GcSignal.notify_one(); + return true; + } + } + } + + return false; +} + +void +GcScheduler::SchedulerThread() +{ + std::chrono::seconds WaitTime{0}; + + for (;;) + { + bool Timeout = false; + { + ZEN_ASSERT(WaitTime.count() >= 0); + std::unique_lock Lock(m_GcMutex); + Timeout = std::cv_status::timeout == m_GcSignal.wait_for(Lock, WaitTime); + } + + if (Status() == GcSchedulerStatus::kStopped) + { + break; + } + + if (!m_Config.Enabled) + { + WaitTime = std::chrono::seconds::max(); + continue; + } + + if (!Timeout && Status() == GcSchedulerStatus::kIdle) + { + continue; + } + + bool Delete = true; + bool CollectSmallObjects = m_Config.CollectSmallObjects; + std::chrono::seconds MaxCacheDuration = m_Config.MaxCacheDuration; + uint64_t DiskSizeSoftLimit = m_Config.DiskSizeSoftLimit; + GcClock::TimePoint Now = GcClock::Now(); + if (m_TriggerParams) + { + const auto TriggerParams = m_TriggerParams.value(); + m_TriggerParams.reset(); + + CollectSmallObjects = TriggerParams.CollectSmallObjects; + if (TriggerParams.MaxCacheDuration != std::chrono::seconds::max()) + { + MaxCacheDuration = TriggerParams.MaxCacheDuration; + } + if (TriggerParams.DiskSizeSoftLimit != 0) + { + DiskSizeSoftLimit = TriggerParams.DiskSizeSoftLimit; + } + } + + GcClock::TimePoint ExpireTime = MaxCacheDuration == GcClock::Duration::max() ? GcClock::TimePoint::min() : Now - MaxCacheDuration; + + std::error_code Ec; + const GcStorageSize TotalSize = m_GcManager.TotalStorageSize(); + + if (Timeout && Status() == GcSchedulerStatus::kIdle) + { + DiskSpace Space = DiskSpaceInfo(m_Config.RootDirectory, Ec); + if (Ec) + { + ZEN_WARN("get disk space info FAILED, reason: '{}'", Ec.message()); + } + + const int64_t PressureGraphLength = 30; + const std::chrono::duration LoadGraphTime = PressureGraphLength * m_Config.MonitorInterval; + std::vector<uint64_t> DiskDeltas; + uint64_t MaxLoad = 0; + { + const GcClock::Tick EpochTickCount = GcClock::Now().time_since_epoch().count(); + std::unique_lock Lock(m_GcMutex); + m_DiskUsageWindow.Append({.SampleTime = EpochTickCount, .DiskUsage = TotalSize.DiskSize}); + m_DiskUsageLog.Append({.SampleTime = EpochTickCount, .DiskUsage = TotalSize.DiskSize}); + const GcClock::TimePoint LoadGraphStartTime = Now - LoadGraphTime; + GcClock::Tick Start = LoadGraphStartTime.time_since_epoch().count(); + GcClock::Tick End = Now.time_since_epoch().count(); + DiskDeltas = m_DiskUsageWindow.GetDiskDeltas(Start, + End, + Max(1, (End - Start + PressureGraphLength - 1) / PressureGraphLength), + MaxLoad); + } + + std::string LoadGraph; + LoadGraph.resize(DiskDeltas.size(), '0'); + if (DiskDeltas.size() > 0 && MaxLoad > 0) + { + char LoadIndicator[11] = "0123456789"; + for (size_t Index = 0; Index < DiskDeltas.size(); ++Index) + { + size_t LoadIndex = (9 * DiskDeltas[Index] + MaxLoad - 1) / MaxLoad; + LoadGraph[Index] = LoadIndicator[LoadIndex]; + } + } + + uint64_t GcDiskSpaceGoal = 0; + if (DiskSizeSoftLimit != 0 && TotalSize.DiskSize > DiskSizeSoftLimit) + { + GcDiskSpaceGoal = TotalSize.DiskSize - DiskSizeSoftLimit; + std::unique_lock Lock(m_GcMutex); + GcClock::Tick AgeTick = m_DiskUsageWindow.FindTimepointThatRemoves(GcDiskSpaceGoal, Now.time_since_epoch().count()); + GcClock::TimePoint SizeBasedExpireTime = GcClock::TimePointFromTick(AgeTick); + if (SizeBasedExpireTime > ExpireTime) + { + ExpireTime = SizeBasedExpireTime; + } + } + + bool DiskSpaceGCTriggered = GcDiskSpaceGoal > 0; + + std::chrono::seconds RemaingTime = std::chrono::duration_cast<std::chrono::seconds>(m_NextGcTime - GcClock::Now()); + + if (RemaingTime < std::chrono::seconds::zero()) + { + RemaingTime = std::chrono::seconds::zero(); + } + + bool TimeBasedGCTriggered = !DiskSpaceGCTriggered && RemaingTime.count() == 0; + ZEN_INFO( + "{} in use,{} {} of total {} free disk space, disk writes last {} per {} [{}], peak {}/s. {}", + NiceBytes(TotalSize.DiskSize), + DiskSizeSoftLimit == 0 ? "" : fmt::format(" {} soft limit,", NiceBytes(DiskSizeSoftLimit)), + NiceBytes(Space.Free), + NiceBytes(Space.Total), + NiceTimeSpanMs(uint64_t(std::chrono::milliseconds(LoadGraphTime).count())), + NiceTimeSpanMs(uint64_t(std::chrono::milliseconds(LoadGraphTime).count() / PressureGraphLength)), + LoadGraph, + NiceBytes(MaxLoad * uint64_t(std::chrono::seconds(1).count()) / uint64_t(std::chrono::seconds(LoadGraphTime).count())), + DiskSpaceGCTriggered ? fmt::format("Disk use threshold triggered, trying to reclaim {}. ", NiceBytes(GcDiskSpaceGoal)) + : TimeBasedGCTriggered ? "GC schedule triggered." + : m_NextGcTime == GcClock::TimePoint::max() + ? "" + : fmt::format("{} until next scheduled GC.", NiceTimeSpanMs(uint64_t(std::chrono::milliseconds(RemaingTime).count())))); + + if (!DiskSpaceGCTriggered && !TimeBasedGCTriggered) + { + WaitTime = m_Config.MonitorInterval < RemaingTime ? m_Config.MonitorInterval : RemaingTime; + continue; + } + + WaitTime = m_Config.MonitorInterval; + uint32_t IdleState = static_cast<uint32_t>(GcSchedulerStatus::kIdle); + if (!m_Status.compare_exchange_strong(IdleState, static_cast<uint32_t>(GcSchedulerStatus::kRunning))) + { + continue; + } + } + + CollectGarbage(ExpireTime, Delete, CollectSmallObjects); + + uint32_t RunningState = static_cast<uint32_t>(GcSchedulerStatus::kRunning); + if (!m_Status.compare_exchange_strong(RunningState, static_cast<uint32_t>(GcSchedulerStatus::kIdle))) + { + ZEN_ASSERT(m_Status == static_cast<uint32_t>(GcSchedulerStatus::kStopped)); + break; + } + + WaitTime = m_Config.MonitorInterval; + } +} + +GcClock::TimePoint +GcScheduler::NextGcTime(GcClock::TimePoint CurrentTime) +{ + if (m_Config.Interval.count()) + { + return CurrentTime + m_Config.Interval; + } + else + { + return GcClock::TimePoint::max(); + } +} + +void +GcScheduler::CollectGarbage(const GcClock::TimePoint& ExpireTime, bool Delete, bool CollectSmallObjects) +{ + GcContext GcCtx(ExpireTime); + GcCtx.SetDeletionMode(Delete); + GcCtx.CollectSmallObjects(CollectSmallObjects); + // GcCtx.MaxCacheDuration(MaxCacheDuration); + GcCtx.DiskReservePath(m_Config.RootDirectory / "reserve.gc"); + + ZEN_INFO("garbage collection STARTING, small objects gc {}, cutoff time {}", + GcCtx.CollectSmallObjects() ? "ENABLED"sv : "DISABLED"sv, + ExpireTime); + { + Stopwatch Timer; + const auto __ = MakeGuard([&] { ZEN_INFO("garbage collection DONE in {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); }); + + m_GcManager.CollectGarbage(GcCtx); + + if (Delete) + { + m_LastGcExpireTime = ExpireTime; + std::unique_lock Lock(m_GcMutex); + m_DiskUsageWindow.KeepRange(ExpireTime.time_since_epoch().count(), GcClock::Duration::max().count()); + } + + m_LastGcTime = GcClock::Now(); + m_NextGcTime = NextGcTime(m_LastGcTime); + + { + const fs::path Path = m_Config.RootDirectory / "gc_state"; + ZEN_DEBUG("saving scheduler state to '{}'", Path); + CbObjectWriter SchedulerState; + SchedulerState << "LastGcTime"sv << static_cast<int64_t>(m_LastGcTime.time_since_epoch().count()); + SchedulerState << "LastGcExpireTime"sv << static_cast<int64_t>(m_LastGcExpireTime.time_since_epoch().count()); + SaveCompactBinaryObject(Path, SchedulerState.Save()); + } + + std::error_code Ec = CreateGCReserve(m_Config.RootDirectory / "reserve.gc", m_Config.DiskReserveSize); + if (Ec) + { + ZEN_WARN("unable to create GC reserve at '{}' with size {}, reason: '{}'", + m_Config.RootDirectory / "reserve.gc", + NiceBytes(m_Config.DiskReserveSize), + Ec.message()); + } + } +} + +////////////////////////////////////////////////////////////////////////// + +#if ZEN_WITH_TESTS + +namespace gc::impl { + static IoBuffer CreateChunk(uint64_t Size) + { + static std::random_device rd; + static std::mt19937 g(rd()); + + std::vector<uint8_t> Values; + Values.resize(Size); + for (size_t Idx = 0; Idx < Size; ++Idx) + { + Values[Idx] = static_cast<uint8_t>(Idx); + } + std::shuffle(Values.begin(), Values.end(), g); + + return IoBufferBuilder::MakeCloneFromMemory(Values.data(), Values.size()); + } + + static CompressedBuffer Compress(IoBuffer Buffer) + { + return CompressedBuffer::Compress(SharedBuffer::MakeView(Buffer.GetData(), Buffer.GetSize())); + } +} // namespace gc::impl + +TEST_CASE("gc.basic") +{ + using namespace gc::impl; + + ScopedTemporaryDirectory TempDir; + + CidStoreConfiguration CasConfig; + CasConfig.RootDirectory = TempDir.Path() / "cas"; + + GcManager Gc; + CidStore CidStore(Gc); + + CidStore.Initialize(CasConfig); + + IoBuffer Chunk = CreateChunk(128); + auto CompressedChunk = Compress(Chunk); + + const auto InsertResult = CidStore.AddChunk(CompressedChunk.GetCompressed().Flatten().AsIoBuffer(), CompressedChunk.DecodeRawHash()); + CHECK(InsertResult.New); + + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + + CidStore.Flush(); + Gc.CollectGarbage(GcCtx); + + CHECK(!CidStore.ContainsChunk(CompressedChunk.DecodeRawHash())); +} + +TEST_CASE("gc.full") +{ + using namespace gc::impl; + + ScopedTemporaryDirectory TempDir; + + CidStoreConfiguration CasConfig; + CasConfig.RootDirectory = TempDir.Path() / "cas"; + + GcManager Gc; + std::unique_ptr<CasStore> CasStore = CreateCasStore(Gc); + + CasStore->Initialize(CasConfig); + + uint64_t ChunkSizes[9] = {128, 541, 1023, 781, 218, 37, 4, 997, 5}; + IoBuffer Chunks[9] = {CreateChunk(ChunkSizes[0]), + CreateChunk(ChunkSizes[1]), + CreateChunk(ChunkSizes[2]), + CreateChunk(ChunkSizes[3]), + CreateChunk(ChunkSizes[4]), + CreateChunk(ChunkSizes[5]), + CreateChunk(ChunkSizes[6]), + CreateChunk(ChunkSizes[7]), + CreateChunk(ChunkSizes[8])}; + IoHash ChunkHashes[9] = { + IoHash::HashBuffer(Chunks[0].Data(), Chunks[0].Size()), + IoHash::HashBuffer(Chunks[1].Data(), Chunks[1].Size()), + IoHash::HashBuffer(Chunks[2].Data(), Chunks[2].Size()), + IoHash::HashBuffer(Chunks[3].Data(), Chunks[3].Size()), + IoHash::HashBuffer(Chunks[4].Data(), Chunks[4].Size()), + IoHash::HashBuffer(Chunks[5].Data(), Chunks[5].Size()), + IoHash::HashBuffer(Chunks[6].Data(), Chunks[6].Size()), + IoHash::HashBuffer(Chunks[7].Data(), Chunks[7].Size()), + IoHash::HashBuffer(Chunks[8].Data(), Chunks[8].Size()), + }; + + CasStore->InsertChunk(Chunks[0], ChunkHashes[0]); + CasStore->InsertChunk(Chunks[1], ChunkHashes[1]); + CasStore->InsertChunk(Chunks[2], ChunkHashes[2]); + CasStore->InsertChunk(Chunks[3], ChunkHashes[3]); + CasStore->InsertChunk(Chunks[4], ChunkHashes[4]); + CasStore->InsertChunk(Chunks[5], ChunkHashes[5]); + CasStore->InsertChunk(Chunks[6], ChunkHashes[6]); + CasStore->InsertChunk(Chunks[7], ChunkHashes[7]); + CasStore->InsertChunk(Chunks[8], ChunkHashes[8]); + + CidStoreSize InitialSize = CasStore->TotalSize(); + + // Keep first and last + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + + std::vector<IoHash> KeepChunks; + KeepChunks.push_back(ChunkHashes[0]); + KeepChunks.push_back(ChunkHashes[8]); + GcCtx.AddRetainedCids(KeepChunks); + + CasStore->Flush(); + Gc.CollectGarbage(GcCtx); + + CHECK(CasStore->ContainsChunk(ChunkHashes[0])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[1])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[2])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[3])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[4])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[5])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[6])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[7])); + CHECK(CasStore->ContainsChunk(ChunkHashes[8])); + + CHECK(ChunkHashes[0] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[0]))); + CHECK(ChunkHashes[8] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[8]))); + } + + CasStore->InsertChunk(Chunks[1], ChunkHashes[1]); + CasStore->InsertChunk(Chunks[2], ChunkHashes[2]); + CasStore->InsertChunk(Chunks[3], ChunkHashes[3]); + CasStore->InsertChunk(Chunks[4], ChunkHashes[4]); + CasStore->InsertChunk(Chunks[5], ChunkHashes[5]); + CasStore->InsertChunk(Chunks[6], ChunkHashes[6]); + CasStore->InsertChunk(Chunks[7], ChunkHashes[7]); + + // Keep last + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + std::vector<IoHash> KeepChunks; + KeepChunks.push_back(ChunkHashes[8]); + GcCtx.AddRetainedCids(KeepChunks); + + CasStore->Flush(); + Gc.CollectGarbage(GcCtx); + + CHECK(!CasStore->ContainsChunk(ChunkHashes[0])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[1])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[2])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[3])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[4])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[5])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[6])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[7])); + CHECK(CasStore->ContainsChunk(ChunkHashes[8])); + + CHECK(ChunkHashes[8] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[8]))); + + CasStore->InsertChunk(Chunks[1], ChunkHashes[1]); + CasStore->InsertChunk(Chunks[2], ChunkHashes[2]); + CasStore->InsertChunk(Chunks[3], ChunkHashes[3]); + CasStore->InsertChunk(Chunks[4], ChunkHashes[4]); + CasStore->InsertChunk(Chunks[5], ChunkHashes[5]); + CasStore->InsertChunk(Chunks[6], ChunkHashes[6]); + CasStore->InsertChunk(Chunks[7], ChunkHashes[7]); + } + + // Keep mixed + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + std::vector<IoHash> KeepChunks; + KeepChunks.push_back(ChunkHashes[1]); + KeepChunks.push_back(ChunkHashes[4]); + KeepChunks.push_back(ChunkHashes[7]); + GcCtx.AddRetainedCids(KeepChunks); + + CasStore->Flush(); + Gc.CollectGarbage(GcCtx); + + CHECK(!CasStore->ContainsChunk(ChunkHashes[0])); + CHECK(CasStore->ContainsChunk(ChunkHashes[1])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[2])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[3])); + CHECK(CasStore->ContainsChunk(ChunkHashes[4])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[5])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[6])); + CHECK(CasStore->ContainsChunk(ChunkHashes[7])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[8])); + + CHECK(ChunkHashes[1] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[1]))); + CHECK(ChunkHashes[4] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[4]))); + CHECK(ChunkHashes[7] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[7]))); + + CasStore->InsertChunk(Chunks[0], ChunkHashes[0]); + CasStore->InsertChunk(Chunks[2], ChunkHashes[2]); + CasStore->InsertChunk(Chunks[3], ChunkHashes[3]); + CasStore->InsertChunk(Chunks[5], ChunkHashes[5]); + CasStore->InsertChunk(Chunks[6], ChunkHashes[6]); + CasStore->InsertChunk(Chunks[8], ChunkHashes[8]); + } + + // Keep multiple at end + { + GcContext GcCtx(GcClock::Now() - std::chrono::hours(24)); + GcCtx.CollectSmallObjects(true); + std::vector<IoHash> KeepChunks; + KeepChunks.push_back(ChunkHashes[6]); + KeepChunks.push_back(ChunkHashes[7]); + KeepChunks.push_back(ChunkHashes[8]); + GcCtx.AddRetainedCids(KeepChunks); + + CasStore->Flush(); + Gc.CollectGarbage(GcCtx); + + CHECK(!CasStore->ContainsChunk(ChunkHashes[0])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[1])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[2])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[3])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[4])); + CHECK(!CasStore->ContainsChunk(ChunkHashes[5])); + CHECK(CasStore->ContainsChunk(ChunkHashes[6])); + CHECK(CasStore->ContainsChunk(ChunkHashes[7])); + CHECK(CasStore->ContainsChunk(ChunkHashes[8])); + + CHECK(ChunkHashes[6] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[6]))); + CHECK(ChunkHashes[7] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[7]))); + CHECK(ChunkHashes[8] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[8]))); + + CasStore->InsertChunk(Chunks[0], ChunkHashes[0]); + CasStore->InsertChunk(Chunks[1], ChunkHashes[1]); + CasStore->InsertChunk(Chunks[2], ChunkHashes[2]); + CasStore->InsertChunk(Chunks[3], ChunkHashes[3]); + CasStore->InsertChunk(Chunks[4], ChunkHashes[4]); + CasStore->InsertChunk(Chunks[5], ChunkHashes[5]); + } + + // Verify that we nicely appended blocks even after all GC operations + CHECK(ChunkHashes[0] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[0]))); + CHECK(ChunkHashes[1] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[1]))); + CHECK(ChunkHashes[2] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[2]))); + CHECK(ChunkHashes[3] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[3]))); + CHECK(ChunkHashes[4] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[4]))); + CHECK(ChunkHashes[5] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[5]))); + CHECK(ChunkHashes[6] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[6]))); + CHECK(ChunkHashes[7] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[7]))); + CHECK(ChunkHashes[8] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[8]))); + + auto FinalSize = CasStore->TotalSize(); + + CHECK_LE(InitialSize.TinySize, FinalSize.TinySize); + CHECK_GE(InitialSize.TinySize + (1u << 28), FinalSize.TinySize); +} + +TEST_CASE("gc.diskusagewindow") +{ + using namespace gc::impl; + + DiskUsageWindow Stats; + Stats.Append({.SampleTime = 0, .DiskUsage = 0}); // 0 0 + Stats.Append({.SampleTime = 10, .DiskUsage = 10}); // 1 10 + Stats.Append({.SampleTime = 20, .DiskUsage = 20}); // 2 10 + Stats.Append({.SampleTime = 30, .DiskUsage = 20}); // 3 0 + Stats.Append({.SampleTime = 40, .DiskUsage = 15}); // 4 0 + Stats.Append({.SampleTime = 50, .DiskUsage = 25}); // 5 10 + Stats.Append({.SampleTime = 60, .DiskUsage = 30}); // 6 5 + Stats.Append({.SampleTime = 70, .DiskUsage = 45}); // 7 15 + + SUBCASE("Truncate start") + { + Stats.KeepRange(-15, 31); + CHECK(Stats.m_LogWindow.size() == 4); + CHECK(Stats.m_LogWindow[0].SampleTime == 0); + CHECK(Stats.m_LogWindow[3].SampleTime == 30); + } + + SUBCASE("Truncate end") + { + Stats.KeepRange(70, 71); + CHECK(Stats.m_LogWindow.size() == 1); + CHECK(Stats.m_LogWindow[0].SampleTime == 70); + } + + SUBCASE("Truncate middle") + { + Stats.KeepRange(29, 69); + CHECK(Stats.m_LogWindow.size() == 4); + CHECK(Stats.m_LogWindow[0].SampleTime == 30); + CHECK(Stats.m_LogWindow[3].SampleTime == 60); + } + + SUBCASE("Full range") + { + uint64_t MaxDelta = 0; + // 0-10, 10-20, 20-30, 30-40, 40-50, 50-60, 60-70, 70-80 + std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(0, 80, 10, MaxDelta); + CHECK(DiskDeltas.size() == 8); + CHECK(MaxDelta == 15); + CHECK(DiskDeltas[0] == 0); + CHECK(DiskDeltas[1] == 10); + CHECK(DiskDeltas[2] == 10); + CHECK(DiskDeltas[3] == 0); + CHECK(DiskDeltas[4] == 0); + CHECK(DiskDeltas[5] == 10); + CHECK(DiskDeltas[6] == 5); + CHECK(DiskDeltas[7] == 15); + } + + SUBCASE("Sub range") + { + uint64_t MaxDelta = 0; + std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(20, 40, 10, MaxDelta); + CHECK(DiskDeltas.size() == 2); + CHECK(MaxDelta == 10); + CHECK(DiskDeltas[0] == 10); // [20:30] + CHECK(DiskDeltas[1] == 0); // [30:40] + } + SUBCASE("Unaligned sub range 1") + { + uint64_t MaxDelta = 0; + std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(21, 51, 10, MaxDelta); + CHECK(DiskDeltas.size() == 3); + CHECK(MaxDelta == 10); + CHECK(DiskDeltas[0] == 0); // [21:31] + CHECK(DiskDeltas[1] == 0); // [31:41] + CHECK(DiskDeltas[2] == 10); // [41:51] + } + SUBCASE("Unaligned end range") + { + uint64_t MaxDelta = 0; + std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(29, 79, 10, MaxDelta); + CHECK(DiskDeltas.size() == 5); + CHECK(MaxDelta == 15); + CHECK(DiskDeltas[0] == 0); // [29:39] + CHECK(DiskDeltas[1] == 0); // [39:49] + CHECK(DiskDeltas[2] == 10); // [49:59] + CHECK(DiskDeltas[3] == 5); // [59:69] + CHECK(DiskDeltas[4] == 15); // [69:79] + } + SUBCASE("Ahead of window") + { + uint64_t MaxDelta = 0; + std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(-40, 0, 10, MaxDelta); + CHECK(DiskDeltas.size() == 4); + CHECK(MaxDelta == 0); + CHECK(DiskDeltas[0] == 0); // [-40:-30] + CHECK(DiskDeltas[1] == 0); // [-30:-20] + CHECK(DiskDeltas[2] == 0); // [-20:-10] + CHECK(DiskDeltas[3] == 0); // [-10:0] + } + SUBCASE("After of window") + { + uint64_t MaxDelta = 0; + std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(90, 120, 10, MaxDelta); + CHECK(DiskDeltas.size() == 3); + CHECK(MaxDelta == 0); + CHECK(DiskDeltas[0] == 0); // [90:100] + CHECK(DiskDeltas[1] == 0); // [100:110] + CHECK(DiskDeltas[2] == 0); // [110:120] + } + SUBCASE("Encapsulating window") + { + uint64_t MaxDelta = 0; + std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(-20, 100, 10, MaxDelta); + CHECK(DiskDeltas.size() == 12); + CHECK(MaxDelta == 15); + CHECK(DiskDeltas[0] == 0); // [-20:-10] + CHECK(DiskDeltas[1] == 0); // [ -10:0] + CHECK(DiskDeltas[2] == 0); // [0:10] + CHECK(DiskDeltas[3] == 10); // [10:20] + CHECK(DiskDeltas[4] == 10); // [20:30] + CHECK(DiskDeltas[5] == 0); // [30:40] + CHECK(DiskDeltas[6] == 0); // [40:50] + CHECK(DiskDeltas[7] == 10); // [50:60] + CHECK(DiskDeltas[8] == 5); // [60:70] + CHECK(DiskDeltas[9] == 15); // [70:80] + CHECK(DiskDeltas[10] == 0); // [80:90] + CHECK(DiskDeltas[11] == 0); // [90:100] + } + + SUBCASE("Full range half stride") + { + uint64_t MaxDelta = 0; + std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(0, 80, 20, MaxDelta); + CHECK(DiskDeltas.size() == 4); + CHECK(MaxDelta == 20); + CHECK(DiskDeltas[0] == 10); // [0:20] + CHECK(DiskDeltas[1] == 10); // [20:40] + CHECK(DiskDeltas[2] == 10); // [40:60] + CHECK(DiskDeltas[3] == 20); // [60:80] + } + + SUBCASE("Partial odd stride") + { + uint64_t MaxDelta = 0; + std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(13, 67, 18, MaxDelta); + CHECK(DiskDeltas.size() == 3); + CHECK(MaxDelta == 15); + CHECK(DiskDeltas[0] == 10); // [13:31] + CHECK(DiskDeltas[1] == 0); // [31:49] + CHECK(DiskDeltas[2] == 15); // [49:67] + } + + SUBCASE("Find size window") + { + DiskUsageWindow Empty; + CHECK(Empty.FindTimepointThatRemoves(15u, 10000) == 10000); + + CHECK(Stats.FindTimepointThatRemoves(15u, 40) == 21); + CHECK(Stats.FindTimepointThatRemoves(15u, 20) == 20); + CHECK(Stats.FindTimepointThatRemoves(100000u, 50) == 50); + CHECK(Stats.FindTimepointThatRemoves(100000u, 1000)); + } +} +#endif + +void +gc_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenstore/hashkeyset.cpp b/src/zenstore/hashkeyset.cpp new file mode 100644 index 000000000..a5436f5cb --- /dev/null +++ b/src/zenstore/hashkeyset.cpp @@ -0,0 +1,60 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenstore/hashkeyset.h> + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +void +HashKeySet::AddHashToSet(const IoHash& HashToAdd) +{ + m_HashSet.insert(HashToAdd); +} + +void +HashKeySet::AddHashesToSet(std::span<const IoHash> HashesToAdd) +{ + m_HashSet.insert(HashesToAdd.begin(), HashesToAdd.end()); +} + +void +HashKeySet::RemoveHashesIf(std::function<bool(const IoHash& CandidateHash)>&& Predicate) +{ + for (auto It = begin(m_HashSet), ItEnd = end(m_HashSet); It != ItEnd;) + { + if (Predicate(*It)) + { + It = m_HashSet.erase(It); + } + else + { + ++It; + } + } +} + +void +HashKeySet::IterateHashes(std::function<void(const IoHash& Hash)>&& Callback) const +{ + for (auto It = begin(m_HashSet), ItEnd = end(m_HashSet); It != ItEnd; ++It) + { + Callback(*It); + } +} + +////////////////////////////////////////////////////////////////////////// +// +// Testing related code follows... +// + +#if ZEN_WITH_TESTS + +void +hashkeyset_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenstore/include/zenstore/blockstore.h b/src/zenstore/include/zenstore/blockstore.h new file mode 100644 index 000000000..857ccae38 --- /dev/null +++ b/src/zenstore/include/zenstore/blockstore.h @@ -0,0 +1,175 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/filesystem.h> +#include <zencore/zencore.h> +#include <zenutil/basicfile.h> + +#include <unordered_map> +#include <unordered_set> + +namespace zen { + +////////////////////////////////////////////////////////////////////////// + +struct BlockStoreLocation +{ + uint32_t BlockIndex; + uint64_t Offset; + uint64_t Size; + + inline auto operator<=>(const BlockStoreLocation& Rhs) const = default; +}; + +#pragma pack(push) +#pragma pack(1) + +struct BlockStoreDiskLocation +{ + constexpr static uint32_t MaxBlockIndexBits = 20; + constexpr static uint32_t MaxOffsetBits = 28; + constexpr static uint32_t MaxBlockIndex = (1ul << BlockStoreDiskLocation::MaxBlockIndexBits) - 1ul; + constexpr static uint32_t MaxOffset = (1ul << BlockStoreDiskLocation::MaxOffsetBits) - 1ul; + + BlockStoreDiskLocation(const BlockStoreLocation& Location, uint64_t OffsetAlignment) + { + Init(Location.BlockIndex, Location.Offset / OffsetAlignment, Location.Size); + } + + BlockStoreDiskLocation() = default; + + inline BlockStoreLocation Get(uint64_t OffsetAlignment) const + { + uint64_t PackedOffset = 0; + memcpy(&PackedOffset, &m_Offset, sizeof m_Offset); + return {.BlockIndex = static_cast<std::uint32_t>(PackedOffset >> MaxOffsetBits), + .Offset = (PackedOffset & MaxOffset) * OffsetAlignment, + .Size = GetSize()}; + } + + inline uint32_t GetBlockIndex() const + { + uint64_t PackedOffset = 0; + memcpy(&PackedOffset, &m_Offset, sizeof m_Offset); + return static_cast<std::uint32_t>(PackedOffset >> MaxOffsetBits); + } + + inline uint64_t GetOffset(uint64_t OffsetAlignment) const + { + uint64_t PackedOffset = 0; + memcpy(&PackedOffset, &m_Offset, sizeof m_Offset); + return (PackedOffset & MaxOffset) * OffsetAlignment; + } + + inline uint64_t GetSize() const { return m_Size; } + + inline auto operator<=>(const BlockStoreDiskLocation& Rhs) const = default; + +private: + inline void Init(uint32_t BlockIndex, uint64_t Offset, uint64_t Size) + { + ZEN_ASSERT(BlockIndex <= MaxBlockIndex); + ZEN_ASSERT(Offset <= MaxOffset); + ZEN_ASSERT(Size <= std::numeric_limits<std::uint32_t>::max()); + + m_Size = static_cast<uint32_t>(Size); + uint64_t PackedOffset = (static_cast<uint64_t>(BlockIndex) << MaxOffsetBits) + Offset; + memcpy(&m_Offset[0], &PackedOffset, sizeof m_Offset); + } + + uint32_t m_Size; + uint8_t m_Offset[6]; +}; + +#pragma pack(pop) + +struct BlockStoreFile : public RefCounted +{ + explicit BlockStoreFile(const std::filesystem::path& BlockPath); + ~BlockStoreFile(); + const std::filesystem::path& GetPath() const; + void Open(); + void Create(uint64_t InitialSize); + void MarkAsDeleteOnClose(); + uint64_t FileSize(); + IoBuffer GetChunk(uint64_t Offset, uint64_t Size); + void Read(void* Data, uint64_t Size, uint64_t FileOffset); + void Write(const void* Data, uint64_t Size, uint64_t FileOffset); + void Flush(); + BasicFile& GetBasicFile(); + void StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun); + +private: + const std::filesystem::path m_Path; + IoBuffer m_IoBuffer; + BasicFile m_File; +}; + +class BlockStore +{ +public: + struct ReclaimSnapshotState + { + std::unordered_set<uint32_t> m_ActiveWriteBlocks; + size_t BlockCount; + }; + + typedef std::vector<std::pair<size_t, BlockStoreLocation>> MovedChunksArray; + typedef std::vector<size_t> ChunkIndexArray; + + typedef std::function<void(const MovedChunksArray& MovedChunks, const ChunkIndexArray& RemovedChunks)> ReclaimCallback; + typedef std::function<uint64_t()> ClaimDiskReserveCallback; + typedef std::function<void(size_t ChunkIndex, const void* Data, uint64_t Size)> IterateChunksSmallSizeCallback; + typedef std::function<void(size_t ChunkIndex, BlockStoreFile& File, uint64_t Offset, uint64_t Size)> IterateChunksLargeSizeCallback; + typedef std::function<void(const BlockStoreLocation& Location)> WriteChunkCallback; + + void Initialize(const std::filesystem::path& BlocksBasePath, + uint64_t MaxBlockSize, + uint64_t MaxBlockCount, + const std::vector<BlockStoreLocation>& KnownLocations); + void Close(); + + void WriteChunk(const void* Data, uint64_t Size, uint64_t Alignment, const WriteChunkCallback& Callback); + + IoBuffer TryGetChunk(const BlockStoreLocation& Location) const; + void Flush(); + + ReclaimSnapshotState GetReclaimSnapshotState(); + void ReclaimSpace( + const ReclaimSnapshotState& Snapshot, + const std::vector<BlockStoreLocation>& ChunkLocations, + const ChunkIndexArray& KeepChunkIndexes, + uint64_t PayloadAlignment, + bool DryRun, + const ReclaimCallback& ChangeCallback = [](const MovedChunksArray&, const ChunkIndexArray&) {}, + const ClaimDiskReserveCallback& DiskReserveCallback = []() { return 0; }); + + void IterateChunks(const std::vector<BlockStoreLocation>& ChunkLocations, + const IterateChunksSmallSizeCallback& SmallSizeCallback, + const IterateChunksLargeSizeCallback& LargeSizeCallback); + + static const char* GetBlockFileExtension(); + static std::filesystem::path GetBlockPath(const std::filesystem::path& BlocksBasePath, const uint32_t BlockIndex); + + inline uint64_t TotalSize() const { return m_TotalSize.load(std::memory_order::relaxed); } + +private: + std::unordered_map<uint32_t, Ref<BlockStoreFile>> m_ChunkBlocks; + + mutable RwLock m_InsertLock; // used to serialize inserts + Ref<BlockStoreFile> m_WriteBlock; + std::uint64_t m_CurrentInsertOffset = 0; + std::atomic_uint32_t m_WriteBlockIndex{}; + std::vector<uint32_t> m_ActiveWriteBlocks; + + uint64_t m_MaxBlockSize = 1u << 28; + uint64_t m_MaxBlockCount = BlockStoreDiskLocation::MaxBlockIndex + 1; + std::filesystem::path m_BlocksBasePath; + + std::atomic_uint64_t m_TotalSize{}; +}; + +void blockstore_forcelink(); + +} // namespace zen diff --git a/src/zenstore/include/zenstore/caslog.h b/src/zenstore/include/zenstore/caslog.h new file mode 100644 index 000000000..d8c3f22f3 --- /dev/null +++ b/src/zenstore/include/zenstore/caslog.h @@ -0,0 +1,91 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/uid.h> +#include <zenutil/basicfile.h> + +namespace zen { + +class CasLogFile +{ +public: + CasLogFile(); + ~CasLogFile(); + + enum class Mode + { + kRead, + kWrite, + kTruncate + }; + + static bool IsValid(std::filesystem::path FileName, size_t RecordSize); + void Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode); + void Append(const void* DataPointer, uint64_t DataSize); + void Replay(std::function<void(const void*)>&& Handler, uint64_t SkipEntryCount); + void Flush(); + void Close(); + uint64_t GetLogSize(); + uint64_t GetLogCount(); + +private: + struct FileHeader + { + uint8_t Magic[16]; + uint32_t RecordSize = 0; + Oid LogId; + uint32_t ValidatedTail = 0; + uint32_t Pad[6]; + uint32_t Checksum = 0; + + static const inline uint8_t MagicSequence[16] = {'.', '-', '=', ' ', 'C', 'A', 'S', 'L', 'O', 'G', 'v', '1', ' ', '=', '-', '.'}; + + ZENCORE_API uint32_t ComputeChecksum(); + void Finalize() { Checksum = ComputeChecksum(); } + }; + + static_assert(sizeof(FileHeader) == 64); + +private: + void Open(std::filesystem::path FileName, size_t RecordSize, BasicFile::Mode Mode); + + BasicFile m_File; + FileHeader m_Header; + size_t m_RecordSize = 1; + std::atomic<uint64_t> m_AppendOffset = 0; +}; + +template<typename T> +class TCasLogFile : public CasLogFile +{ +public: + static bool IsValid(std::filesystem::path FileName) { return CasLogFile::IsValid(FileName, sizeof(T)); } + void Open(std::filesystem::path FileName, Mode Mode) { CasLogFile::Open(FileName, sizeof(T), Mode); } + + // This should be called before the Replay() is called to do some basic sanity checking + bool Initialize() { return true; } + + void Replay(Invocable<const T&> auto Handler, uint64_t SkipEntryCount) + { + CasLogFile::Replay( + [&](const void* VoidPtr) { + const T& Record = *reinterpret_cast<const T*>(VoidPtr); + + Handler(Record); + }, + SkipEntryCount); + } + + void Append(const T& Record) + { + // TODO: implement some more efficent path here so we don't end up with + // a syscall per append + + CasLogFile::Append(&Record, sizeof Record); + } + + void Append(const std::span<T>& Records) { CasLogFile::Append(Records.data(), sizeof(T) * Records.size()); } +}; + +} // namespace zen diff --git a/src/zenstore/include/zenstore/cidstore.h b/src/zenstore/include/zenstore/cidstore.h new file mode 100644 index 000000000..16ca78225 --- /dev/null +++ b/src/zenstore/include/zenstore/cidstore.h @@ -0,0 +1,87 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenstore.h" + +#include <zencore/iohash.h> +#include <zenstore/hashkeyset.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <tsl/robin_map.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <filesystem> + +namespace zen { + +class GcManager; +class CasStore; +class CompressedBuffer; +class IoBuffer; +class ScrubContext; + +/** Content Store + * + * Data in the content store is referenced by content identifiers (CIDs), it works + * with compressed buffers so the CID is expected to be the RAW hash. It stores the + * chunk directly under the RAW hash. + * This class maps uncompressed hashes (CIDs) to compressed hashes and may + * be used to deal with other kinds of indirections in the future. For example, if we want + * to support chunking then a CID may represent a list of chunks which could be concatenated + * to form the referenced chunk. + * + */ + +struct CidStoreSize +{ + uint64_t TinySize = 0; + uint64_t SmallSize = 0; + uint64_t LargeSize = 0; + uint64_t TotalSize = 0; +}; + +struct CidStoreConfiguration +{ + // Root directory for CAS store + std::filesystem::path RootDirectory; + + // Threshold below which values are considered 'tiny' and managed using the 'tiny values' strategy + uint64_t TinyValueThreshold = 1024; + + // Threshold above which values are considered 'huge' and managed using the 'huge values' strategy + uint64_t HugeValueThreshold = 1024 * 1024; +}; + +class CidStore +{ +public: + CidStore(GcManager& Gc); + ~CidStore(); + + struct InsertResult + { + bool New = false; + }; + enum class InsertMode + { + kCopyOnly, + kMayBeMovedInPlace + }; + + void Initialize(const CidStoreConfiguration& Config); + InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, InsertMode Mode = InsertMode::kMayBeMovedInPlace); + IoBuffer FindChunkByCid(const IoHash& DecompressedId); + bool ContainsChunk(const IoHash& DecompressedId); + void FilterChunks(HashKeySet& InOutChunks); + void Flush(); + void Scrub(ScrubContext& Ctx); + CidStoreSize TotalSize() const; + +private: + struct Impl; + std::unique_ptr<CasStore> m_CasStore; + std::unique_ptr<Impl> m_Impl; +}; + +} // namespace zen diff --git a/src/zenstore/include/zenstore/gc.h b/src/zenstore/include/zenstore/gc.h new file mode 100644 index 000000000..e0354b331 --- /dev/null +++ b/src/zenstore/include/zenstore/gc.h @@ -0,0 +1,242 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> +#include <zencore/thread.h> +#include <zenstore/caslog.h> + +#include <atomic> +#include <chrono> +#include <condition_variable> +#include <filesystem> +#include <functional> +#include <optional> +#include <span> +#include <thread> + +#define ZEN_USE_REF_TRACKING 0 // This is not currently functional + +namespace spdlog { +class logger; +} + +namespace zen { + +class HashKeySet; +class GcManager; +class CidStore; +struct IoHash; + +/** GC clock + */ +class GcClock +{ +public: + using Clock = std::chrono::system_clock; + using TimePoint = Clock::time_point; + using Duration = Clock::duration; + using Tick = int64_t; + + static Tick TickCount() { return Now().time_since_epoch().count(); } + static TimePoint Now() { return Clock::now(); } + static TimePoint TimePointFromTick(const Tick TickCount) { return TimePoint{Duration{TickCount}}; } +}; + +/** Garbage Collection context object + */ +class GcContext +{ +public: + GcContext(const GcClock::TimePoint& ExpireTime); + ~GcContext(); + + void AddRetainedCids(std::span<const IoHash> Cid); + void SetExpiredCacheKeys(const std::string& CacheKeyContext, std::vector<IoHash>&& ExpiredKeys); + + void IterateCids(std::function<void(const IoHash&)> Callback); + + void FilterCids(std::span<const IoHash> Cid, std::function<void(const IoHash&)> KeepFunc); + void FilterCids(std::span<const IoHash> Cid, std::function<void(const IoHash&, bool)>&& FilterFunc); + + void AddDeletedCids(std::span<const IoHash> Cas); + const HashKeySet& DeletedCids(); + + std::span<const IoHash> ExpiredCacheKeys(const std::string& CacheKeyContext) const; + + bool IsDeletionMode() const; + void SetDeletionMode(bool NewState); + + bool CollectSmallObjects() const; + void CollectSmallObjects(bool NewState); + + GcClock::TimePoint ExpireTime() const; + + void DiskReservePath(const std::filesystem::path& Path); + uint64_t ClaimGCReserve(); + +private: + struct GcState; + + std::unique_ptr<GcState> m_State; +}; + +/** GC root contributor + + Higher level data structures provide roots for the garbage collector, + which ultimately determine what is garbage and what data we need to + retain. + + */ +class GcContributor +{ +public: + GcContributor(GcManager& Gc); + ~GcContributor(); + + virtual void GatherReferences(GcContext& GcCtx) = 0; + +protected: + GcManager& m_Gc; +}; + +struct GcStorageSize +{ + uint64_t DiskSize{}; + uint64_t MemorySize{}; +}; + +/** GC storage provider + */ +class GcStorage +{ +public: + GcStorage(GcManager& Gc); + ~GcStorage(); + + virtual void CollectGarbage(GcContext& GcCtx) = 0; + virtual GcStorageSize StorageSize() const = 0; + +private: + GcManager& m_Gc; +}; + +/** GC orchestrator + */ +class GcManager +{ +public: + GcManager(); + ~GcManager(); + + void AddGcContributor(GcContributor* Contributor); + void RemoveGcContributor(GcContributor* Contributor); + + void AddGcStorage(GcStorage* Contributor); + void RemoveGcStorage(GcStorage* Contributor); + + void CollectGarbage(GcContext& GcCtx); + + GcStorageSize TotalStorageSize() const; + +#if ZEN_USE_REF_TRACKING + void OnNewCidReferences(std::span<IoHash> Hashes); + void OnCommittedCidReferences(std::span<IoHash> Hashes); + void OnDroppedCidReferences(std::span<IoHash> Hashes); +#endif + +private: + spdlog::logger& Log() { return m_Log; } + spdlog::logger& m_Log; + mutable RwLock m_Lock; + std::vector<GcContributor*> m_GcContribs; + std::vector<GcStorage*> m_GcStorage; + CidStore* m_CidStore = nullptr; +}; + +enum class GcSchedulerStatus : uint32_t +{ + kIdle, + kRunning, + kStopped +}; + +struct GcSchedulerConfig +{ + std::filesystem::path RootDirectory; + std::chrono::seconds MonitorInterval{30}; + std::chrono::seconds Interval{}; + std::chrono::seconds MaxCacheDuration{86400}; + bool CollectSmallObjects = true; + bool Enabled = true; + uint64_t DiskReserveSize = 1ul << 28; + uint64_t DiskSizeSoftLimit = 0; +}; + +class DiskUsageWindow +{ +public: + struct DiskUsageEntry + { + GcClock::Tick SampleTime; + uint64_t DiskUsage; + }; + + std::vector<DiskUsageEntry> m_LogWindow; + inline void Append(const DiskUsageEntry& Entry) { m_LogWindow.push_back(Entry); } + inline void Append(DiskUsageEntry&& Entry) { m_LogWindow.emplace_back(std::move(Entry)); } + void KeepRange(GcClock::Tick StartTick, GcClock::Tick EndTick); + std::vector<uint64_t> GetDiskDeltas(GcClock::Tick StartTick, + GcClock::Tick EndTick, + GcClock::Tick DeltaWidth, + uint64_t& OutMaxDelta) const; + GcClock::Tick FindTimepointThatRemoves(uint64_t Amount, GcClock::Tick EndTick) const; +}; + +/** + * GC scheduler + */ +class GcScheduler +{ +public: + GcScheduler(GcManager& GcManager); + ~GcScheduler(); + + void Initialize(const GcSchedulerConfig& Config); + void Shutdown(); + GcSchedulerStatus Status() const { return static_cast<GcSchedulerStatus>(m_Status.load()); } + + struct TriggerParams + { + bool CollectSmallObjects = false; + std::chrono::seconds MaxCacheDuration = std::chrono::seconds::max(); + uint64_t DiskSizeSoftLimit = 0; + }; + + bool Trigger(const TriggerParams& Params); + +private: + void SchedulerThread(); + void CollectGarbage(const GcClock::TimePoint& ExpireTime, bool Delete, bool CollectSmallObjects); + GcClock::TimePoint NextGcTime(GcClock::TimePoint CurrentTime); + spdlog::logger& Log() { return m_Log; } + + spdlog::logger& m_Log; + GcManager& m_GcManager; + GcSchedulerConfig m_Config; + GcClock::TimePoint m_LastGcTime{}; + GcClock::TimePoint m_LastGcExpireTime{}; + GcClock::TimePoint m_NextGcTime{}; + std::atomic_uint32_t m_Status{}; + std::thread m_GcThread; + std::mutex m_GcMutex; + std::condition_variable m_GcSignal; + std::optional<TriggerParams> m_TriggerParams; + + TCasLogFile<DiskUsageWindow::DiskUsageEntry> m_DiskUsageLog; + DiskUsageWindow m_DiskUsageWindow; +}; + +void gc_forcelink(); + +} // namespace zen diff --git a/src/zenstore/include/zenstore/hashkeyset.h b/src/zenstore/include/zenstore/hashkeyset.h new file mode 100644 index 000000000..411a6256e --- /dev/null +++ b/src/zenstore/include/zenstore/hashkeyset.h @@ -0,0 +1,54 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include "zenstore.h" + +#include <zencore/iohash.h> + +#include <functional> +#include <unordered_set> + +namespace zen { + +/** Manage a set of IoHash values + */ + +class HashKeySet +{ +public: + void AddHashToSet(const IoHash& HashToAdd); + void AddHashesToSet(std::span<const IoHash> HashesToAdd); + void RemoveHashesIf(std::function<bool(const IoHash& CandidateHash)>&& Predicate); + void IterateHashes(std::function<void(const IoHash& Hash)>&& Callback) const; + [[nodiscard]] inline bool ContainsHash(const IoHash& Hash) const { return m_HashSet.find(Hash) != m_HashSet.end(); } + [[nodiscard]] inline bool IsEmpty() const { return m_HashSet.empty(); } + [[nodiscard]] inline size_t GetSize() const { return m_HashSet.size(); } + + inline void FilterHashes(std::span<const IoHash> Candidates, Invocable<const IoHash&> auto MatchFunc) const + { + for (const IoHash& Candidate : Candidates) + { + if (ContainsHash(Candidate)) + { + MatchFunc(Candidate); + } + } + } + + inline void FilterHashes(std::span<const IoHash> Candidates, Invocable<const IoHash&, bool> auto MatchFunc) const + { + for (const IoHash& Candidate : Candidates) + { + MatchFunc(Candidate, ContainsHash(Candidate)); + } + } + +private: + // Q: should we protect this with a lock, or is that a higher level concern? + std::unordered_set<IoHash, IoHash::Hasher> m_HashSet; +}; + +void hashkeyset_forcelink(); + +} // namespace zen diff --git a/src/zenstore/include/zenstore/scrubcontext.h b/src/zenstore/include/zenstore/scrubcontext.h new file mode 100644 index 000000000..0b884fcc6 --- /dev/null +++ b/src/zenstore/include/zenstore/scrubcontext.h @@ -0,0 +1,41 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/timer.h> +#include <zenstore/hashkeyset.h> + +namespace zen { + +/** Context object for data scrubbing + * + * Data scrubbing is when we traverse stored data to validate it and + * optionally correct/recover + */ + +class ScrubContext +{ +public: + virtual void ReportBadCidChunks(std::span<IoHash> BadCasChunks) { m_BadCid.AddHashesToSet(BadCasChunks); } + inline uint64_t ScrubTimestamp() const { return m_ScrubTime; } + inline bool RunRecovery() const { return m_Recover; } + void ReportScrubbed(uint64_t ChunkCount, uint64_t ChunkBytes) + { + m_ChunkCount.fetch_add(ChunkCount); + m_ByteCount.fetch_add(ChunkBytes); + } + + inline uint64_t ScrubbedChunks() const { return m_ChunkCount; } + inline uint64_t ScrubbedBytes() const { return m_ByteCount; } + + const HashKeySet BadCids() const { return m_BadCid; } + +private: + uint64_t m_ScrubTime = GetHifreqTimerValue(); + bool m_Recover = true; + std::atomic<uint64_t> m_ChunkCount{0}; + std::atomic<uint64_t> m_ByteCount{0}; + HashKeySet m_BadCid; +}; + +} // namespace zen diff --git a/src/zenstore/include/zenstore/zenstore.h b/src/zenstore/include/zenstore/zenstore.h new file mode 100644 index 000000000..46d62029d --- /dev/null +++ b/src/zenstore/include/zenstore/zenstore.h @@ -0,0 +1,13 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/zencore.h> + +#define ZENSTORE_API + +namespace zen { + +ZENSTORE_API void zenstore_forcelinktests(); + +} diff --git a/src/zenstore/xmake.lua b/src/zenstore/xmake.lua new file mode 100644 index 000000000..4469c5650 --- /dev/null +++ b/src/zenstore/xmake.lua @@ -0,0 +1,9 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zenstore') + set_kind("static") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore", "zenutil") + add_packages("vcpkg::robin-map") diff --git a/src/zenstore/zenstore.cpp b/src/zenstore/zenstore.cpp new file mode 100644 index 000000000..d87652fde --- /dev/null +++ b/src/zenstore/zenstore.cpp @@ -0,0 +1,32 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenstore/zenstore.h" + +#if ZEN_WITH_TESTS + +# include <zenstore/blockstore.h> +# include <zenstore/gc.h> +# include <zenstore/hashkeyset.h> +# include <zenutil/basicfile.h> + +# include "cas.h" +# include "compactcas.h" +# include "filecas.h" + +namespace zen { + +void +zenstore_forcelinktests() +{ + basicfile_forcelink(); + CAS_forcelink(); + filecas_forcelink(); + blockstore_forcelink(); + compactcas_forcelink(); + gc_forcelink(); + hashkeyset_forcelink(); +} + +} // namespace zen + +#endif diff --git a/src/zentest-appstub/xmake.lua b/src/zentest-appstub/xmake.lua new file mode 100644 index 000000000..d8e0283c1 --- /dev/null +++ b/src/zentest-appstub/xmake.lua @@ -0,0 +1,16 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("zentest-appstub") + set_kind("binary") + add_headerfiles("**.h") + add_files("*.cpp") + + if is_os("linux") then + add_syslinks("pthread") + end + + if is_plat("macosx") then + add_ldflags("-framework CoreFoundation") + add_ldflags("-framework Security") + add_ldflags("-framework SystemConfiguration") + end diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp new file mode 100644 index 000000000..66e6e03fd --- /dev/null +++ b/src/zentest-appstub/zentest-appstub.cpp @@ -0,0 +1,34 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <stdio.h> +#include <cstdlib> +#include <cstring> +#include <thread> + +using namespace std::chrono_literals; + +int +main(int argc, char* argv[]) +{ + int ExitCode = 0; + + for (int i = 0; i < argc; ++i) + { + if (std::strncmp(argv[i], "-t=", 3) == 0) + { + const int SleepTime = std::atoi(argv[i] + 3); + + printf("[zentest] sleeping for %ds...\n", SleepTime); + + std::this_thread::sleep_for(SleepTime * 1s); + } + else if (std::strncmp(argv[i], "-f=", 3) == 0) + { + ExitCode = std::atoi(argv[i] + 3); + } + } + + printf("[zentest] exiting with exit code: %d\n", ExitCode); + + return ExitCode; +} diff --git a/src/zenutil/basicfile.cpp b/src/zenutil/basicfile.cpp new file mode 100644 index 000000000..1e6043d7e --- /dev/null +++ b/src/zenutil/basicfile.cpp @@ -0,0 +1,575 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenutil/basicfile.h" + +#include <zencore/compactbinary.h> +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/testing.h> +#include <zencore/testutils.h> + +#if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +#else +# include <fcntl.h> +# include <sys/file.h> +# include <sys/stat.h> +# include <unistd.h> +#endif + +#include <fmt/format.h> +#include <gsl/gsl-lite.hpp> + +namespace zen { + +BasicFile::~BasicFile() +{ + Close(); +} + +void +BasicFile::Open(const std::filesystem::path& FileName, Mode Mode) +{ + std::error_code Ec; + Open(FileName, Mode, Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("failed to open file '{}'", FileName)); + } +} + +void +BasicFile::Open(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec) +{ + Ec.clear(); + +#if ZEN_PLATFORM_WINDOWS + DWORD dwCreationDisposition = 0; + DWORD dwDesiredAccess = 0; + switch (Mode) + { + case Mode::kRead: + dwCreationDisposition |= OPEN_EXISTING; + dwDesiredAccess |= GENERIC_READ; + break; + case Mode::kWrite: + dwCreationDisposition |= OPEN_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE); + break; + case Mode::kDelete: + dwCreationDisposition |= OPEN_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE | DELETE); + break; + case Mode::kTruncate: + dwCreationDisposition |= CREATE_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE); + break; + case Mode::kTruncateDelete: + dwCreationDisposition |= CREATE_ALWAYS; + dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE | DELETE); + break; + } + + const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE; + const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL; + HANDLE hTemplateFile = nullptr; + + HANDLE FileHandle = CreateFile(FileName.c_str(), + dwDesiredAccess, + dwShareMode, + /* lpSecurityAttributes */ nullptr, + dwCreationDisposition, + dwFlagsAndAttributes, + hTemplateFile); + + if (FileHandle == INVALID_HANDLE_VALUE) + { + Ec = MakeErrorCodeFromLastError(); + + return; + } +#else + int OpenFlags = O_CLOEXEC; + switch (Mode) + { + case Mode::kRead: + OpenFlags |= O_RDONLY; + break; + case Mode::kWrite: + case Mode::kDelete: + OpenFlags |= (O_RDWR | O_CREAT); + break; + case Mode::kTruncate: + case Mode::kTruncateDelete: + OpenFlags |= (O_RDWR | O_CREAT | O_TRUNC); + break; + } + + int Fd = open(FileName.c_str(), OpenFlags, 0666); + if (Fd < 0) + { + Ec = MakeErrorCodeFromLastError(); + return; + } + if (Mode != Mode::kRead) + { + fchmod(Fd, 0666); + } + + void* FileHandle = (void*)(uintptr_t(Fd)); +#endif + + m_FileHandle = FileHandle; +} + +void +BasicFile::Close() +{ + if (m_FileHandle) + { +#if ZEN_PLATFORM_WINDOWS + ::CloseHandle(m_FileHandle); +#else + int Fd = int(uintptr_t(m_FileHandle)); + close(Fd); +#endif + m_FileHandle = nullptr; + } +} + +void +BasicFile::Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset) +{ + const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024; + + while (BytesToRead) + { + const uint64_t NumberOfBytesToRead = Min(BytesToRead, MaxChunkSize); + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(FileOffset >> 32); + + DWORD dwNumberOfBytesRead = 0; + BOOL Success = ::ReadFile(m_FileHandle, Data, DWORD(NumberOfBytesToRead), &dwNumberOfBytesRead, &Ovl); + + ZEN_ASSERT(dwNumberOfBytesRead == NumberOfBytesToRead); +#else + static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); + int Fd = int(uintptr_t(m_FileHandle)); + int BytesRead = pread(Fd, Data, NumberOfBytesToRead, FileOffset); + bool Success = (BytesRead > 0); +#endif + + if (!Success) + { + ThrowLastError(fmt::format("Failed to read from file '{}'", zen::PathFromHandle(m_FileHandle))); + } + + BytesToRead -= NumberOfBytesToRead; + FileOffset += NumberOfBytesToRead; + Data = reinterpret_cast<uint8_t*>(Data) + NumberOfBytesToRead; + } +} + +IoBuffer +BasicFile::ReadAll() +{ + IoBuffer Buffer(FileSize()); + Read(Buffer.MutableData(), Buffer.Size(), 0); + return Buffer; +} + +void +BasicFile::StreamFile(std::function<void(const void* Data, uint64_t Size)>&& ChunkFun) +{ + StreamByteRange(0, FileSize(), std::move(ChunkFun)); +} + +void +BasicFile::StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun) +{ + const uint64_t ChunkSize = 128 * 1024; + IoBuffer ReadBuffer{ChunkSize}; + void* BufferPtr = ReadBuffer.MutableData(); + + uint64_t RemainBytes = Size; + uint64_t CurrentOffset = FileOffset; + + while (RemainBytes) + { + const uint64_t ThisChunkBytes = zen::Min(ChunkSize, RemainBytes); + + Read(BufferPtr, ThisChunkBytes, CurrentOffset); + + ChunkFun(BufferPtr, ThisChunkBytes); + + CurrentOffset += ThisChunkBytes; + RemainBytes -= ThisChunkBytes; + } +} + +void +BasicFile::Write(MemoryView Data, uint64_t FileOffset, std::error_code& Ec) +{ + Write(Data.GetData(), Data.GetSize(), FileOffset, Ec); +} + +void +BasicFile::Write(const void* Data, uint64_t Size, uint64_t FileOffset, std::error_code& Ec) +{ + Ec.clear(); + + const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024; + + while (Size) + { + const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize); + +#if ZEN_PLATFORM_WINDOWS + OVERLAPPED Ovl{}; + + Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu); + Ovl.OffsetHigh = DWORD(FileOffset >> 32); + + DWORD dwNumberOfBytesWritten = 0; + + BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl); +#else + static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files"); + int Fd = int(uintptr_t(m_FileHandle)); + int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, FileOffset); + bool Success = (BytesWritten > 0); +#endif + + if (!Success) + { + Ec = MakeErrorCodeFromLastError(); + + return; + } + + Size -= NumberOfBytesToWrite; + FileOffset += NumberOfBytesToWrite; + Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesToWrite; + } +} + +void +BasicFile::Write(MemoryView Data, uint64_t FileOffset) +{ + Write(Data.GetData(), Data.GetSize(), FileOffset); +} + +void +BasicFile::Write(const void* Data, uint64_t Size, uint64_t Offset) +{ + std::error_code Ec; + Write(Data, Size, Offset, Ec); + + if (Ec) + { + throw std::system_error(Ec, fmt::format("Failed to write to file '{}'", zen::PathFromHandle(m_FileHandle))); + } +} + +void +BasicFile::WriteAll(IoBuffer Data, std::error_code& Ec) +{ + Write(Data.Data(), Data.Size(), 0, Ec); +} + +void +BasicFile::Flush() +{ +#if ZEN_PLATFORM_WINDOWS + FlushFileBuffers(m_FileHandle); +#else + int Fd = int(uintptr_t(m_FileHandle)); + fsync(Fd); +#endif +} + +uint64_t +BasicFile::FileSize() +{ +#if ZEN_PLATFORM_WINDOWS + ULARGE_INTEGER liFileSize; + liFileSize.LowPart = ::GetFileSize(m_FileHandle, &liFileSize.HighPart); + if (liFileSize.LowPart == INVALID_FILE_SIZE) + { + int Error = zen::GetLastError(); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to get file size from file '{}'", PathFromHandle(m_FileHandle))); + } + } + return uint64_t(liFileSize.QuadPart); +#else + int Fd = int(uintptr_t(m_FileHandle)); + static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files"); + struct stat Stat; + fstat(Fd, &Stat); + return uint64_t(Stat.st_size); +#endif +} + +void +BasicFile::SetFileSize(uint64_t FileSize) +{ +#if ZEN_PLATFORM_WINDOWS + LARGE_INTEGER liFileSize; + liFileSize.QuadPart = FileSize; + BOOL OK = ::SetFilePointerEx(m_FileHandle, liFileSize, 0, FILE_BEGIN); + if (OK == FALSE) + { + int Error = zen::GetLastError(); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to set file pointer to {} for file {}", FileSize, PathFromHandle(m_FileHandle))); + } + } + OK = ::SetEndOfFile(m_FileHandle); + if (OK == FALSE) + { + int Error = zen::GetLastError(); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to set end of file to {} for file {}", FileSize, PathFromHandle(m_FileHandle))); + } + } +#elif ZEN_PLATFORM_MAC + int Fd = int(intptr_t(m_FileHandle)); + if (ftruncate(Fd, (off_t)FileSize) < 0) + { + int Error = zen::GetLastError(); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to set truncate file to {} for file {}", FileSize, PathFromHandle(m_FileHandle))); + } + } +#else + int Fd = int(intptr_t(m_FileHandle)); + if (ftruncate64(Fd, (off64_t)FileSize) < 0) + { + int Error = zen::GetLastError(); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to set truncate file to {} for file {}", FileSize, PathFromHandle(m_FileHandle))); + } + } + if (FileSize > 0) + { + int Error = posix_fallocate64(Fd, 0, (off64_t)FileSize); + if (Error) + { + ThrowSystemError(Error, fmt::format("Failed to allocate space of {} for file {}", FileSize, PathFromHandle(m_FileHandle))); + } + } +#endif +} + +void* +BasicFile::Detach() +{ + void* FileHandle = m_FileHandle; + m_FileHandle = 0; + return FileHandle; +} + +////////////////////////////////////////////////////////////////////////// + +TemporaryFile::~TemporaryFile() +{ + Close(); +} + +void +TemporaryFile::Close() +{ + if (m_FileHandle) + { +#if ZEN_PLATFORM_WINDOWS + // Mark file for deletion when final handle is closed + + FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE}; + + SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi); +#else + std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle); + unlink(FilePath.c_str()); +#endif + + BasicFile::Close(); + } +} + +void +TemporaryFile::CreateTemporary(std::filesystem::path TempDirName, std::error_code& Ec) +{ + StringBuilder<64> TempName; + Oid::NewOid().ToString(TempName); + + m_TempPath = TempDirName / TempName.c_str(); + + Open(m_TempPath, BasicFile::Mode::kTruncateDelete, Ec); +} + +void +TemporaryFile::MoveTemporaryIntoPlace(std::filesystem::path FinalFileName, std::error_code& Ec) +{ + // We intentionally call the base class Close() since otherwise we'll end up + // deleting the temporary file + BasicFile::Close(); + + std::filesystem::rename(m_TempPath, FinalFileName, Ec); +} + +////////////////////////////////////////////////////////////////////////// + +LockFile::LockFile() +{ +} + +LockFile::~LockFile() +{ +#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Fd = int(intptr_t(m_FileHandle)); + flock(Fd, LOCK_UN | LOCK_NB); +#endif +} + +void +LockFile::Create(std::filesystem::path FileName, CbObject Payload, std::error_code& Ec) +{ +#if ZEN_PLATFORM_WINDOWS + Ec.clear(); + + const DWORD dwCreationDisposition = CREATE_ALWAYS; + DWORD dwDesiredAccess = GENERIC_READ | GENERIC_WRITE | DELETE; + const DWORD dwShareMode = FILE_SHARE_READ; + const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_DELETE_ON_CLOSE; + HANDLE hTemplateFile = nullptr; + + HANDLE FileHandle = CreateFile(FileName.c_str(), + dwDesiredAccess, + dwShareMode, + /* lpSecurityAttributes */ nullptr, + dwCreationDisposition, + dwFlagsAndAttributes, + hTemplateFile); + + if (FileHandle == INVALID_HANDLE_VALUE) + { + Ec = zen::MakeErrorCodeFromLastError(); + + return; + } +#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC + int Fd = open(FileName.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666); + if (Fd < 0) + { + Ec = zen::MakeErrorCodeFromLastError(); + return; + } + fchmod(Fd, 0666); + + int LockRet = flock(Fd, LOCK_EX | LOCK_NB); + if (LockRet < 0) + { + Ec = zen::MakeErrorCodeFromLastError(); + close(Fd); + return; + } + + void* FileHandle = (void*)uintptr_t(Fd); +#endif + + m_FileHandle = FileHandle; + + BasicFile::Write(Payload.GetBuffer(), 0, Ec); +} + +void +LockFile::Update(CbObject Payload, std::error_code& Ec) +{ + BasicFile::Write(Payload.GetBuffer(), 0, Ec); +} + +/* + ___________ __ + \__ ___/___ _______/ |_ ______ + | |_/ __ \ / ___/\ __\/ ___/ + | |\ ___/ \___ \ | | \___ \ + |____| \___ >____ > |__| /____ > + \/ \/ \/ +*/ + +#if ZEN_WITH_TESTS + +TEST_CASE("BasicFile") +{ + ScopedCurrentDirectoryChange _; + + BasicFile File1; + CHECK_THROWS(File1.Open("zonk", BasicFile::Mode::kRead)); + CHECK_NOTHROW(File1.Open("zonk", BasicFile::Mode::kTruncate)); + CHECK_NOTHROW(File1.Write("abcd", 4, 0)); + CHECK(File1.FileSize() == 4); + { + IoBuffer Data = File1.ReadAll(); + CHECK(Data.Size() == 4); + CHECK_EQ(memcmp(Data.Data(), "abcd", 4), 0); + } + CHECK_NOTHROW(File1.Write("efgh", 4, 2)); + CHECK(File1.FileSize() == 6); + { + IoBuffer Data = File1.ReadAll(); + CHECK(Data.Size() == 6); + CHECK_EQ(memcmp(Data.Data(), "abefgh", 6), 0); + } +} + +TEST_CASE("TemporaryFile") +{ + ScopedCurrentDirectoryChange _; + + SUBCASE("DeleteOnClose") + { + TemporaryFile TmpFile; + std::error_code Ec; + TmpFile.CreateTemporary(std::filesystem::current_path(), Ec); + CHECK(!Ec); + CHECK(std::filesystem::exists(TmpFile.GetPath())); + TmpFile.Close(); + CHECK(std::filesystem::exists(TmpFile.GetPath()) == false); + } + + SUBCASE("MoveIntoPlace") + { + TemporaryFile TmpFile; + std::error_code Ec; + TmpFile.CreateTemporary(std::filesystem::current_path(), Ec); + CHECK(!Ec); + std::filesystem::path TempPath = TmpFile.GetPath(); + std::filesystem::path FinalPath = std::filesystem::current_path() / "final"; + CHECK(std::filesystem::exists(TempPath)); + TmpFile.MoveTemporaryIntoPlace(FinalPath, Ec); + CHECK(!Ec); + CHECK(std::filesystem::exists(TempPath) == false); + CHECK(std::filesystem::exists(FinalPath)); + } +} + +void +basicfile_forcelink() +{ +} + +#endif + +} // namespace zen diff --git a/src/zenutil/cache/cachekey.cpp b/src/zenutil/cache/cachekey.cpp new file mode 100644 index 000000000..545b47f11 --- /dev/null +++ b/src/zenutil/cache/cachekey.cpp @@ -0,0 +1,9 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cache/cachekey.h> + +namespace zen { + +const CacheKey CacheKey::Empty = CacheKey{.Bucket = std::string(), .Hash = IoHash()}; + +} // namespace zen diff --git a/src/zenutil/cache/cachepolicy.cpp b/src/zenutil/cache/cachepolicy.cpp new file mode 100644 index 000000000..3bca363bb --- /dev/null +++ b/src/zenutil/cache/cachepolicy.cpp @@ -0,0 +1,282 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cache/cachepolicy.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/enumflags.h> +#include <zencore/string.h> + +#include <algorithm> +#include <unordered_map> + +namespace zen::Private { +class CacheRecordPolicyShared; +} + +namespace zen { + +using namespace std::literals; + +namespace DerivedData::Private { + + constinit char CachePolicyDelimiter = ','; + + struct CachePolicyToTextData + { + CachePolicy Policy; + std::string_view Text; + }; + + constinit CachePolicyToTextData CachePolicyToText[]{ + // Flags with multiple bits are ordered by bit count to minimize token count in the text format. + {CachePolicy::Default, "Default"sv}, + {CachePolicy::Remote, "Remote"sv}, + {CachePolicy::Local, "Local"sv}, + {CachePolicy::Store, "Store"sv}, + {CachePolicy::Query, "Query"sv}, + // Flags with only one bit can be in any order. Match the order in CachePolicy. + {CachePolicy::QueryLocal, "QueryLocal"sv}, + {CachePolicy::QueryRemote, "QueryRemote"sv}, + {CachePolicy::StoreLocal, "StoreLocal"sv}, + {CachePolicy::StoreRemote, "StoreRemote"sv}, + {CachePolicy::SkipMeta, "SkipMeta"sv}, + {CachePolicy::SkipData, "SkipData"sv}, + {CachePolicy::PartialRecord, "PartialRecord"sv}, + {CachePolicy::KeepAlive, "KeepAlive"sv}, + // None must be last because it matches every policy. + {CachePolicy::None, "None"sv}, + }; + + constinit CachePolicy CachePolicyKnownFlags = + CachePolicy::Default | CachePolicy::SkipMeta | CachePolicy::SkipData | CachePolicy::PartialRecord | CachePolicy::KeepAlive; + + StringBuilderBase& CachePolicyToString(StringBuilderBase& Builder, CachePolicy Policy) + { + // Mask out unknown flags. None will be written if no flags are known. + Policy &= CachePolicyKnownFlags; + for (const CachePolicyToTextData& Pair : CachePolicyToText) + { + if (EnumHasAllFlags(Policy, Pair.Policy)) + { + EnumRemoveFlags(Policy, Pair.Policy); + Builder << Pair.Text << CachePolicyDelimiter; + if (Policy == CachePolicy::None) + { + break; + } + } + } + Builder.RemoveSuffix(1); + return Builder; + } + + CachePolicy ParseCachePolicy(const std::string_view Text) + { + ZEN_ASSERT(!Text.empty()); // ParseCachePolicy requires a non-empty string + CachePolicy Policy = CachePolicy::None; + ForEachStrTok(Text, CachePolicyDelimiter, [&Policy, Index = int32_t(0)](const std::string_view& Token) mutable { + const int32_t EndIndex = Index; + for (; size_t(Index) < sizeof(CachePolicyToText) / sizeof(CachePolicyToText[0]); ++Index) + { + if (CachePolicyToText[Index].Text == Token) + { + Policy |= CachePolicyToText[Index].Policy; + ++Index; + return true; + } + } + for (Index = 0; Index < EndIndex; ++Index) + { + if (CachePolicyToText[Index].Text == Token) + { + Policy |= CachePolicyToText[Index].Policy; + ++Index; + return true; + } + } + return true; + }); + return Policy; + } + +} // namespace DerivedData::Private + +StringBuilderBase& +operator<<(StringBuilderBase& Builder, CachePolicy Policy) +{ + return DerivedData::Private::CachePolicyToString(Builder, Policy); +} + +CachePolicy +ParseCachePolicy(std::string_view Text) +{ + return DerivedData::Private::ParseCachePolicy(Text); +} + +CachePolicy +ConvertToUpstream(CachePolicy Policy) +{ + // Set Local flags equal to downstream's Remote flags. + // Delete Skip flags if StoreLocal is true, otherwise use the downstream value. + // Use the downstream value for all other flags. + + CachePolicy UpstreamPolicy = CachePolicy::None; + + if (EnumHasAllFlags(Policy, CachePolicy::QueryRemote)) + { + UpstreamPolicy |= CachePolicy::QueryLocal; + } + + if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote)) + { + UpstreamPolicy |= CachePolicy::StoreLocal; + } + + if (!EnumHasAllFlags(Policy, CachePolicy::StoreLocal)) + { + UpstreamPolicy |= (Policy & (CachePolicy::SkipData | CachePolicy::SkipMeta)); + } + + UpstreamPolicy |= Policy & ~(CachePolicy::Local | CachePolicy::SkipData | CachePolicy::SkipMeta); + + return UpstreamPolicy; +} + +class Private::CacheRecordPolicyShared final : public Private::ICacheRecordPolicyShared +{ +public: + inline void AddValuePolicy(const CacheValuePolicy& Value) final + { + ZEN_ASSERT(Value.Id); // Failed to add value policy because the ID is null. + const auto Insert = + std::lower_bound(Values.begin(), Values.end(), Value, [](const CacheValuePolicy& Existing, const CacheValuePolicy& New) { + return Existing.Id < New.Id; + }); + ZEN_ASSERT( + !(Insert < Values.end() && + Insert->Id == Value.Id)); // Failed to add value policy with ID %s because it has an existing value policy with that ID. ") + Values.insert(Insert, Value); + } + + inline std::span<const CacheValuePolicy> GetValuePolicies() const final { return Values; } + +private: + std::vector<CacheValuePolicy> Values; +}; + +CachePolicy +CacheRecordPolicy::GetValuePolicy(const Oid& Id) const +{ + if (Shared) + { + const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); + const auto Iter = + std::lower_bound(Values.begin(), Values.end(), Id, [](const CacheValuePolicy& A, const Oid& B) { return A.Id < B; }); + if (Iter != Values.end() && Iter->Id == Id) + { + return Iter->Policy; + } + } + return DefaultValuePolicy; +} + +void +CacheRecordPolicy::Save(CbWriter& Writer) const +{ + Writer.BeginObject(); + // The RecordPolicy is calculated from the ValuePolicies and does not need to be saved separately. + Writer.AddString("BasePolicy"sv, WriteToString<128>(GetBasePolicy())); + if (!IsUniform()) + { + Writer.BeginArray("ValuePolicies"sv); + for (const CacheValuePolicy& Value : GetValuePolicies()) + { + Writer.BeginObject(); + Writer.AddObjectId("Id"sv, Value.Id); + Writer.AddString("Policy"sv, WriteToString<128>(Value.Policy)); + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); +} + +OptionalCacheRecordPolicy +CacheRecordPolicy::Load(const CbObjectView Object) +{ + std::string_view BasePolicyText = Object["BasePolicy"sv].AsString(); + if (BasePolicyText.empty()) + { + return {}; + } + + CacheRecordPolicyBuilder Builder(ParseCachePolicy(BasePolicyText)); + for (CbFieldView ValueField : Object["ValuePolicies"sv]) + { + const CbObjectView Value = ValueField.AsObjectView(); + const Oid Id = Value["Id"sv].AsObjectId(); + const std::string_view PolicyText = Value["Policy"sv].AsString(); + if (!Id || PolicyText.empty()) + { + return {}; + } + CachePolicy Policy = ParseCachePolicy(PolicyText); + if (EnumHasAnyFlags(Policy, ~CacheValuePolicy::PolicyMask)) + { + return {}; + } + Builder.AddValuePolicy(Id, Policy); + } + + return Builder.Build(); +} + +CacheRecordPolicy +CacheRecordPolicy::ConvertToUpstream() const +{ + CacheRecordPolicyBuilder Builder(zen::ConvertToUpstream(GetBasePolicy())); + for (const CacheValuePolicy& ValuePolicy : GetValuePolicies()) + { + Builder.AddValuePolicy(ValuePolicy.Id, zen::ConvertToUpstream(ValuePolicy.Policy)); + } + return Builder.Build(); +} + +void +CacheRecordPolicyBuilder::AddValuePolicy(const CacheValuePolicy& Value) +{ + ZEN_ASSERT(!EnumHasAnyFlags(Value.Policy, + ~Value.PolicyMask)); // Value policy contains flags that only make sense on the record policy. Policy: %s + if (Value.Policy == (BasePolicy & Value.PolicyMask)) + { + return; + } + if (!Shared) + { + Shared = new Private::CacheRecordPolicyShared; + } + Shared->AddValuePolicy(Value); +} + +CacheRecordPolicy +CacheRecordPolicyBuilder::Build() +{ + CacheRecordPolicy Policy(BasePolicy); + if (Shared) + { + const auto Add = [](const CachePolicy A, const CachePolicy B) { + return ((A | B) & ~CachePolicy::SkipData) | ((A & B) & CachePolicy::SkipData); + }; + const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies(); + Policy.RecordPolicy = BasePolicy; + for (const CacheValuePolicy& ValuePolicy : Values) + { + Policy.RecordPolicy = Add(Policy.RecordPolicy, ValuePolicy.Policy); + } + Policy.Shared = std::move(Shared); + } + return Policy; +} + +} // namespace zen diff --git a/src/zenutil/cache/cacherequests.cpp b/src/zenutil/cache/cacherequests.cpp new file mode 100644 index 000000000..4c865ec22 --- /dev/null +++ b/src/zenutil/cache/cacherequests.cpp @@ -0,0 +1,1643 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/cache/cacherequests.h> + +#include <zencore/compactbinary.h> +#include <zencore/compactbinarybuilder.h> +#include <zencore/compactbinarypackage.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/zencore.h> + +#include <string> +#include <string_view> + +#if ZEN_WITH_TESTS +# include <zencore/testing.h> +#endif + +namespace zen { + +namespace cacherequests { + + namespace { + constinit AsciiSet ValidNamespaceNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + constinit AsciiSet ValidBucketNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"}; + + std::optional<std::string> GetValidNamespaceName(std::string_view Name) + { + if (Name.empty()) + { + ZEN_WARN("Namespace is invalid, empty namespace is not allowed"); + return {}; + } + + if (Name.length() > 64) + { + ZEN_WARN("Namespace '{}' is invalid, length exceeds 64 characters", Name); + return {}; + } + + if (!AsciiSet::HasOnly(Name, ValidNamespaceNameCharactersSet)) + { + ZEN_WARN("Namespace '{}' is invalid, invalid characters detected", Name); + return {}; + } + + return ToLower(Name); + } + + std::optional<std::string> GetValidBucketName(std::string_view Name) + { + if (Name.empty()) + { + ZEN_WARN("Bucket name is invalid, empty bucket name is not allowed"); + return {}; + } + + if (!AsciiSet::HasOnly(Name, ValidBucketNameCharactersSet)) + { + ZEN_WARN("Bucket name '{}' is invalid, invalid characters detected", Name); + return {}; + } + + return ToLower(Name); + } + + std::optional<IoHash> GetValidIoHash(std::string_view Hash) + { + if (Hash.length() != IoHash::StringLength) + { + return {}; + } + + IoHash KeyHash; + if (!ParseHexBytes(Hash.data(), Hash.size(), KeyHash.Hash)) + { + return {}; + } + return KeyHash; + } + + std::optional<CacheRecordPolicy> Convert(const OptionalCacheRecordPolicy& Policy) + { + return Policy.IsValid() ? Policy.Get() : std::optional<CacheRecordPolicy>{}; + }; + } // namespace + + std::optional<std::string> GetRequestNamespace(const CbObjectView& Params) + { + CbFieldView NamespaceField = Params["Namespace"]; + if (!NamespaceField) + { + return std::string("!default!"); // ZenCacheStore::DefaultNamespace); + } + + if (NamespaceField.HasError()) + { + return {}; + } + if (!NamespaceField.IsString()) + { + return {}; + } + return GetValidNamespaceName(NamespaceField.AsString()); + } + + bool GetRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key) + { + CbFieldView BucketField = KeyView["Bucket"]; + if (BucketField.HasError()) + { + return false; + } + if (!BucketField.IsString()) + { + return false; + } + std::optional<std::string> Bucket = GetValidBucketName(BucketField.AsString()); + if (!Bucket.has_value()) + { + return false; + } + CbFieldView HashField = KeyView["Hash"]; + if (HashField.HasError()) + { + return false; + } + if (!HashField.IsHash()) + { + return false; + } + Key.Bucket = *Bucket; + Key.Hash = HashField.AsHash(); + return true; + } + + void WriteCacheRequestKey(CbObjectWriter& Writer, const CacheKey& Value) + { + Writer.BeginObject("Key"); + { + Writer << "Bucket" << Value.Bucket; + Writer << "Hash" << Value.Hash; + } + Writer.EndObject(); + } + + void WriteOptionalCacheRequestPolicy(CbObjectWriter& Writer, std::string_view FieldName, const std::optional<CacheRecordPolicy>& Policy) + { + if (Policy) + { + Writer.SetName(FieldName); + Policy->Save(Writer); + } + } + + std::optional<CachePolicy> GetCachePolicy(CbObjectView ObjectView, std::string_view FieldName) + { + std::string_view DefaultPolicyText = ObjectView[FieldName].AsString(); + if (DefaultPolicyText.empty()) + { + return {}; + } + return ParseCachePolicy(DefaultPolicyText); + } + + void WriteCachePolicy(CbObjectWriter& Writer, std::string_view FieldName, const std::optional<CachePolicy>& Policy) + { + if (Policy) + { + Writer << FieldName << WriteToString<128>(*Policy); + } + } + + bool PutCacheRecordsRequest::Parse(const CbPackage& Package) + { + CbObjectView BatchObject = Package.GetObject(); + ZEN_ASSERT(BatchObject["Method"].AsString() == "PutCacheRecords"); + AcceptMagic = BatchObject["AcceptType"].AsUInt32(0); + + CbObjectView Params = BatchObject["Params"].AsObjectView(); + std::optional<std::string> RequestNamespace = GetRequestNamespace(Params); + if (!RequestNamespace) + { + return false; + } + Namespace = *RequestNamespace; + DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default); + + CbArrayView RequestFieldArray = Params["Requests"].AsArrayView(); + Requests.resize(RequestFieldArray.Num()); + for (size_t RequestIndex = 0; CbFieldView RequestField : RequestFieldArray) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView RecordObject = RequestObject["Record"].AsObjectView(); + CbObjectView KeyView = RecordObject["Key"].AsObjectView(); + + PutCacheRecordRequest& Request = Requests[RequestIndex++]; + + if (!GetRequestCacheKey(KeyView, Request.Key)) + { + return false; + } + + Request.Policy = Convert(CacheRecordPolicy::Load(RequestObject["Policy"].AsObjectView())); + + std::unordered_map<IoHash, size_t, IoHash::Hasher> RawHashToAttachmentIndex; + + CbArrayView ValuesArray = RecordObject["Values"].AsArrayView(); + Request.Values.resize(ValuesArray.Num()); + RawHashToAttachmentIndex.reserve(ValuesArray.Num()); + for (size_t Index = 0; CbFieldView Value : ValuesArray) + { + CbObjectView ObjectView = Value.AsObjectView(); + IoHash AttachmentHash = ObjectView["RawHash"].AsHash(); + RawHashToAttachmentIndex[AttachmentHash] = Index; + Request.Values[Index++] = {.Id = ObjectView["Id"].AsObjectId(), .RawHash = AttachmentHash}; + } + + RecordObject.IterateAttachments([&](CbFieldView HashView) { + const IoHash ValueHash = HashView.AsHash(); + if (const CbAttachment* Attachment = Package.FindAttachment(ValueHash)) + { + if (Attachment->IsCompressedBinary()) + { + auto It = RawHashToAttachmentIndex.find(ValueHash); + ZEN_ASSERT(It != RawHashToAttachmentIndex.end()); + PutCacheRecordRequestValue& Value = Request.Values[It->second]; + ZEN_ASSERT(Value.RawHash == ValueHash); + Value.Body = Attachment->AsCompressedBinary(); + ZEN_ASSERT_SLOW(Value.Body.DecodeRawHash() == Value.RawHash); + } + } + }); + } + + return true; + } + + bool PutCacheRecordsRequest::Format(CbPackage& OutPackage) const + { + CbObjectWriter Writer; + Writer << "Method" + << "PutCacheRecords"; + if (AcceptMagic != 0) + { + Writer << "Accept" << AcceptMagic; + } + + Writer.BeginObject("Params"); + { + Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy); + Writer << "Namespace" << Namespace; + + Writer.BeginArray("Requests"); + for (const PutCacheRecordRequest& RecordRequest : Requests) + { + Writer.BeginObject(); + { + Writer.BeginObject("Record"); + { + WriteCacheRequestKey(Writer, RecordRequest.Key); + Writer.BeginArray("Values"); + for (const PutCacheRecordRequestValue& Value : RecordRequest.Values) + { + Writer.BeginObject(); + { + Writer.AddObjectId("Id", Value.Id); + const CompressedBuffer& Buffer = Value.Body; + if (Buffer) + { + IoHash AttachmentHash = Buffer.DecodeRawHash(); // TODO: Slow! + Writer.AddBinaryAttachment("RawHash", AttachmentHash); + OutPackage.AddAttachment(CbAttachment(Buffer, AttachmentHash)); + Writer.AddInteger("RawSize", Buffer.DecodeRawSize()); // TODO: Slow! + } + else + { + if (Value.RawHash == IoHash::Zero) + { + return false; + } + Writer.AddBinaryAttachment("RawHash", Value.RawHash); + } + } + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); + WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy); + } + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); + OutPackage.SetObject(Writer.Save()); + + return true; + } + + bool PutCacheRecordsResult::Parse(const CbPackage& Package) + { + CbArrayView ResultsArray = Package.GetObject()["Result"].AsArrayView(); + if (!ResultsArray) + { + return false; + } + CbFieldViewIterator It = ResultsArray.CreateViewIterator(); + while (It.HasValue()) + { + Success.push_back(It.AsBool()); + It++; + } + return true; + } + + bool PutCacheRecordsResult::Format(CbPackage& OutPackage) const + { + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"); + for (bool Value : Success) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + OutPackage.SetObject(ResponseObject.Save()); + return true; + } + + bool GetCacheRecordsRequest::Parse(const CbObjectView& RpcRequest) + { + ZEN_ASSERT(RpcRequest["Method"].AsString() == "GetCacheRecords"); + AcceptMagic = RpcRequest["AcceptType"].AsUInt32(0); + AcceptOptions = RpcRequest["AcceptFlags"].AsUInt16(0); + ProcessPid = RpcRequest["Pid"].AsInt32(0); + + CbObjectView Params = RpcRequest["Params"].AsObjectView(); + std::optional<std::string> RequestNamespace = GetRequestNamespace(Params); + if (!RequestNamespace) + { + return false; + } + + Namespace = *RequestNamespace; + DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default); + + CbArrayView RequestsArray = Params["Requests"].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + for (CbFieldView RequestField : RequestsArray) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"].AsObjectView(); + + GetCacheRecordRequest& Request = Requests.emplace_back(); + + if (!GetRequestCacheKey(KeyObject, Request.Key)) + { + return false; + } + + Request.Policy = Convert(CacheRecordPolicy::Load(RequestObject["Policy"].AsObjectView())); + } + return true; + } + + bool GetCacheRecordsRequest::Parse(const CbPackage& RpcRequest) { return Parse(RpcRequest.GetObject()); } + + bool GetCacheRecordsRequest::Format(CbObjectWriter& Writer, const std::span<const size_t> OptionalRecordFilter) const + { + Writer << "Method" + << "GetCacheRecords"; + if (AcceptMagic != 0) + { + Writer << "Accept" << AcceptMagic; + } + if (AcceptOptions != 0) + { + Writer << "AcceptFlags" << AcceptOptions; + } + if (ProcessPid != 0) + { + Writer << "Pid" << ProcessPid; + } + + Writer.BeginObject("Params"); + { + Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy); + Writer << "Namespace" << Namespace; + Writer.BeginArray("Requests"); + if (OptionalRecordFilter.empty()) + { + for (const GetCacheRecordRequest& RecordRequest : Requests) + { + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, RecordRequest.Key); + WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy); + } + Writer.EndObject(); + } + } + else + { + for (size_t Index : OptionalRecordFilter) + { + const GetCacheRecordRequest& RecordRequest = Requests[Index]; + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, RecordRequest.Key); + WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy); + } + Writer.EndObject(); + } + } + Writer.EndArray(); + } + Writer.EndObject(); + + return true; + } + + bool GetCacheRecordsRequest::Format(CbPackage& OutPackage, const std::span<const size_t> OptionalRecordFilter) const + { + CbObjectWriter Writer; + if (!Format(Writer, OptionalRecordFilter)) + { + return false; + } + OutPackage.SetObject(Writer.Save()); + return true; + } + + bool GetCacheRecordsResult::Parse(const CbPackage& Package, const std::span<const size_t> OptionalRecordResultIndexes) + { + CbObject ResponseObject = Package.GetObject(); + CbArrayView ResultsArray = ResponseObject["Result"].AsArrayView(); + if (!ResultsArray) + { + return false; + } + + Results.reserve(ResultsArray.Num()); + if (!OptionalRecordResultIndexes.empty() && ResultsArray.Num() != OptionalRecordResultIndexes.size()) + { + return false; + } + for (size_t Index = 0; CbFieldView RecordView : ResultsArray) + { + size_t ResultIndex = OptionalRecordResultIndexes.empty() ? Index : OptionalRecordResultIndexes[Index]; + Index++; + + if (Results.size() <= ResultIndex) + { + Results.resize(ResultIndex + 1); + } + if (RecordView.IsNull()) + { + continue; + } + Results[ResultIndex] = GetCacheRecordResult{}; + GetCacheRecordResult& Request = Results[ResultIndex].value(); + CbObjectView RecordObject = RecordView.AsObjectView(); + CbObjectView KeyObject = RecordObject["Key"].AsObjectView(); + if (!GetRequestCacheKey(KeyObject, Request.Key)) + { + return false; + } + + CbArrayView ValuesArray = RecordObject["Values"].AsArrayView(); + Request.Values.reserve(ValuesArray.Num()); + for (CbFieldView Value : ValuesArray) + { + CbObjectView ValueObject = Value.AsObjectView(); + IoHash RawHash = ValueObject["RawHash"].AsHash(); + uint64_t RawSize = ValueObject["RawSize"].AsUInt64(); + Oid Id = ValueObject["Id"].AsObjectId(); + const CbAttachment* Attachment = Package.FindAttachment(RawHash); + if (!Attachment) + { + Request.Values.push_back({.Id = Id, .RawHash = RawHash, .RawSize = RawSize, .Body = {}}); + continue; + } + if (!Attachment->IsCompressedBinary()) + { + return false; + } + Request.Values.push_back({.Id = Id, .RawHash = RawHash, .RawSize = RawSize, .Body = Attachment->AsCompressedBinary()}); + } + } + return true; + } + + bool GetCacheRecordsResult::Format(CbPackage& OutPackage) const + { + CbObjectWriter Writer; + + Writer.BeginArray("Result"); + for (const std::optional<GetCacheRecordResult>& RecordResult : Results) + { + if (!RecordResult.has_value()) + { + Writer.AddNull(); + continue; + } + Writer.BeginObject(); + WriteCacheRequestKey(Writer, RecordResult->Key); + + Writer.BeginArray("Values"); + for (const GetCacheRecordResultValue& Value : RecordResult->Values) + { + IoHash AttachmentHash = Value.Body ? Value.Body.DecodeRawHash() : Value.RawHash; + Writer.BeginObject(); + { + Writer.AddObjectId("Id", Value.Id); + Writer.AddHash("RawHash", AttachmentHash); + Writer.AddInteger("RawSize", Value.Body ? Value.Body.DecodeRawSize() : Value.RawSize); + } + Writer.EndObject(); + if (Value.Body) + { + OutPackage.AddAttachment(CbAttachment(Value.Body, AttachmentHash)); + } + } + + Writer.EndArray(); + Writer.EndObject(); + } + Writer.EndArray(); + + OutPackage.SetObject(Writer.Save()); + return true; + } + + bool PutCacheValuesRequest::Parse(const CbPackage& Package) + { + CbObjectView BatchObject = Package.GetObject(); + ZEN_ASSERT(BatchObject["Method"].AsString() == "PutCacheValues"); + AcceptMagic = BatchObject["AcceptType"].AsUInt32(0); + + CbObjectView Params = BatchObject["Params"].AsObjectView(); + std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params); + if (!RequestNamespace) + { + return false; + } + + Namespace = *RequestNamespace; + DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default); + + CbArrayView RequestsArray = Params["Requests"].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + for (CbFieldView RequestField : RequestsArray) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"].AsObjectView(); + + PutCacheValueRequest& Request = Requests.emplace_back(); + + if (!GetRequestCacheKey(KeyObject, Request.Key)) + { + return false; + } + + Request.RawHash = RequestObject["RawHash"].AsBinaryAttachment(); + Request.Policy = GetCachePolicy(RequestObject, "Policy"); + + if (const CbAttachment* Attachment = Package.FindAttachment(Request.RawHash)) + { + if (!Attachment->IsCompressedBinary()) + { + return false; + } + Request.Body = Attachment->AsCompressedBinary(); + } + } + return true; + } + + bool PutCacheValuesRequest::Format(CbPackage& OutPackage) const + { + CbObjectWriter Writer; + Writer << "Method" + << "PutCacheValues"; + if (AcceptMagic != 0) + { + Writer << "Accept" << AcceptMagic; + } + + Writer.BeginObject("Params"); + { + Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy); + Writer << "Namespace" << Namespace; + + Writer.BeginArray("Requests"); + for (const PutCacheValueRequest& ValueRequest : Requests) + { + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, ValueRequest.Key); + if (ValueRequest.Body) + { + IoHash AttachmentHash = ValueRequest.Body.DecodeRawHash(); + if (ValueRequest.RawHash != IoHash::Zero && AttachmentHash != ValueRequest.RawHash) + { + return false; + } + Writer.AddBinaryAttachment("RawHash", AttachmentHash); + OutPackage.AddAttachment(CbAttachment(ValueRequest.Body, AttachmentHash)); + } + else if (ValueRequest.RawHash != IoHash::Zero) + { + Writer.AddBinaryAttachment("RawHash", ValueRequest.RawHash); + } + else + { + return false; + } + WriteCachePolicy(Writer, "Policy", ValueRequest.Policy); + } + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); + + OutPackage.SetObject(Writer.Save()); + return true; + } + + bool PutCacheValuesResult::Parse(const CbPackage& Package) + { + CbArrayView ResultsArray = Package.GetObject()["Result"].AsArrayView(); + if (!ResultsArray) + { + return false; + } + CbFieldViewIterator It = ResultsArray.CreateViewIterator(); + while (It.HasValue()) + { + Success.push_back(It.AsBool()); + It++; + } + return true; + } + + bool PutCacheValuesResult::Format(CbPackage& OutPackage) const + { + if (Success.empty()) + { + return false; + } + CbObjectWriter ResponseObject; + ResponseObject.BeginArray("Result"); + for (bool Value : Success) + { + ResponseObject.AddBool(Value); + } + ResponseObject.EndArray(); + + OutPackage.SetObject(ResponseObject.Save()); + return true; + } + + bool GetCacheValuesRequest::Parse(const CbObjectView& BatchObject) + { + ZEN_ASSERT(BatchObject["Method"].AsString() == "GetCacheValues"); + AcceptMagic = BatchObject["AcceptType"].AsUInt32(0); + AcceptOptions = BatchObject["AcceptFlags"].AsUInt16(0); + ProcessPid = BatchObject["Pid"].AsInt32(0); + + CbObjectView Params = BatchObject["Params"].AsObjectView(); + std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params); + if (!RequestNamespace) + { + return false; + } + + Namespace = *RequestNamespace; + DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default); + + CbArrayView RequestsArray = Params["Requests"].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + for (CbFieldView RequestField : RequestsArray) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"].AsObjectView(); + + GetCacheValueRequest& Request = Requests.emplace_back(); + + if (!GetRequestCacheKey(KeyObject, Request.Key)) + { + return false; + } + + Request.Policy = GetCachePolicy(RequestObject, "Policy"); + } + return true; + } + + bool GetCacheValuesRequest::Format(CbPackage& OutPackage, const std::span<const size_t> OptionalValueFilter) const + { + CbObjectWriter Writer; + Writer << "Method" + << "GetCacheValues"; + if (AcceptMagic != 0) + { + Writer << "Accept" << AcceptMagic; + } + if (AcceptOptions != 0) + { + Writer << "AcceptFlags" << AcceptOptions; + } + if (ProcessPid != 0) + { + Writer << "Pid" << ProcessPid; + } + + Writer.BeginObject("Params"); + { + Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy); + Writer << "Namespace" << Namespace; + + Writer.BeginArray("Requests"); + if (OptionalValueFilter.empty()) + { + for (const GetCacheValueRequest& ValueRequest : Requests) + { + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, ValueRequest.Key); + WriteCachePolicy(Writer, "Policy", ValueRequest.Policy); + } + Writer.EndObject(); + } + } + else + { + for (size_t Index : OptionalValueFilter) + { + const GetCacheValueRequest& ValueRequest = Requests[Index]; + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, ValueRequest.Key); + WriteCachePolicy(Writer, "Policy", ValueRequest.Policy); + } + Writer.EndObject(); + } + } + Writer.EndArray(); + } + Writer.EndObject(); + + OutPackage.SetObject(Writer.Save()); + return true; + } + + bool CacheValuesResult::Parse(const CbPackage& Package, const std::span<const size_t> OptionalValueResultIndexes) + { + CbObject ResponseObject = Package.GetObject(); + CbArrayView ResultsArray = ResponseObject["Result"].AsArrayView(); + if (!ResultsArray) + { + return false; + } + Results.reserve(ResultsArray.Num()); + if (!OptionalValueResultIndexes.empty() && ResultsArray.Num() != OptionalValueResultIndexes.size()) + { + return false; + } + for (size_t Index = 0; CbFieldView RecordView : ResultsArray) + { + size_t ResultIndex = OptionalValueResultIndexes.empty() ? Index : OptionalValueResultIndexes[Index]; + Index++; + + if (Results.size() <= ResultIndex) + { + Results.resize(ResultIndex + 1); + } + if (RecordView.IsNull()) + { + continue; + } + + CacheValueResult& ValueResult = Results[ResultIndex]; + CbObjectView RecordObject = RecordView.AsObjectView(); + + CbFieldView RawHashField = RecordObject["RawHash"]; + ValueResult.RawHash = RawHashField.AsHash(); + bool Succeeded = !RawHashField.HasError(); + if (Succeeded) + { + const CbAttachment* Attachment = Package.FindAttachment(ValueResult.RawHash); + ValueResult.Body = Attachment ? Attachment->AsCompressedBinary() : CompressedBuffer(); + if (ValueResult.Body) + { + ValueResult.RawSize = ValueResult.Body.DecodeRawSize(); + } + else + { + ValueResult.RawSize = RecordObject["RawSize"].AsUInt64(UINT64_MAX); + } + } + } + return true; + } + + bool CacheValuesResult::Format(CbPackage& OutPackage) const + { + CbObjectWriter ResponseObject; + + ResponseObject.BeginArray("Result"); + for (const CacheValueResult& ValueResult : Results) + { + ResponseObject.BeginObject(); + if (ValueResult.RawHash != IoHash::Zero) + { + ResponseObject.AddHash("RawHash", ValueResult.RawHash); + if (ValueResult.Body) + { + OutPackage.AddAttachment(CbAttachment(ValueResult.Body, ValueResult.RawHash)); + } + else + { + ResponseObject.AddInteger("RawSize", ValueResult.RawSize); + } + } + ResponseObject.EndObject(); + } + ResponseObject.EndArray(); + + OutPackage.SetObject(ResponseObject.Save()); + return true; + } + + bool GetCacheChunksRequest::Parse(const CbObjectView& BatchObject) + { + ZEN_ASSERT(BatchObject["Method"].AsString() == "GetCacheChunks"); + AcceptMagic = BatchObject["AcceptType"].AsUInt32(0); + AcceptOptions = BatchObject["AcceptFlags"].AsUInt16(0); + ProcessPid = BatchObject["Pid"].AsInt32(0); + + CbObjectView Params = BatchObject["Params"].AsObjectView(); + std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params); + if (!RequestNamespace) + { + return false; + } + + Namespace = *RequestNamespace; + DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default); + + CbArrayView RequestsArray = Params["ChunkRequests"].AsArrayView(); + Requests.reserve(RequestsArray.Num()); + for (CbFieldView RequestField : RequestsArray) + { + CbObjectView RequestObject = RequestField.AsObjectView(); + CbObjectView KeyObject = RequestObject["Key"].AsObjectView(); + + GetCacheChunkRequest& Request = Requests.emplace_back(); + + if (!GetRequestCacheKey(KeyObject, Request.Key)) + { + return false; + } + + Request.ValueId = RequestObject["ValueId"].AsObjectId(); + Request.ChunkId = RequestObject["ChunkId"].AsHash(); + Request.RawOffset = RequestObject["RawOffset"].AsUInt64(); + Request.RawSize = RequestObject["RawSize"].AsUInt64(UINT64_MAX); + + Request.Policy = GetCachePolicy(RequestObject, "Policy"); + } + return true; + } + + bool GetCacheChunksRequest::Format(CbPackage& OutPackage) const + { + CbObjectWriter Writer; + Writer << "Method" + << "GetCacheChunks"; + if (AcceptMagic != 0) + { + Writer << "Accept" << AcceptMagic; + } + if (AcceptOptions != 0) + { + Writer << "AcceptFlags" << AcceptOptions; + } + if (ProcessPid != 0) + { + Writer << "Pid" << ProcessPid; + } + + Writer.BeginObject("Params"); + { + Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy); + Writer << "Namespace" << Namespace; + + Writer.BeginArray("ChunkRequests"); + for (const GetCacheChunkRequest& ValueRequest : Requests) + { + Writer.BeginObject(); + { + WriteCacheRequestKey(Writer, ValueRequest.Key); + + Writer.AddObjectId("ValueId", ValueRequest.ValueId); + Writer.AddHash("ChunkId", ValueRequest.ChunkId); + Writer.AddInteger("RawOffset", ValueRequest.RawOffset); + Writer.AddInteger("RawSize", ValueRequest.RawSize); + + WriteCachePolicy(Writer, "Policy", ValueRequest.Policy); + } + Writer.EndObject(); + } + Writer.EndArray(); + } + Writer.EndObject(); + + OutPackage.SetObject(Writer.Save()); + return true; + } + + bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data) + { + std::vector<std::string_view> Tokens; + uint32_t TokenCount = zen::ForEachStrTok(Key, '/', [&](const std::string_view& Token) { + Tokens.push_back(Token); + return true; + }); + + switch (TokenCount) + { + case 1: + Data.Namespace = GetValidNamespaceName(Tokens[0]); + return Data.Namespace.has_value(); + case 2: + { + std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]); + if (PossibleHashKey.has_value()) + { + // Legacy bucket/key request + Data.Bucket = GetValidBucketName(Tokens[0]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = PossibleHashKey; + return true; + } + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + return true; + } + case 3: + { + std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]); + if (PossibleHashKey.has_value()) + { + // Legacy bucket/key/valueid request + Data.Bucket = GetValidBucketName(Tokens[0]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = PossibleHashKey; + Data.ValueContentId = GetValidIoHash(Tokens[2]); + if (!Data.ValueContentId.has_value()) + { + return false; + } + return true; + } + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + Data.HashKey = GetValidIoHash(Tokens[2]); + if (!Data.HashKey) + { + return false; + } + return true; + } + case 4: + { + Data.Namespace = GetValidNamespaceName(Tokens[0]); + if (!Data.Namespace.has_value()) + { + return false; + } + + Data.Bucket = GetValidBucketName(Tokens[1]); + if (!Data.Bucket.has_value()) + { + return false; + } + + Data.HashKey = GetValidIoHash(Tokens[2]); + if (!Data.HashKey.has_value()) + { + return false; + } + + Data.ValueContentId = GetValidIoHash(Tokens[3]); + if (!Data.ValueContentId.has_value()) + { + return false; + } + return true; + } + default: + return false; + } + } + + // bool CacheRecord::Parse(CbObjectView& Reader) + // { + // CbObjectView KeyView = Reader["Key"].AsObjectView(); + // + // if (!GetRequestCacheKey(KeyView, Key)) + // { + // return false; + // } + // CbArrayView ValuesArray = Reader["Values"].AsArrayView(); + // Values.reserve(ValuesArray.Num()); + // for (CbFieldView Value : ValuesArray) + // { + // CbObjectView ObjectView = Value.AsObjectView(); + // Values.push_back({.Id = ObjectView["Id"].AsObjectId(), + // .RawHash = ObjectView["RawHash"].AsHash(), + // .RawSize = ObjectView["RawSize"].AsUInt64()}); + // } + // return true; + // } + // + // bool CacheRecord::Format(CbObjectWriter& Writer) const + // { + // WriteCacheRequestKey(Writer, Key); + // Writer.BeginArray("Values"); + // for (const CacheRecordValue& Value : Values) + // { + // Writer.BeginObject(); + // { + // Writer.AddObjectId("Id", Value.Id); + // Writer.AddHash("RawHash", Value.RawHash); + // Writer.AddInteger("RawSize", Value.RawSize); + // } + // Writer.EndObject(); + // } + // Writer.EndArray(); + // return true; + // } + +#if ZEN_WITH_TESTS + + static bool operator==(const PutCacheRecordRequestValue& Lhs, const PutCacheRecordRequestValue& Rhs) + { + const IoHash LhsRawHash = Lhs.RawHash != IoHash::Zero ? Lhs.RawHash : Lhs.Body.DecodeRawHash(); + const IoHash RhsRawHash = Rhs.RawHash != IoHash::Zero ? Rhs.RawHash : Rhs.Body.DecodeRawHash(); + return Lhs.Id == Rhs.Id && LhsRawHash == RhsRawHash && + Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView()); + } + + static bool operator==(const zen::CacheValuePolicy& Lhs, const zen::CacheValuePolicy& Rhs) + { + return (Lhs.Id == Rhs.Id) && (Lhs.Policy == Rhs.Policy); + } + + static bool operator==(const std::span<const zen::CacheValuePolicy>& Lhs, const std::span<const zen::CacheValuePolicy>& Rhs) + { + if (Lhs.size() != Lhs.size()) + { + return false; + } + for (size_t Idx = 0; Idx < Lhs.size(); ++Idx) + { + if (Lhs[Idx] != Rhs[Idx]) + { + return false; + } + } + return true; + } + + static bool operator==(const zen::CacheRecordPolicy& Lhs, const zen::CacheRecordPolicy& Rhs) + { + return (Lhs.GetRecordPolicy() == Rhs.GetRecordPolicy()) && (Lhs.GetBasePolicy() == Rhs.GetBasePolicy()) && + (Lhs.GetValuePolicies() == Rhs.GetValuePolicies()); + } + + static bool operator==(const std::optional<CacheRecordPolicy>& Lhs, const std::optional<CacheRecordPolicy>& Rhs) + { + return (Lhs.has_value() == Rhs.has_value()) && (!Lhs || (*Lhs == *Rhs)); + } + + static bool operator==(const PutCacheRecordRequest& Lhs, const PutCacheRecordRequest& Rhs) + { + return (Lhs.Key == Rhs.Key) && (Lhs.Values == Rhs.Values) && (Lhs.Policy == Rhs.Policy); + } + + static bool operator==(const PutCacheRecordsRequest& Lhs, const PutCacheRecordsRequest& Rhs) + { + return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests); + } + + static bool operator==(const PutCacheRecordsResult& Lhs, const PutCacheRecordsResult& Rhs) { return (Lhs.Success == Rhs.Success); } + + static bool operator==(const GetCacheRecordRequest& Lhs, const GetCacheRecordRequest& Rhs) + { + return (Lhs.Key == Rhs.Key) && (Lhs.Policy == Rhs.Policy); + } + + static bool operator==(const GetCacheRecordsRequest& Lhs, const GetCacheRecordsRequest& Rhs) + { + return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests); + } + + static bool operator==(const GetCacheRecordResultValue& Lhs, const GetCacheRecordResultValue& Rhs) + { + if ((Lhs.Id != Rhs.Id) || (Lhs.RawHash != Rhs.RawHash) || (Lhs.RawSize != Rhs.RawSize)) + { + return false; + } + if (bool(Lhs.Body) != bool(Rhs.Body)) + { + return false; + } + if (bool(Lhs.Body) && Lhs.Body.DecodeRawHash() != Rhs.Body.DecodeRawHash()) + { + return false; + } + return true; + } + + static bool operator==(const GetCacheRecordResult& Lhs, const GetCacheRecordResult& Rhs) + { + return Lhs.Key == Rhs.Key && Lhs.Values == Rhs.Values; + } + + static bool operator==(const std::optional<GetCacheRecordResult>& Lhs, const std::optional<GetCacheRecordResult>& Rhs) + { + if (Lhs.has_value() != Rhs.has_value()) + { + return false; + } + return *Lhs == Rhs; + } + + static bool operator==(const GetCacheRecordsResult& Lhs, const GetCacheRecordsResult& Rhs) { return Lhs.Results == Rhs.Results; } + + static bool operator==(const PutCacheValueRequest& Lhs, const PutCacheValueRequest& Rhs) + { + if ((Lhs.Key != Rhs.Key) || (Lhs.RawHash != Rhs.RawHash)) + { + return false; + } + + if (bool(Lhs.Body) != bool(Rhs.Body)) + { + return false; + } + if (!Lhs.Body) + { + return true; + } + return Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView()); + } + + static bool operator==(const PutCacheValuesRequest& Lhs, const PutCacheValuesRequest& Rhs) + { + return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests); + } + + static bool operator==(const PutCacheValuesResult& Lhs, const PutCacheValuesResult& Rhs) { return (Lhs.Success == Rhs.Success); } + + static bool operator==(const GetCacheValueRequest& Lhs, const GetCacheValueRequest& Rhs) + { + return Lhs.Key == Rhs.Key && Lhs.Policy == Rhs.Policy; + } + + static bool operator==(const GetCacheValuesRequest& Lhs, const GetCacheValuesRequest& Rhs) + { + return Lhs.DefaultPolicy == Rhs.DefaultPolicy && Lhs.Namespace == Rhs.Namespace && Lhs.Requests == Rhs.Requests; + } + + static bool operator==(const CacheValueResult& Lhs, const CacheValueResult& Rhs) + { + if (Lhs.RawHash != Rhs.RawHash) + { + return false; + }; + if (Lhs.Body) + { + if (!Rhs.Body) + { + return false; + } + return Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView()); + } + return Lhs.RawSize == Rhs.RawSize; + } + + static bool operator==(const CacheValuesResult& Lhs, const CacheValuesResult& Rhs) { return Lhs.Results == Rhs.Results; } + + static bool operator==(const GetCacheChunkRequest& Lhs, const GetCacheChunkRequest& Rhs) + { + return Lhs.Key == Rhs.Key && Lhs.ValueId == Rhs.ValueId && Lhs.ChunkId == Rhs.ChunkId && Lhs.RawOffset == Rhs.RawOffset && + Lhs.RawSize == Rhs.RawSize && Lhs.Policy == Rhs.Policy; + } + + static bool operator==(const GetCacheChunksRequest& Lhs, const GetCacheChunksRequest& Rhs) + { + return Lhs.DefaultPolicy == Rhs.DefaultPolicy && Lhs.Namespace == Rhs.Namespace && Lhs.Requests == Rhs.Requests; + } + + static CompressedBuffer MakeCompressedBuffer(size_t Size) { return CompressedBuffer::Compress(SharedBuffer(IoBuffer(Size))); }; + + TEST_CASE("cacherequests.put.cache.records") + { + PutCacheRecordsRequest EmptyRequest; + CbPackage EmptyRequestPackage; + CHECK(EmptyRequest.Format(EmptyRequestPackage)); + PutCacheRecordsRequest EmptyRequestCopy; + CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required + + PutCacheRecordsRequest FullRequest = { + .DefaultPolicy = CachePolicy::Remote, + .Namespace = "the_namespace", + .Requests = {{.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(2134)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(213)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(7)}}, + .Policy = CachePolicy::StoreLocal}, + {.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1234)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(99)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(124)}}, + .Policy = CachePolicy::Store}, + {.Key = {.Bucket = "theotherbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(19)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1248)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(823)}}}}}; + + CbPackage FullRequestPackage; + CHECK(FullRequest.Format(FullRequestPackage)); + PutCacheRecordsRequest FullRequestCopy; + CHECK(FullRequestCopy.Parse(FullRequestPackage)); + CHECK(FullRequest == FullRequestCopy); + + PutCacheRecordsResult EmptyResult; + CbPackage EmptyResponsePackage; + CHECK(EmptyResult.Format(EmptyResponsePackage)); + PutCacheRecordsResult EmptyResultCopy; + CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage)); + CHECK(EmptyResult == EmptyResultCopy); + + PutCacheRecordsResult FullResult = {.Success = {true, false, true, true, false}}; + CbPackage FullResponsePackage; + CHECK(FullResult.Format(FullResponsePackage)); + PutCacheRecordsResult FullResultCopy; + CHECK(FullResultCopy.Parse(FullResponsePackage)); + CHECK(FullResult == FullResultCopy); + } + + TEST_CASE("cacherequests.get.cache.records") + { + GetCacheRecordsRequest EmptyRequest; + CbPackage EmptyRequestPackage; + CHECK(EmptyRequest.Format(EmptyRequestPackage)); + GetCacheRecordsRequest EmptyRequestCopy; + CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required + + GetCacheRecordsRequest FullRequest = { + .DefaultPolicy = CachePolicy::StoreLocal, + .Namespace = "other_namespace", + .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .Policy = CachePolicy::Local}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .Policy = CachePolicy::Remote}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}}}}; + + CbPackage FullRequestPackage; + CHECK(FullRequest.Format(FullRequestPackage)); + GetCacheRecordsRequest FullRequestCopy; + CHECK(FullRequestCopy.Parse(FullRequestPackage)); + CHECK(FullRequest == FullRequestCopy); + + CbPackage PartialRequestPackage; + CHECK(FullRequest.Format(PartialRequestPackage, std::initializer_list<size_t>{0, 2})); + GetCacheRecordsRequest PartialRequest = FullRequest; + PartialRequest.Requests.erase(PartialRequest.Requests.begin() + 1); + GetCacheRecordsRequest PartialRequestCopy; + CHECK(PartialRequestCopy.Parse(PartialRequestPackage)); + CHECK(PartialRequest == PartialRequestCopy); + + GetCacheRecordsResult EmptyResult; + CbPackage EmptyResponsePackage; + CHECK(EmptyResult.Format(EmptyResponsePackage)); + GetCacheRecordsResult EmptyResultCopy; + CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage)); + CHECK(EmptyResult == EmptyResultCopy); + + PutCacheRecordsRequest FullPutRequest = { + .DefaultPolicy = CachePolicy::Remote, + .Namespace = "the_namespace", + .Requests = {{.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(2134)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(213)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(7)}}, + .Policy = CachePolicy::StoreLocal}, + {.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1234)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(99)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(124)}}, + .Policy = CachePolicy::Store}, + {.Key = {.Bucket = "theotherbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}, + .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(19)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1248)}, + {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(823)}}}}}; + + CbPackage FullPutRequestPackage; + CHECK(FullPutRequest.Format(FullPutRequestPackage)); + PutCacheRecordsRequest FullPutRequestCopy; + CHECK(FullPutRequestCopy.Parse(FullPutRequestPackage)); + + GetCacheRecordsResult FullResult = { + {GetCacheRecordResult{.Key = FullPutRequestCopy.Requests[0].Key, + .Values = {{.Id = FullPutRequestCopy.Requests[0].Values[0].Id, + .RawHash = FullPutRequestCopy.Requests[0].Values[0].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[0].Values[0].Body.DecodeRawSize(), + .Body = FullPutRequestCopy.Requests[0].Values[0].Body}, + {.Id = FullPutRequestCopy.Requests[0].Values[1].Id, + + .RawHash = FullPutRequestCopy.Requests[0].Values[1].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[0].Values[1].Body.DecodeRawSize(), + .Body = FullPutRequestCopy.Requests[0].Values[1].Body}, + {.Id = FullPutRequestCopy.Requests[0].Values[2].Id, + .RawHash = FullPutRequestCopy.Requests[0].Values[2].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[0].Values[2].Body.DecodeRawSize(), + .Body = FullPutRequestCopy.Requests[0].Values[2].Body}}}, + {}, // Simulate not have! + GetCacheRecordResult{.Key = FullPutRequestCopy.Requests[2].Key, + .Values = {{.Id = FullPutRequestCopy.Requests[2].Values[0].Id, + .RawHash = FullPutRequestCopy.Requests[2].Values[0].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[2].Values[0].Body.DecodeRawSize(), + .Body = FullPutRequestCopy.Requests[2].Values[0].Body}, + {.Id = FullPutRequestCopy.Requests[2].Values[1].Id, + .RawHash = FullPutRequestCopy.Requests[2].Values[1].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[2].Values[1].Body.DecodeRawSize(), + .Body = {}}, // Simulate not have + {.Id = FullPutRequestCopy.Requests[2].Values[2].Id, + .RawHash = FullPutRequestCopy.Requests[2].Values[2].Body.DecodeRawHash(), + .RawSize = FullPutRequestCopy.Requests[2].Values[2].Body.DecodeRawSize(), + .Body = FullPutRequestCopy.Requests[2].Values[2].Body}}}}}; + CbPackage FullResponsePackage; + CHECK(FullResult.Format(FullResponsePackage)); + GetCacheRecordsResult FullResultCopy; + CHECK(FullResultCopy.Parse(FullResponsePackage)); + CHECK(FullResult.Results[0] == FullResultCopy.Results[0]); + CHECK(!FullResultCopy.Results[1]); + CHECK(FullResult.Results[2] == FullResultCopy.Results[2]); + + GetCacheRecordsResult PartialResultCopy; + CHECK(PartialResultCopy.Parse(FullResponsePackage, std::initializer_list<size_t>{0, 3, 4})); + CHECK(FullResult.Results[0] == PartialResultCopy.Results[0]); + CHECK(!PartialResultCopy.Results[1]); + CHECK(!PartialResultCopy.Results[2]); + CHECK(!PartialResultCopy.Results[3]); + CHECK(FullResult.Results[2] == PartialResultCopy.Results[4]); + } + + TEST_CASE("cacherequests.put.cache.values") + { + PutCacheValuesRequest EmptyRequest; + CbPackage EmptyRequestPackage; + CHECK(EmptyRequest.Format(EmptyRequestPackage)); + PutCacheValuesRequest EmptyRequestCopy; + CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required + + CompressedBuffer Buffers[3] = {MakeCompressedBuffer(969), MakeCompressedBuffer(3469), MakeCompressedBuffer(9)}; + PutCacheValuesRequest FullRequest = { + .DefaultPolicy = CachePolicy::StoreLocal, + .Namespace = "other_namespace", + .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .RawHash = Buffers[0].DecodeRawHash(), + .Body = Buffers[0], + .Policy = CachePolicy::Local}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .RawHash = Buffers[1].DecodeRawHash(), + .Body = Buffers[1], + .Policy = CachePolicy::Remote}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}, + .RawHash = Buffers[2].DecodeRawHash()}}}; + + CbPackage FullRequestPackage; + CHECK(FullRequest.Format(FullRequestPackage)); + PutCacheValuesRequest FullRequestCopy; + CHECK(FullRequestCopy.Parse(FullRequestPackage)); + CHECK(FullRequest == FullRequestCopy); + + PutCacheValuesResult EmptyResult; + CbPackage EmptyResponsePackage; + CHECK(!EmptyResult.Format(EmptyResponsePackage)); + + PutCacheValuesResult FullResult = {.Success = {true, false, true}}; + + CbPackage FullResponsePackage; + CHECK(FullResult.Format(FullResponsePackage)); + PutCacheValuesResult FullResultCopy; + CHECK(FullResultCopy.Parse(FullResponsePackage)); + CHECK(FullResult == FullResultCopy); + } + + TEST_CASE("cacherequests.get.cache.values") + { + GetCacheValuesRequest EmptyRequest; + CbPackage EmptyRequestPackage; + CHECK(EmptyRequest.Format(EmptyRequestPackage)); + GetCacheValuesRequest EmptyRequestCopy; + CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage.GetObject())); // Namespace is required + + GetCacheValuesRequest FullRequest = { + .DefaultPolicy = CachePolicy::StoreLocal, + .Namespace = "other_namespace", + .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .Policy = CachePolicy::Local}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .Policy = CachePolicy::Remote}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}}}}; + + CbPackage FullRequestPackage; + CHECK(FullRequest.Format(FullRequestPackage)); + GetCacheValuesRequest FullRequestCopy; + CHECK(FullRequestCopy.Parse(FullRequestPackage.GetObject())); + CHECK(FullRequest == FullRequestCopy); + + CbPackage PartialRequestPackage; + CHECK(FullRequest.Format(PartialRequestPackage, std::initializer_list<size_t>{0, 2})); + GetCacheValuesRequest PartialRequest = FullRequest; + PartialRequest.Requests.erase(PartialRequest.Requests.begin() + 1); + GetCacheValuesRequest PartialRequestCopy; + CHECK(PartialRequestCopy.Parse(PartialRequestPackage.GetObject())); + CHECK(PartialRequest == PartialRequestCopy); + + CacheValuesResult EmptyResult; + CbPackage EmptyResponsePackage; + CHECK(EmptyResult.Format(EmptyResponsePackage)); + CacheValuesResult EmptyResultCopy; + CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage)); + CHECK(EmptyResult == EmptyResultCopy); + + CompressedBuffer Buffers[3][3] = {{MakeCompressedBuffer(123), MakeCompressedBuffer(321), MakeCompressedBuffer(333)}, + {MakeCompressedBuffer(6123), MakeCompressedBuffer(8321), MakeCompressedBuffer(7333)}, + {MakeCompressedBuffer(5123), MakeCompressedBuffer(2321), MakeCompressedBuffer(2333)}}; + CacheValuesResult FullResult = { + .Results = {CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][0].DecodeRawHash(), .Body = Buffers[0][0]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][1].DecodeRawHash(), .Body = Buffers[0][1]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][2].DecodeRawHash(), .Body = Buffers[0][2]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][0].DecodeRawHash(), .Body = Buffers[2][0]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][1].DecodeRawHash(), .Body = Buffers[2][1]}, + CacheValueResult{.RawSize = Buffers[2][2].DecodeRawSize(), .RawHash = Buffers[2][2].DecodeRawHash()}}}; + CbPackage FullResponsePackage; + CHECK(FullResult.Format(FullResponsePackage)); + CacheValuesResult FullResultCopy; + CHECK(FullResultCopy.Parse(FullResponsePackage)); + CHECK(FullResult == FullResultCopy); + + CacheValuesResult PartialResultCopy; + CHECK(PartialResultCopy.Parse(FullResponsePackage, std::initializer_list<size_t>{0, 3, 4, 5, 6, 9})); + CHECK(PartialResultCopy.Results[0] == FullResult.Results[0]); + CHECK(PartialResultCopy.Results[1].RawHash == IoHash::Zero); + CHECK(PartialResultCopy.Results[2].RawHash == IoHash::Zero); + CHECK(PartialResultCopy.Results[3] == FullResult.Results[1]); + CHECK(PartialResultCopy.Results[4] == FullResult.Results[2]); + CHECK(PartialResultCopy.Results[5] == FullResult.Results[3]); + CHECK(PartialResultCopy.Results[6] == FullResult.Results[4]); + CHECK(PartialResultCopy.Results[7].RawHash == IoHash::Zero); + CHECK(PartialResultCopy.Results[8].RawHash == IoHash::Zero); + CHECK(PartialResultCopy.Results[9] == FullResult.Results[5]); + } + + TEST_CASE("cacherequests.get.cache.chunks") + { + GetCacheChunksRequest EmptyRequest; + CbPackage EmptyRequestPackage; + CHECK(EmptyRequest.Format(EmptyRequestPackage)); + GetCacheChunksRequest EmptyRequestCopy; + CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage.GetObject())); // Namespace is required + + GetCacheChunksRequest FullRequest = { + .DefaultPolicy = CachePolicy::StoreLocal, + .Namespace = "other_namespace", + .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")}, + .ValueId = Oid::NewOid(), + .ChunkId = IoHash::FromHexString("ab3917854bfef7e7af2c372d795bb907a15cab15"), + .RawOffset = 77, + .RawSize = 33, + .Policy = CachePolicy::Local}, + {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")}, + .ValueId = Oid::NewOid(), + .ChunkId = IoHash::FromHexString("372d795bb907a15cab15ab3917854bfef7e7af2c"), + .Policy = CachePolicy::Remote}, + { + .Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}, + .ChunkId = IoHash::FromHexString("372d795bb907a15cab15ab3917854bfef7e7af2c"), + }}}; + + CbPackage FullRequestPackage; + CHECK(FullRequest.Format(FullRequestPackage)); + GetCacheChunksRequest FullRequestCopy; + CHECK(FullRequestCopy.Parse(FullRequestPackage.GetObject())); + CHECK(FullRequest == FullRequestCopy); + + GetCacheChunksResult EmptyResult; + CbPackage EmptyResponsePackage; + CHECK(EmptyResult.Format(EmptyResponsePackage)); + GetCacheChunksResult EmptyResultCopy; + CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage)); + CHECK(EmptyResult == EmptyResultCopy); + + CompressedBuffer Buffers[3][3] = {{MakeCompressedBuffer(123), MakeCompressedBuffer(321), MakeCompressedBuffer(333)}, + {MakeCompressedBuffer(6123), MakeCompressedBuffer(8321), MakeCompressedBuffer(7333)}, + {MakeCompressedBuffer(5123), MakeCompressedBuffer(2321), MakeCompressedBuffer(2333)}}; + GetCacheChunksResult FullResult = { + .Results = {CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][0].DecodeRawHash(), .Body = Buffers[0][0]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][1].DecodeRawHash(), .Body = Buffers[0][1]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][2].DecodeRawHash(), .Body = Buffers[0][2]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][0].DecodeRawHash(), .Body = Buffers[2][0]}, + CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][1].DecodeRawHash(), .Body = Buffers[2][1]}, + CacheValueResult{.RawSize = Buffers[2][2].DecodeRawSize(), .RawHash = Buffers[2][2].DecodeRawHash()}}}; + CbPackage FullResponsePackage; + CHECK(FullResult.Format(FullResponsePackage)); + GetCacheChunksResult FullResultCopy; + CHECK(FullResultCopy.Parse(FullResponsePackage)); + CHECK(FullResult == FullResultCopy); + } + + TEST_CASE("z$service.parse.relative.Uri") + { + HttpRequestData LegacyBucketRequestBecomesNamespaceRequest; + CHECK(HttpRequestParseRelativeUri("test", LegacyBucketRequestBecomesNamespaceRequest)); + CHECK(LegacyBucketRequestBecomesNamespaceRequest.Namespace == "test"); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.Bucket.has_value()); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.HashKey.has_value()); + CHECK(!LegacyBucketRequestBecomesNamespaceRequest.ValueContentId.has_value()); + + HttpRequestData LegacyHashKeyRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", LegacyHashKeyRequest)); + CHECK(!LegacyHashKeyRequest.Namespace); + CHECK(LegacyHashKeyRequest.Bucket == "test"); + CHECK(LegacyHashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234")); + CHECK(!LegacyHashKeyRequest.ValueContentId.has_value()); + + HttpRequestData LegacyValueContentIdRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789", + LegacyValueContentIdRequest)); + CHECK(!LegacyValueContentIdRequest.Namespace); + CHECK(LegacyValueContentIdRequest.Bucket == "test"); + CHECK(LegacyValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234")); + CHECK(LegacyValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789")); + + HttpRequestData V2DefaultNamespaceRequest; + CHECK(HttpRequestParseRelativeUri("ue4.ddc", V2DefaultNamespaceRequest)); + CHECK(V2DefaultNamespaceRequest.Namespace == "ue4.ddc"); + CHECK(!V2DefaultNamespaceRequest.Bucket.has_value()); + CHECK(!V2DefaultNamespaceRequest.HashKey.has_value()); + CHECK(!V2DefaultNamespaceRequest.ValueContentId.has_value()); + + HttpRequestData V2NamespaceRequest; + CHECK(HttpRequestParseRelativeUri("nicenamespace", V2NamespaceRequest)); + CHECK(V2NamespaceRequest.Namespace == "nicenamespace"); + CHECK(!V2NamespaceRequest.Bucket.has_value()); + CHECK(!V2NamespaceRequest.HashKey.has_value()); + CHECK(!V2NamespaceRequest.ValueContentId.has_value()); + + HttpRequestData V2BucketRequestWithDefaultNamespace; + CHECK(HttpRequestParseRelativeUri("ue4.ddc/test", V2BucketRequestWithDefaultNamespace)); + CHECK(V2BucketRequestWithDefaultNamespace.Namespace == "ue4.ddc"); + CHECK(V2BucketRequestWithDefaultNamespace.Bucket == "test"); + CHECK(!V2BucketRequestWithDefaultNamespace.HashKey.has_value()); + CHECK(!V2BucketRequestWithDefaultNamespace.ValueContentId.has_value()); + + HttpRequestData V2BucketRequestWithNamespace; + CHECK(HttpRequestParseRelativeUri("nicenamespace/test", V2BucketRequestWithNamespace)); + CHECK(V2BucketRequestWithNamespace.Namespace == "nicenamespace"); + CHECK(V2BucketRequestWithNamespace.Bucket == "test"); + CHECK(!V2BucketRequestWithNamespace.HashKey.has_value()); + CHECK(!V2BucketRequestWithNamespace.ValueContentId.has_value()); + + HttpRequestData V2HashKeyRequest; + CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", V2HashKeyRequest)); + CHECK(!V2HashKeyRequest.Namespace); + CHECK(V2HashKeyRequest.Bucket == "test"); + CHECK(V2HashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234")); + CHECK(!V2HashKeyRequest.ValueContentId.has_value()); + + HttpRequestData V2ValueContentIdRequest; + CHECK(HttpRequestParseRelativeUri( + "nicenamespace/test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789", + V2ValueContentIdRequest)); + CHECK(V2ValueContentIdRequest.Namespace == "nicenamespace"); + CHECK(V2ValueContentIdRequest.Bucket == "test"); + CHECK(V2ValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234")); + CHECK(V2ValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789")); + + HttpRequestData Invalid; + CHECK(!HttpRequestParseRelativeUri("", Invalid)); + CHECK(!HttpRequestParseRelativeUri("/", Invalid)); + CHECK(!HttpRequestParseRelativeUri("bad\2_namespace", Invalid)); + CHECK(!HttpRequestParseRelativeUri("nice/\2\1bucket", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789a", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcdef1234", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/pppppppp89abcdef12340123456789abcdef1234", Invalid)); + CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcd", Invalid)); + CHECK(!HttpRequestParseRelativeUri( + "namespace/bucket/0123456789abcdef12340123456789abcdef1234/ppppppppdef12345678956789abcdef123456789", + Invalid)); + } +#endif +} // namespace cacherequests + +void +cacherequests_forcelink() +{ +} + +} // namespace zen diff --git a/src/zenutil/cache/rpcrecording.cpp b/src/zenutil/cache/rpcrecording.cpp new file mode 100644 index 000000000..4958a27f6 --- /dev/null +++ b/src/zenutil/cache/rpcrecording.cpp @@ -0,0 +1,210 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenutil/basicfile.h> +#include <zenutil/cache/rpcrecording.h> + +ZEN_THIRD_PARTY_INCLUDES_START +#include <fmt/format.h> +#include <gsl/gsl-lite.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +namespace zen::cache { +struct RecordedRequest +{ + uint64_t Offset; + uint64_t Length; + ZenContentType ContentType; + ZenContentType AcceptType; +}; + +const uint64_t RecordedRequestBlockSize = 1ull << 31u; + +struct RecordedRequestsWriter +{ + void BeginWrite(const std::filesystem::path& BasePath) + { + m_BasePath = BasePath; + std::filesystem::create_directories(m_BasePath); + } + + void EndWrite() + { + RwLock::ExclusiveLockScope _(m_Lock); + m_BlockFiles.clear(); + + IoBuffer IndexBuffer(IoBuffer::Wrap, m_Entries.data(), m_Entries.size() * sizeof(RecordedRequest)); + BasicFile IndexFile; + IndexFile.Open(m_BasePath / "index.bin", BasicFile::Mode::kTruncate); + std::error_code Ec; + IndexFile.WriteAll(IndexBuffer, Ec); + IndexFile.Close(); + m_Entries.clear(); + } + + uint64_t WriteRequest(ZenContentType ContentType, ZenContentType AcceptType, const IoBuffer& RequestBuffer) + { + RwLock::ExclusiveLockScope Lock(m_Lock); + uint64_t RequestIndex = m_Entries.size(); + RecordedRequest& Entry = m_Entries.emplace_back( + RecordedRequest{.Offset = ~0ull, .Length = RequestBuffer.Size(), .ContentType = ContentType, .AcceptType = AcceptType}); + if (Entry.Length < 1 * 1024 * 1024) + { + uint32_t BlockIndex = gsl::narrow<uint32_t>((m_ChunkOffset + Entry.Length) / RecordedRequestBlockSize); + if (BlockIndex == m_BlockFiles.size()) + { + std::unique_ptr<BasicFile>& NewBlockFile = m_BlockFiles.emplace_back(std::make_unique<BasicFile>()); + NewBlockFile->Open(m_BasePath / fmt::format("chunks{}.bin", BlockIndex), BasicFile::Mode::kTruncate); + m_ChunkOffset = BlockIndex * RecordedRequestBlockSize; + } + ZEN_ASSERT(BlockIndex < m_BlockFiles.size()); + BasicFile* BlockFile = m_BlockFiles[BlockIndex].get(); + ZEN_ASSERT(BlockFile != nullptr); + + Entry.Offset = m_ChunkOffset; + m_ChunkOffset = RoundUp(m_ChunkOffset + Entry.Length, 1u << 4u); + Lock.ReleaseNow(); + + std::error_code Ec; + BlockFile->Write(RequestBuffer.Data(), RequestBuffer.Size(), Entry.Offset - BlockIndex * RecordedRequestBlockSize, Ec); + if (Ec) + { + Entry.Length = 0; + return ~0ull; + } + return RequestIndex; + } + Lock.ReleaseNow(); + + BasicFile RequestFile; + RequestFile.Open(m_BasePath / fmt::format("request{}.bin", RequestIndex), BasicFile::Mode::kTruncate); + std::error_code Ec; + RequestFile.WriteAll(RequestBuffer, Ec); + if (Ec) + { + Entry.Length = 0; + return ~0ull; + } + return RequestIndex; + } + + std::filesystem::path m_BasePath; + mutable RwLock m_Lock; + std::vector<RecordedRequest> m_Entries; + std::vector<std::unique_ptr<BasicFile>> m_BlockFiles; + uint64_t m_ChunkOffset; +}; + +struct RecordedRequestsReader +{ + uint64_t BeginRead(const std::filesystem::path& BasePath, bool InMemory) + { + m_BasePath = BasePath; + BasicFile IndexFile; + IndexFile.Open(m_BasePath / "index.bin", BasicFile::Mode::kRead); + m_Entries.resize(IndexFile.FileSize() / sizeof(RecordedRequest)); + IndexFile.Read(m_Entries.data(), IndexFile.FileSize(), 0); + uint64_t MaxChunkPosition = 0; + for (const RecordedRequest& R : m_Entries) + { + if (R.Offset != ~0ull) + { + MaxChunkPosition = Max(MaxChunkPosition, R.Offset + R.Length); + } + } + uint32_t BlockCount = gsl::narrow<uint32_t>(MaxChunkPosition / RecordedRequestBlockSize) + 1; + m_BlockFiles.resize(BlockCount); + for (uint32_t BlockIndex = 0; BlockIndex < BlockCount; ++BlockIndex) + { + if (InMemory) + { + BasicFile Chunk; + Chunk.Open(m_BasePath / fmt::format("chunks{}.bin", BlockIndex), BasicFile::Mode::kRead); + m_BlockFiles[BlockIndex] = Chunk.ReadAll(); + continue; + } + m_BlockFiles[BlockIndex] = IoBufferBuilder::MakeFromFile(m_BasePath / fmt::format("chunks{}.bin", BlockIndex)); + } + return m_Entries.size(); + } + void EndRead() { m_BlockFiles.clear(); } + + std::pair<ZenContentType, ZenContentType> ReadRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) const + { + if (RequestIndex >= m_Entries.size()) + { + return {ZenContentType::kUnknownContentType, ZenContentType::kUnknownContentType}; + } + const RecordedRequest& Entry = m_Entries[RequestIndex]; + if (Entry.Length == 0) + { + return {ZenContentType::kUnknownContentType, ZenContentType::kUnknownContentType}; + } + if (Entry.Offset != ~0ull) + { + uint32_t BlockIndex = gsl::narrow<uint32_t>((Entry.Offset + Entry.Length) / RecordedRequestBlockSize); + uint64_t ChunkOffset = Entry.Offset - (BlockIndex * RecordedRequestBlockSize); + OutBuffer = IoBuffer(m_BlockFiles[BlockIndex], ChunkOffset, Entry.Length); + return {Entry.ContentType, Entry.AcceptType}; + } + OutBuffer = IoBufferBuilder::MakeFromFile(m_BasePath / fmt::format("request{}.bin", RequestIndex)); + return {Entry.ContentType, Entry.AcceptType}; + } + + std::filesystem::path m_BasePath; + std::vector<RecordedRequest> m_Entries; + std::vector<IoBuffer> m_BlockFiles; +}; + +class DiskRequestRecorder : public IRpcRequestRecorder +{ +public: + DiskRequestRecorder(const std::filesystem::path& BasePath) { m_RecordedRequests.BeginWrite(BasePath); } + virtual ~DiskRequestRecorder() { m_RecordedRequests.EndWrite(); } + +private: + virtual uint64_t RecordRequest(const ZenContentType ContentType, + const ZenContentType AcceptType, + const IoBuffer& RequestBuffer) override + { + return m_RecordedRequests.WriteRequest(ContentType, AcceptType, RequestBuffer); + } + virtual void RecordResponse(uint64_t, const ZenContentType, const IoBuffer&) override {} + virtual void RecordResponse(uint64_t, const ZenContentType, const CompositeBuffer&) override {} + RecordedRequestsWriter m_RecordedRequests; +}; + +class DiskRequestReplayer : public IRpcRequestReplayer +{ +public: + DiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory) + { + m_RequestCount = m_RequestBuffer.BeginRead(BasePath, InMemory); + } + virtual ~DiskRequestReplayer() { m_RequestBuffer.EndRead(); } + +private: + virtual uint64_t GetRequestCount() const override { return m_RequestCount; } + + virtual std::pair<ZenContentType, ZenContentType> GetRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) override + { + return m_RequestBuffer.ReadRequest(RequestIndex, OutBuffer); + } + virtual ZenContentType GetResponse(uint64_t, IoBuffer&) override { return ZenContentType::kUnknownContentType; } + + std::uint64_t m_RequestCount; + RecordedRequestsReader m_RequestBuffer; +}; + +std::unique_ptr<cache::IRpcRequestRecorder> +MakeDiskRequestRecorder(const std::filesystem::path& BasePath) +{ + return std::make_unique<DiskRequestRecorder>(BasePath); +} + +std::unique_ptr<cache::IRpcRequestReplayer> +MakeDiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory) +{ + return std::make_unique<DiskRequestReplayer>(BasePath, InMemory); +} + +} // namespace zen::cache diff --git a/src/zenutil/include/zenutil/basicfile.h b/src/zenutil/include/zenutil/basicfile.h new file mode 100644 index 000000000..877df0f92 --- /dev/null +++ b/src/zenutil/include/zenutil/basicfile.h @@ -0,0 +1,113 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iobuffer.h> + +#include <filesystem> +#include <functional> + +namespace zen { + +class CbObject; + +/** + * Probably the most basic file abstraction in the universe + * + * One thing of note is that there is no notion of a "current file position" + * in this API -- all reads and writes are done from explicit offsets in + * the file. This avoids concurrency issues which can occur otherwise. + * + */ + +class BasicFile +{ +public: + BasicFile() = default; + ~BasicFile(); + + BasicFile(const BasicFile&) = delete; + BasicFile& operator=(const BasicFile&) = delete; + + enum class Mode : uint32_t + { + kRead = 0, // Opens a existing file for read only + kWrite = 1, // Opens (or creates) a file for read and write + kTruncate = 2, // Opens (or creates) a file for read and write and sets the size to zero + kDelete = 3, // Opens (or creates) a file for read and write allowing .DeleteFile file disposition to be set + kTruncateDelete = + 4 // Opens (or creates) a file for read and write and sets the size to zero allowing .DeleteFile file disposition to be set + }; + + void Open(const std::filesystem::path& FileName, Mode Mode); + void Open(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec); + void Close(); + void Read(void* Data, uint64_t Size, uint64_t FileOffset); + void StreamFile(std::function<void(const void* Data, uint64_t Size)>&& ChunkFun); + void StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun); + void Write(MemoryView Data, uint64_t FileOffset); + void Write(MemoryView Data, uint64_t FileOffset, std::error_code& Ec); + void Write(const void* Data, uint64_t Size, uint64_t FileOffset); + void Write(const void* Data, uint64_t Size, uint64_t FileOffset, std::error_code& Ec); + void Flush(); + uint64_t FileSize(); + void SetFileSize(uint64_t FileSize); + IoBuffer ReadAll(); + void WriteAll(IoBuffer Data, std::error_code& Ec); + void* Detach(); + + inline void* Handle() { return m_FileHandle; } + +protected: + void* m_FileHandle = nullptr; // This is either null or valid +private: +}; + +/** + * Simple abstraction for a temporary file + * + * Works like a regular BasicFile but implements a simple mechanism to allow creating + * a temporary file for writing in a directory which may later be moved atomically + * into the intended location after it has been fully written to. + * + */ + +class TemporaryFile : public BasicFile +{ +public: + TemporaryFile() = default; + ~TemporaryFile(); + + TemporaryFile(const TemporaryFile&) = delete; + TemporaryFile& operator=(const TemporaryFile&) = delete; + + void Close(); + void CreateTemporary(std::filesystem::path TempDirName, std::error_code& Ec); + void MoveTemporaryIntoPlace(std::filesystem::path FinalFileName, std::error_code& Ec); + const std::filesystem::path& GetPath() const { return m_TempPath; } + +private: + std::filesystem::path m_TempPath; + + using BasicFile::Open; +}; + +/** Lock file abstraction + + */ + +class LockFile : protected BasicFile +{ +public: + LockFile(); + ~LockFile(); + + void Create(std::filesystem::path FileName, CbObject Payload, std::error_code& Ec); + void Update(CbObject Payload, std::error_code& Ec); + +private: +}; + +ZENCORE_API void basicfile_forcelink(); + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cache/cache.h b/src/zenutil/include/zenutil/cache/cache.h new file mode 100644 index 000000000..1a1dd9386 --- /dev/null +++ b/src/zenutil/include/zenutil/cache/cache.h @@ -0,0 +1,6 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenutil/cache/cachekey.h> +#include <zenutil/cache/cachepolicy.h> diff --git a/src/zenutil/include/zenutil/cache/cachekey.h b/src/zenutil/include/zenutil/cache/cachekey.h new file mode 100644 index 000000000..741375946 --- /dev/null +++ b/src/zenutil/include/zenutil/cache/cachekey.h @@ -0,0 +1,86 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/iohash.h> +#include <zencore/string.h> +#include <zencore/uid.h> + +#include <zenutil/cache/cachepolicy.h> + +namespace zen { + +struct CacheKey +{ + std::string Bucket; + IoHash Hash; + + static CacheKey Create(std::string_view Bucket, const IoHash& Hash) { return {.Bucket = ToLower(Bucket), .Hash = Hash}; } + + auto operator<=>(const CacheKey& that) const + { + if (auto b = caseSensitiveCompareStrings(Bucket, that.Bucket); b != std::strong_ordering::equal) + { + return b; + } + return Hash <=> that.Hash; + } + + auto operator==(const CacheKey& that) const { return (*this <=> that) == std::strong_ordering::equal; } + + static const CacheKey Empty; +}; + +struct CacheChunkRequest +{ + CacheKey Key; + IoHash ChunkId; + Oid ValueId; + uint64_t RawOffset = 0ull; + uint64_t RawSize = ~uint64_t(0); + CachePolicy Policy = CachePolicy::Default; +}; + +struct CacheKeyRequest +{ + CacheKey Key; + CacheRecordPolicy Policy; +}; + +struct CacheValueRequest +{ + CacheKey Key; + CachePolicy Policy = CachePolicy::Default; +}; + +inline bool +operator<(const CacheChunkRequest& A, const CacheChunkRequest& B) +{ + if (A.Key < B.Key) + { + return true; + } + if (B.Key < A.Key) + { + return false; + } + if (A.ChunkId < B.ChunkId) + { + return true; + } + if (B.ChunkId < A.ChunkId) + { + return false; + } + if (A.ValueId < B.ValueId) + { + return true; + } + if (B.ValueId < A.ValueId) + { + return false; + } + return A.RawOffset < B.RawOffset; +} + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cache/cachepolicy.h b/src/zenutil/include/zenutil/cache/cachepolicy.h new file mode 100644 index 000000000..9a745e42c --- /dev/null +++ b/src/zenutil/include/zenutil/cache/cachepolicy.h @@ -0,0 +1,227 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compactbinary.h> +#include <zencore/enumflags.h> +#include <zencore/refcount.h> +#include <zencore/string.h> +#include <zencore/uid.h> + +#include <gsl/gsl-lite.hpp> +#include <span> +namespace zen::Private { +class ICacheRecordPolicyShared; +} +namespace zen { + +class CbObjectView; +class CbWriter; + +class OptionalCacheRecordPolicy; + +enum class CachePolicy : uint32_t +{ + /** A value with no flags. Disables access to the cache unless combined with other flags. */ + None = 0, + + /** Allow a cache request to query local caches. */ + QueryLocal = 1 << 0, + /** Allow a cache request to query remote caches. */ + QueryRemote = 1 << 1, + /** Allow a cache request to query any caches. */ + Query = QueryLocal | QueryRemote, + + /** Allow cache requests to query and store records and values in local caches. */ + StoreLocal = 1 << 2, + /** Allow cache records and values to be stored in remote caches. */ + StoreRemote = 1 << 3, + /** Allow cache records and values to be stored in any caches. */ + Store = StoreLocal | StoreRemote, + + /** Allow cache requests to query and store records and values in local caches. */ + Local = QueryLocal | StoreLocal, + /** Allow cache requests to query and store records and values in remote caches. */ + Remote = QueryRemote | StoreRemote, + + /** Allow cache requests to query and store records and values in any caches. */ + Default = Query | Store, + + /** Skip fetching the data for values. */ + SkipData = 1 << 4, + + /** Skip fetching the metadata for record requests. */ + SkipMeta = 1 << 5, + + /** + * Partial output will be provided with the error status when a required value is missing. + * + * This is meant for cases when the missing values can be individually recovered, or rebuilt, + * without rebuilding the whole record. The cache automatically adds this flag when there are + * other cache stores that it may be able to recover missing values from. + * + * Missing values will be returned in the records, but with only the hash and size. + * + * Applying this flag for a put of a record allows a partial record to be stored. + */ + PartialRecord = 1 << 6, + + /** + * Keep records in the cache for at least the duration of the session. + * + * This is a hint that the record may be accessed again in this session. This is mainly meant + * to be used when subsequent accesses will not tolerate a cache miss. + */ + KeepAlive = 1 << 7, +}; + +gsl_DEFINE_ENUM_BITMASK_OPERATORS(CachePolicy); +/** Append a non-empty text version of the policy to the builder. */ +StringBuilderBase& operator<<(StringBuilderBase& Builder, CachePolicy Policy); +/** Parse non-empty text written by operator<< into a policy. */ +CachePolicy ParseCachePolicy(std::string_view Text); +/** Return input converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server. */ +CachePolicy ConvertToUpstream(CachePolicy Policy); + +inline CachePolicy +Union(CachePolicy A, CachePolicy B) +{ + constexpr CachePolicy InvertedFlags = CachePolicy::SkipData | CachePolicy::SkipMeta; + return (A & ~(InvertedFlags)) | (B & ~(InvertedFlags)) | (A & B & InvertedFlags); +} + +/** A value ID and the cache policy to use for that value. */ +struct CacheValuePolicy +{ + Oid Id; + CachePolicy Policy = CachePolicy::Default; + + /** Flags that are valid on a value policy. */ + static constexpr CachePolicy PolicyMask = CachePolicy::Default | CachePolicy::SkipData; +}; + +/** Interface for the private implementation of the cache record policy. */ +class Private::ICacheRecordPolicyShared : public RefCounted +{ +public: + virtual ~ICacheRecordPolicyShared() = default; + virtual void AddValuePolicy(const CacheValuePolicy& Policy) = 0; + virtual std::span<const CacheValuePolicy> GetValuePolicies() const = 0; +}; + +/** + * Flags to control the behavior of cache record requests, with optional overrides by value. + * + * Examples: + * - A base policy of None with value policy overrides of Default will fetch those values if they + * exist in the record, and skip data for any other values. + * - A base policy of Default, with value policy overrides of (Query | SkipData), will skip those + * values, but still check if they exist, and will load any other values. + */ +class CacheRecordPolicy +{ +public: + /** Construct a cache record policy that uses the default policy. */ + CacheRecordPolicy() = default; + + /** Construct a cache record policy with a uniform policy for the record and every value. */ + inline CacheRecordPolicy(CachePolicy BasePolicy) + : RecordPolicy(BasePolicy) + , DefaultValuePolicy(BasePolicy & CacheValuePolicy::PolicyMask) + { + } + + /** Returns true if the record and every value use the same cache policy. */ + inline bool IsUniform() const { return !Shared; } + + /** Returns the cache policy to use for the record. */ + inline CachePolicy GetRecordPolicy() const { return RecordPolicy; } + + /** Returns the base cache policy that this was constructed from. */ + inline CachePolicy GetBasePolicy() const { return DefaultValuePolicy | (RecordPolicy & ~CacheValuePolicy::PolicyMask); } + + /** Returns the cache policy to use for the value. */ + CachePolicy GetValuePolicy(const Oid& Id) const; + + /** Returns the array of cache policy overrides for values, sorted by ID. */ + inline std::span<const CacheValuePolicy> GetValuePolicies() const + { + return Shared ? Shared->GetValuePolicies() : std::span<const CacheValuePolicy>(); + } + + /** Saves the cache record policy to a compact binary object. */ + void Save(CbWriter& Writer) const; + + /** Loads a cache record policy from an object. */ + static OptionalCacheRecordPolicy Load(CbObjectView Object); + + /** Return *this converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server. + */ + CacheRecordPolicy ConvertToUpstream() const; + +private: + friend class CacheRecordPolicyBuilder; + friend class OptionalCacheRecordPolicy; + + CachePolicy RecordPolicy = CachePolicy::Default; + CachePolicy DefaultValuePolicy = CachePolicy::Default; + RefPtr<const Private::ICacheRecordPolicyShared> Shared; +}; + +/** A cache record policy builder is used to construct a cache record policy. */ +class CacheRecordPolicyBuilder +{ +public: + /** Construct a policy builder that uses the default policy as its base policy. */ + CacheRecordPolicyBuilder() = default; + + /** Construct a policy builder that uses the provided policy for the record and values with no override. */ + inline explicit CacheRecordPolicyBuilder(CachePolicy Policy) : BasePolicy(Policy) {} + + /** Adds a cache policy override for a value. */ + void AddValuePolicy(const CacheValuePolicy& Value); + inline void AddValuePolicy(const Oid& Id, CachePolicy Policy) { AddValuePolicy({Id, Policy}); } + + /** Build a cache record policy, which makes this builder subsequently unusable. */ + CacheRecordPolicy Build(); + +private: + CachePolicy BasePolicy = CachePolicy::Default; + RefPtr<Private::ICacheRecordPolicyShared> Shared; +}; + +/** + * A cache record policy that can be null. + * + * @see CacheRecordPolicy + */ +class OptionalCacheRecordPolicy : private CacheRecordPolicy +{ +public: + inline OptionalCacheRecordPolicy() : CacheRecordPolicy(~CachePolicy::None) {} + + inline OptionalCacheRecordPolicy(CacheRecordPolicy&& InOutput) : CacheRecordPolicy(std::move(InOutput)) {} + inline OptionalCacheRecordPolicy(const CacheRecordPolicy& InOutput) : CacheRecordPolicy(InOutput) {} + inline OptionalCacheRecordPolicy& operator=(CacheRecordPolicy&& InOutput) + { + CacheRecordPolicy::operator=(std::move(InOutput)); + return *this; + } + inline OptionalCacheRecordPolicy& operator=(const CacheRecordPolicy& InOutput) + { + CacheRecordPolicy::operator=(InOutput); + return *this; + } + + /** Returns the cache record policy. The caller must check for null before using this accessor. */ + inline const CacheRecordPolicy& Get() const& { return *this; } + inline CacheRecordPolicy Get() && { return std::move(*this); } + + inline bool IsNull() const { return RecordPolicy == ~CachePolicy::None; } + inline bool IsValid() const { return !IsNull(); } + inline explicit operator bool() const { return !IsNull(); } + + inline void Reset() { *this = OptionalCacheRecordPolicy(); } +}; + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cache/cacherequests.h b/src/zenutil/include/zenutil/cache/cacherequests.h new file mode 100644 index 000000000..f1999ebfe --- /dev/null +++ b/src/zenutil/include/zenutil/cache/cacherequests.h @@ -0,0 +1,279 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compress.h> + +#include "cachekey.h" +#include "cachepolicy.h" + +#include <functional> + +namespace zen { + +class CbPackage; +class CbObjectWriter; +class CbObjectView; + +namespace cacherequests { + // I'd really like to get rid of std::optional<CacheRecordPolicy> (or really the class CacheRecordPolicy) + // + // CacheRecordPolicy has a record level policy but it can also contain policies for individual + // values inside the record. + // + // However, when we do a "PutCacheRecords" we already list the individual Values with their Id + // so we can just as well use an optional plain CachePolicy for each value. + // + // In "GetCacheRecords" we do not currently as for the individual values but you can add + // a policy on a per-value level in the std::optional<CacheRecordPolicy> Policy for each record. + // + // But as we already need to know the Ids of the values we want to set the policy for + // it would be simpler to add an array of requested values which each has an optional policy. + // + // We could add: + // struct GetCacheRecordValueRequest + // { + // Oid Id; + // std::optional<CachePolicy> Policy; + // }; + // + // and change GetCacheRecordRequest to + // struct GetCacheRecordRequest + // { + // CacheKey Key = CacheKey::Empty; + // std::vector<GetCacheRecordValueRequest> ValueRequests; + // std::optional<CachePolicy> Policy; + // }; + // + // This way we don't need the complex CacheRecordPolicy class and the request becomes + // more uniform and easier to understand. + // + // Would need to decide what the ValueRequests actually mean: + // Do they dictate which values to fetch or just a change of the policy? + // If they dictate the values to fetch you need to know all the value ids to set them + // and that is unlikely what we want - we want to be able to get a cache record with + // all its values without knowing all the Ids, right? + // + + ////////////////////////////////////////////////////////////////////////// + // Put 1..n structured cache records with optional attachments + + struct PutCacheRecordRequestValue + { + Oid Id = Oid::Zero; + IoHash RawHash = IoHash::Zero; // If Body is not set, this must be set and the value must already exist in cache + CompressedBuffer Body = CompressedBuffer::Null; + }; + + struct PutCacheRecordRequest + { + CacheKey Key = CacheKey::Empty; + std::vector<PutCacheRecordRequestValue> Values; + std::optional<CacheRecordPolicy> Policy; + }; + + struct PutCacheRecordsRequest + { + uint32_t AcceptMagic = 0; + CachePolicy DefaultPolicy = CachePolicy::Default; + std::string Namespace; + std::vector<PutCacheRecordRequest> Requests; + + bool Parse(const CbPackage& Package); + bool Format(CbPackage& OutPackage) const; + }; + + struct PutCacheRecordsResult + { + std::vector<bool> Success; + + bool Parse(const CbPackage& Package); + bool Format(CbPackage& OutPackage) const; + }; + + ////////////////////////////////////////////////////////////////////////// + // Get 1..n structured cache records with optional attachments + // We can get requests for a cache record where we want care about a particular + // value id which we now of, but we don't know the ids of the other values and + // we still want them. + // Not sure if in that case we want different policies for the different attachemnts? + + struct GetCacheRecordRequest + { + CacheKey Key = CacheKey::Empty; + std::optional<CacheRecordPolicy> Policy; + }; + + struct GetCacheRecordsRequest + { + uint32_t AcceptMagic = 0; + uint16_t AcceptOptions = 0; + int32_t ProcessPid = 0; + CachePolicy DefaultPolicy = CachePolicy::Default; + std::string Namespace; + std::vector<GetCacheRecordRequest> Requests; + + bool Parse(const CbPackage& RpcRequest); + bool Parse(const CbObjectView& RpcRequest); + bool Format(CbPackage& OutPackage, const std::span<const size_t> OptionalRecordFilter = {}) const; + bool Format(CbObjectWriter& Writer, const std::span<const size_t> OptionalRecordFilter = {}) const; + }; + + struct GetCacheRecordResultValue + { + Oid Id = Oid::Zero; + IoHash RawHash = IoHash::Zero; + uint64_t RawSize = 0; + CompressedBuffer Body = CompressedBuffer::Null; + }; + + struct GetCacheRecordResult + { + CacheKey Key = CacheKey::Empty; + std::vector<GetCacheRecordResultValue> Values; + }; + + struct GetCacheRecordsResult + { + std::vector<std::optional<GetCacheRecordResult>> Results; + + bool Parse(const CbPackage& Package, const std::span<const size_t> OptionalRecordResultIndexes = {}); + bool Format(CbPackage& OutPackage) const; + }; + + ////////////////////////////////////////////////////////////////////////// + // Put 1..n unstructured cache objects + + struct PutCacheValueRequest + { + CacheKey Key = CacheKey::Empty; + IoHash RawHash = IoHash::Zero; + CompressedBuffer Body = CompressedBuffer::Null; // If not set the value is expected to already exist in cache store + std::optional<CachePolicy> Policy; + }; + + struct PutCacheValuesRequest + { + uint32_t AcceptMagic = 0; + CachePolicy DefaultPolicy = CachePolicy::Default; + std::string Namespace; + std::vector<PutCacheValueRequest> Requests; + + bool Parse(const CbPackage& Package); + bool Format(CbPackage& OutPackage) const; + }; + + struct PutCacheValuesResult + { + std::vector<bool> Success; + + bool Parse(const CbPackage& Package); + bool Format(CbPackage& OutPackage) const; + }; + + ////////////////////////////////////////////////////////////////////////// + // Get 1..n unstructured cache objects (stored data may be structured or unstructured) + + struct GetCacheValueRequest + { + CacheKey Key = CacheKey::Empty; + std::optional<CachePolicy> Policy; + }; + + struct GetCacheValuesRequest + { + uint32_t AcceptMagic = 0; + uint16_t AcceptOptions = 0; + int32_t ProcessPid = 0; + CachePolicy DefaultPolicy = CachePolicy::Default; + std::string Namespace; + std::vector<GetCacheValueRequest> Requests; + + bool Parse(const CbObjectView& BatchObject); + bool Format(CbPackage& OutPackage, const std::span<const size_t> OptionalValueFilter = {}) const; + }; + + struct CacheValueResult + { + uint64_t RawSize = 0; + IoHash RawHash = IoHash::Zero; + CompressedBuffer Body = CompressedBuffer::Null; + }; + + struct CacheValuesResult + { + std::vector<CacheValueResult> Results; + + bool Parse(const CbPackage& Package, const std::span<const size_t> OptionalValueResultIndexes = {}); + bool Format(CbPackage& OutPackage) const; + }; + + typedef CacheValuesResult GetCacheValuesResult; + + ////////////////////////////////////////////////////////////////////////// + // Get 1..n cache record values (attachments) for 1..n records + + struct GetCacheChunkRequest + { + CacheKey Key; + Oid ValueId = Oid::Zero; // Set if ChunkId is not known at request time + IoHash ChunkId = IoHash::Zero; + uint64_t RawOffset = 0ull; + uint64_t RawSize = ~uint64_t(0); + std::optional<CachePolicy> Policy; + }; + + struct GetCacheChunksRequest + { + uint32_t AcceptMagic = 0; + uint16_t AcceptOptions = 0; + int32_t ProcessPid = 0; + CachePolicy DefaultPolicy = CachePolicy::Default; + std::string Namespace; + std::vector<GetCacheChunkRequest> Requests; + + bool Parse(const CbObjectView& BatchObject); + bool Format(CbPackage& OutPackage) const; + }; + + typedef CacheValuesResult GetCacheChunksResult; + + ////////////////////////////////////////////////////////////////////////// + + struct HttpRequestData + { + std::optional<std::string> Namespace; + std::optional<std::string> Bucket; + std::optional<IoHash> HashKey; + std::optional<IoHash> ValueContentId; + }; + + bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data); + + // Temporarily public + std::optional<std::string> GetRequestNamespace(const CbObjectView& Params); + bool GetRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key); + + ////////////////////////////////////////////////////////////////////////// + + // struct CacheRecordValue + // { + // Oid Id = Oid::Zero; + // IoHash RawHash = IoHash::Zero; + // uint64_t RawSize = 0; + // }; + // + // struct CacheRecord + // { + // CacheKey Key = CacheKey::Empty; + // std::vector<CacheRecordValue> Values; + // + // bool Parse(CbObjectView& Reader); + // bool Format(CbObjectWriter& Writer) const; + // }; + +} // namespace cacherequests + +void cacherequests_forcelink(); // internal + +} // namespace zen diff --git a/src/zenutil/include/zenutil/cache/rpcrecording.h b/src/zenutil/include/zenutil/cache/rpcrecording.h new file mode 100644 index 000000000..6d65a532a --- /dev/null +++ b/src/zenutil/include/zenutil/cache/rpcrecording.h @@ -0,0 +1,29 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/compositebuffer.h> +#include <zencore/iobuffer.h> + +namespace zen::cache { +class IRpcRequestRecorder +{ +public: + virtual ~IRpcRequestRecorder() {} + virtual uint64_t RecordRequest(const ZenContentType ContentType, const ZenContentType AcceptType, const IoBuffer& RequestBuffer) = 0; + virtual void RecordResponse(uint64_t RequestIndex, const ZenContentType ContentType, const IoBuffer& ResponseBuffer) = 0; + virtual void RecordResponse(uint64_t RequestIndex, const ZenContentType ContentType, const CompositeBuffer& ResponseBuffer) = 0; +}; +class IRpcRequestReplayer +{ +public: + virtual ~IRpcRequestReplayer() {} + virtual uint64_t GetRequestCount() const = 0; + virtual std::pair<ZenContentType, ZenContentType> GetRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) = 0; + virtual ZenContentType GetResponse(uint64_t RequestIndex, IoBuffer& OutBuffer) = 0; +}; + +std::unique_ptr<cache::IRpcRequestRecorder> MakeDiskRequestRecorder(const std::filesystem::path& BasePath); +std::unique_ptr<cache::IRpcRequestReplayer> MakeDiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory); + +} // namespace zen::cache diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h new file mode 100644 index 000000000..1c204c144 --- /dev/null +++ b/src/zenutil/include/zenutil/zenserverprocess.h @@ -0,0 +1,141 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/enumflags.h> +#include <zencore/logging.h> +#include <zencore/thread.h> +#include <zencore/uid.h> + +#include <atomic> +#include <filesystem> +#include <optional> + +namespace zen { + +class ZenServerEnvironment +{ +public: + ZenServerEnvironment(); + ~ZenServerEnvironment(); + + void Initialize(std::filesystem::path ProgramBaseDir); + void InitializeForTest(std::filesystem::path ProgramBaseDir, std::filesystem::path TestBaseDir, std::string_view ServerClass = ""); + + std::filesystem::path CreateNewTestDir(); + std::filesystem::path ProgramBaseDir() const { return m_ProgramBaseDir; } + std::filesystem::path GetTestRootDir(std::string_view Path); + inline bool IsInitialized() const { return m_IsInitialized; } + inline bool IsTestEnvironment() const { return m_IsTestInstance; } + inline std::string_view GetServerClass() const { return m_ServerClass; } + +private: + std::filesystem::path m_ProgramBaseDir; + std::filesystem::path m_TestBaseDir; + bool m_IsInitialized = false; + bool m_IsTestInstance = false; + std::string m_ServerClass; +}; + +struct ZenServerInstance +{ + ZenServerInstance(ZenServerEnvironment& TestEnvironment); + ~ZenServerInstance(); + + void Shutdown(); + void SignalShutdown(); + void WaitUntilReady(); + [[nodiscard]] bool WaitUntilReady(int Timeout); + void EnableTermination() { m_Terminate = true; } + void Detach(); + inline int GetPid() { return m_Process.Pid(); } + inline void SetOwnerPid(int Pid) { m_OwnerPid = Pid; } + + void SetTestDir(std::filesystem::path TestDir) + { + ZEN_ASSERT(!m_Process.IsValid()); + m_TestDir = TestDir; + } + + void SpawnServer(int BasePort = 0, std::string_view AdditionalServerArgs = std::string_view()); + + void AttachToRunningServer(int BasePort = 0); + + std::string GetBaseUri() const; + +private: + ZenServerEnvironment& m_Env; + ProcessHandle m_Process; + NamedEvent m_ReadyEvent; + NamedEvent m_ShutdownEvent; + bool m_Terminate = false; + std::filesystem::path m_TestDir; + int m_BasePort = 0; + std::optional<int> m_OwnerPid; + + void CreateShutdownEvent(int BasePort); +}; + +/** Shared system state + * + * Used as a scratchpad to identify running instances etc + * + * The state lives in a memory-mapped file backed by the swapfile + * + */ + +class ZenServerState +{ +public: + ZenServerState(); + ~ZenServerState(); + + struct ZenServerEntry + { + // NOTE: any changes to this should consider backwards compatibility + // which means you should not rearrange members only potentially + // add something to the end or use a different mechanism for + // additional state. For example, you can use the session ID + // to introduce additional named objects + std::atomic<uint32_t> Pid; + std::atomic<uint16_t> DesiredListenPort; + std::atomic<uint16_t> Flags; + uint8_t SessionId[12]; + std::atomic<uint32_t> SponsorPids[8]; + std::atomic<uint16_t> EffectiveListenPort; + uint8_t Padding[10]; + + enum class FlagsEnum : uint16_t + { + kShutdownPlease = 1 << 0, + kIsReady = 1 << 1, + }; + + FRIEND_ENUM_CLASS_FLAGS(FlagsEnum); + + Oid GetSessionId() const { return Oid::FromMemory(SessionId); } + void Reset(); + void SignalShutdownRequest(); + void SignalReady(); + bool AddSponsorProcess(uint32_t Pid); + }; + + static_assert(sizeof(ZenServerEntry) == 64); + + void Initialize(); + [[nodiscard]] bool InitializeReadOnly(); + [[nodiscard]] ZenServerEntry* Lookup(int DesiredListenPort); + ZenServerEntry* Register(int DesiredListenPort); + void Sweep(); + void Snapshot(std::function<void(const ZenServerEntry&)>&& Callback); + inline bool IsReadOnly() const { return m_IsReadOnly; } + +private: + void* m_hMapFile = nullptr; + ZenServerEntry* m_Data = nullptr; + int m_MaxEntryCount = 65536 / sizeof(ZenServerEntry); + ZenServerEntry* m_OurEntry = nullptr; + bool m_IsReadOnly = true; +}; + +} // namespace zen diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua new file mode 100644 index 000000000..e7d849bb2 --- /dev/null +++ b/src/zenutil/xmake.lua @@ -0,0 +1,9 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('zenutil') + set_kind("static") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs("include", {public=true}) + add_deps("zencore") + add_packages("vcpkg::spdlog") diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp new file mode 100644 index 000000000..5ecde343b --- /dev/null +++ b/src/zenutil/zenserverprocess.cpp @@ -0,0 +1,677 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "zenutil/zenserverprocess.h" + +#include <zencore/except.h> +#include <zencore/filesystem.h> +#include <zencore/fmtutils.h> +#include <zencore/logging.h> +#include <zencore/session.h> +#include <zencore/string.h> +#include <zencore/thread.h> + +#include <atomic> + +#if ZEN_PLATFORM_WINDOWS +# include <atlbase.h> +# include <zencore/windows.h> +#else +# include <sys/mman.h> +#endif + +////////////////////////////////////////////////////////////////////////// + +namespace zen { + +namespace zenutil { +#if ZEN_PLATFORM_WINDOWS + class SecurityAttributes + { + public: + inline SECURITY_ATTRIBUTES* Attributes() { return &m_Attributes; } + + protected: + SECURITY_ATTRIBUTES m_Attributes{}; + SECURITY_DESCRIPTOR m_Sd{}; + }; + + // Security attributes which allows any user access + + class AnyUserSecurityAttributes : public SecurityAttributes + { + public: + AnyUserSecurityAttributes() + { + m_Attributes.nLength = sizeof m_Attributes; + m_Attributes.bInheritHandle = false; // Disable inheritance + + const BOOL Success = InitializeSecurityDescriptor(&m_Sd, SECURITY_DESCRIPTOR_REVISION); + + if (Success) + { + if (!SetSecurityDescriptorDacl(&m_Sd, TRUE, (PACL)NULL, FALSE)) + { + ThrowLastError("SetSecurityDescriptorDacl failed"); + } + + m_Attributes.lpSecurityDescriptor = &m_Sd; + } + } + }; +#endif // ZEN_PLATFORM_WINDOWS + +} // namespace zenutil + +////////////////////////////////////////////////////////////////////////// + +ZenServerState::ZenServerState() +{ +} + +ZenServerState::~ZenServerState() +{ + if (m_OurEntry) + { + // Clean up our entry now that we're leaving + + m_OurEntry->Reset(); + m_OurEntry = nullptr; + } + +#if ZEN_PLATFORM_WINDOWS + if (m_Data) + { + UnmapViewOfFile(m_Data); + } + + if (m_hMapFile) + { + CloseHandle(m_hMapFile); + } +#else + if (m_Data != nullptr) + { + munmap(m_Data, m_MaxEntryCount * sizeof(ZenServerEntry)); + } + + int Fd = int(intptr_t(m_hMapFile)); + close(Fd); +#endif + + m_Data = nullptr; +} + +void +ZenServerState::Initialize() +{ + size_t MapSize = m_MaxEntryCount * sizeof(ZenServerEntry); + +#if ZEN_PLATFORM_WINDOWS + // TODO: there's a small chance of a race here, this logic could be tightened up with a mutex to + // ensure only a single process at a time creates the mapping + // TODO: the fallback to Local instead of Global has a flaw where if you start a non-elevated instance + // first then start an elevated instance second you'll have the first instance with a local + // mapping and the second instance with a global mapping. This kind of elevated/non-elevated + // shouldn't be common, but handling for it should be improved in the future. + + HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Global\\ZenMap"); + if (hMap == NULL) + { + hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Local\\ZenMap"); + } + + if (hMap == NULL) + { + // Security attributes to enable any user to access state + zenutil::AnyUserSecurityAttributes Attrs; + + hMap = CreateFileMapping(INVALID_HANDLE_VALUE, // use paging file + Attrs.Attributes(), // allow anyone to access + PAGE_READWRITE, // read/write access + 0, // maximum object size (high-order DWORD) + DWORD(MapSize), // maximum object size (low-order DWORD) + L"Global\\ZenMap"); // name of mapping object + + if (hMap == NULL) + { + hMap = CreateFileMapping(INVALID_HANDLE_VALUE, // use paging file + Attrs.Attributes(), // allow anyone to access + PAGE_READWRITE, // read/write access + 0, // maximum object size (high-order DWORD) + m_MaxEntryCount * sizeof(ZenServerEntry), // maximum object size (low-order DWORD) + L"Local\\ZenMap"); // name of mapping object + } + + if (hMap == NULL) + { + ThrowLastError("Could not open or create file mapping object for Zen server state"); + } + } + + void* pBuf = MapViewOfFile(hMap, // handle to map object + FILE_MAP_ALL_ACCESS, // read/write permission + 0, // offset high + 0, // offset low + DWORD(MapSize)); + + if (pBuf == NULL) + { + ThrowLastError("Could not map view of Zen server state"); + } +#else + int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, 0666); + if (Fd < 0) + { + ThrowLastError("Could not open a shared memory object"); + } + fchmod(Fd, 0666); + void* hMap = (void*)intptr_t(Fd); + + int Result = ftruncate(Fd, MapSize); + ZEN_UNUSED(Result); + + void* pBuf = mmap(nullptr, MapSize, PROT_READ | PROT_WRITE, MAP_SHARED, Fd, 0); + if (pBuf == MAP_FAILED) + { + ThrowLastError("Could not map view of Zen server state"); + } +#endif + + m_hMapFile = hMap; + m_Data = reinterpret_cast<ZenServerEntry*>(pBuf); + m_IsReadOnly = false; +} + +bool +ZenServerState::InitializeReadOnly() +{ + size_t MapSize = m_MaxEntryCount * sizeof(ZenServerEntry); + +#if ZEN_PLATFORM_WINDOWS + HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Global\\ZenMap"); + if (hMap == NULL) + { + hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Local\\ZenMap"); + } + + if (hMap == NULL) + { + return false; + } + + void* pBuf = MapViewOfFile(hMap, // handle to map object + FILE_MAP_READ, // read permission + 0, // offset high + 0, // offset low + MapSize); + + if (pBuf == NULL) + { + ThrowLastError("Could not map view of Zen server state"); + } +#else + int Fd = shm_open("/UnrealEngineZen", O_RDONLY | O_CLOEXEC, 0666); + if (Fd < 0) + { + return false; + } + void* hMap = (void*)intptr_t(Fd); + + void* pBuf = mmap(nullptr, MapSize, PROT_READ, MAP_PRIVATE, Fd, 0); + if (pBuf == MAP_FAILED) + { + ThrowLastError("Could not map read-only view of Zen server state"); + } +#endif + + m_hMapFile = hMap; + m_Data = reinterpret_cast<ZenServerEntry*>(pBuf); + + return true; +} + +ZenServerState::ZenServerEntry* +ZenServerState::Lookup(int DesiredListenPort) +{ + for (int i = 0; i < m_MaxEntryCount; ++i) + { + if (m_Data[i].DesiredListenPort == DesiredListenPort) + { + return &m_Data[i]; + } + } + + return nullptr; +} + +ZenServerState::ZenServerEntry* +ZenServerState::Register(int DesiredListenPort) +{ + if (m_Data == nullptr) + { + return nullptr; + } + + // Allocate an entry + + int Pid = GetCurrentProcessId(); + + for (int i = 0; i < m_MaxEntryCount; ++i) + { + ZenServerEntry& Entry = m_Data[i]; + + if (Entry.DesiredListenPort.load(std::memory_order_relaxed) == 0) + { + uint16_t Expected = 0; + if (Entry.DesiredListenPort.compare_exchange_strong(Expected, uint16_t(DesiredListenPort))) + { + // Successfully allocated entry + + m_OurEntry = &Entry; + + Entry.Pid = Pid; + Entry.EffectiveListenPort = 0; + Entry.Flags = 0; + + const Oid SesId = GetSessionId(); + memcpy(Entry.SessionId, &SesId, sizeof SesId); + + return &Entry; + } + } + } + + return nullptr; +} + +void +ZenServerState::Sweep() +{ + if (m_Data == nullptr) + { + return; + } + + ZEN_ASSERT(m_IsReadOnly == false); + + for (int i = 0; i < m_MaxEntryCount; ++i) + { + ZenServerEntry& Entry = m_Data[i]; + + if (Entry.DesiredListenPort) + { + if (IsProcessRunning(Entry.Pid) == false) + { + ZEN_DEBUG("Sweep - pid {} not running, reclaiming entry (port {})", Entry.Pid, Entry.DesiredListenPort); + + Entry.Reset(); + } + } + } +} + +void +ZenServerState::Snapshot(std::function<void(const ZenServerEntry&)>&& Callback) +{ + if (m_Data == nullptr) + { + return; + } + + for (int i = 0; i < m_MaxEntryCount; ++i) + { + ZenServerEntry& Entry = m_Data[i]; + + if (Entry.DesiredListenPort) + { + Callback(Entry); + } + } +} + +void +ZenServerState::ZenServerEntry::Reset() +{ + Pid = 0; + DesiredListenPort = 0; + Flags = 0; + EffectiveListenPort = 0; +} + +void +ZenServerState::ZenServerEntry::SignalShutdownRequest() +{ + Flags |= uint16_t(FlagsEnum::kShutdownPlease); +} + +void +ZenServerState::ZenServerEntry::SignalReady() +{ + Flags |= uint16_t(FlagsEnum::kIsReady); +} + +bool +ZenServerState::ZenServerEntry::AddSponsorProcess(uint32_t PidToAdd) +{ + for (std::atomic<uint32_t>& PidEntry : SponsorPids) + { + if (PidEntry.load(std::memory_order_relaxed) == 0) + { + uint32_t Expected = 0; + if (PidEntry.compare_exchange_strong(Expected, PidToAdd)) + { + // Success! + return true; + } + } + else if (PidEntry.load(std::memory_order_relaxed) == PidToAdd) + { + // Success, the because pid is already in the list + return true; + } + } + + return false; +} + +////////////////////////////////////////////////////////////////////////// + +std::atomic<int> ZenServerTestCounter{0}; + +ZenServerEnvironment::ZenServerEnvironment() +{ +} + +ZenServerEnvironment::~ZenServerEnvironment() +{ +} + +void +ZenServerEnvironment::Initialize(std::filesystem::path ProgramBaseDir) +{ + m_ProgramBaseDir = ProgramBaseDir; + + ZEN_DEBUG("Program base dir is '{}'", ProgramBaseDir); + + m_IsInitialized = true; +} + +void +ZenServerEnvironment::InitializeForTest(std::filesystem::path ProgramBaseDir, + std::filesystem::path TestBaseDir, + std::string_view ServerClass) +{ + using namespace std::literals; + + m_ProgramBaseDir = ProgramBaseDir; + m_TestBaseDir = TestBaseDir; + + ZEN_INFO("Program base dir is '{}'", ProgramBaseDir); + ZEN_INFO("Cleaning test base dir '{}'", TestBaseDir); + DeleteDirectories(TestBaseDir.c_str()); + + m_IsTestInstance = true; + m_IsInitialized = true; + + if (ServerClass.empty()) + { +#if ZEN_WITH_HTTPSYS + m_ServerClass = "httpsys"sv; +#else + m_ServerClass = "asio"sv; +#endif + } + else + { + m_ServerClass = ServerClass; + } +} + +std::filesystem::path +ZenServerEnvironment::CreateNewTestDir() +{ + using namespace std::literals; + + ExtendableWideStringBuilder<256> TestDir; + TestDir << "test"sv << int64_t(++ZenServerTestCounter); + + std::filesystem::path TestPath = m_TestBaseDir / TestDir.c_str(); + + ZEN_INFO("Creating new test dir @ '{}'", TestPath); + + CreateDirectories(TestPath.c_str()); + + return TestPath; +} + +std::filesystem::path +ZenServerEnvironment::GetTestRootDir(std::string_view Path) +{ + std::filesystem::path Root = m_ProgramBaseDir.parent_path().parent_path(); + + std::filesystem::path Relative{Path}; + + return Root / Relative; +} + +////////////////////////////////////////////////////////////////////////// + +std::atomic<int> ChildIdCounter{0}; + +ZenServerInstance::ZenServerInstance(ZenServerEnvironment& TestEnvironment) : m_Env(TestEnvironment) +{ + ZEN_ASSERT(TestEnvironment.IsInitialized()); +} + +ZenServerInstance::~ZenServerInstance() +{ + Shutdown(); +} + +void +ZenServerInstance::SignalShutdown() +{ + m_ShutdownEvent.Set(); +} + +void +ZenServerInstance::Shutdown() +{ + if (m_Process.IsValid()) + { + if (m_Terminate) + { + ZEN_INFO("Terminating zenserver process"); + m_Process.Terminate(111); + m_Process.Reset(); + } + else + { + SignalShutdown(); + m_Process.Wait(); + m_Process.Reset(); + } + } +} + +void +ZenServerInstance::SpawnServer(int BasePort, std::string_view AdditionalServerArgs) +{ + ZEN_ASSERT(!m_Process.IsValid()); // Only spawn once + + const int MyPid = zen::GetCurrentProcessId(); + const int ChildId = ++ChildIdCounter; + + ExtendableStringBuilder<32> ChildEventName; + ChildEventName << "Zen_Child_" << ChildId; + NamedEvent ChildEvent{ChildEventName}; + + CreateShutdownEvent(BasePort); + + ExtendableStringBuilder<32> LogId; + LogId << "Zen" << ChildId; + + ExtendableStringBuilder<512> CommandLine; + CommandLine << "zenserver" ZEN_EXE_SUFFIX_LITERAL; // see CreateProc() call for actual binary path + + const bool IsTest = m_Env.IsTestEnvironment(); + + if (IsTest) + { + if (!m_OwnerPid.has_value()) + { + m_OwnerPid = MyPid; + } + + CommandLine << " --test --log-id " << LogId; + } + + if (m_OwnerPid.has_value()) + { + CommandLine << " --owner-pid " << m_OwnerPid.value(); + } + + CommandLine << " --child-id " << ChildEventName; + + if (std::string_view ServerClass = m_Env.GetServerClass(); ServerClass.empty() == false) + { + CommandLine << " --http " << ServerClass; + } + + if (BasePort) + { + CommandLine << " --port " << BasePort; + m_BasePort = BasePort; + } + + if (!m_TestDir.empty()) + { + CommandLine << " --data-dir "; + PathToUtf8(m_TestDir.c_str(), CommandLine); + } + + if (!AdditionalServerArgs.empty()) + { + CommandLine << " " << AdditionalServerArgs; + } + + std::filesystem::path CurrentDirectory = std::filesystem::current_path(); + + ZEN_DEBUG("Spawning server '{}'", LogId); + + uint32_t CreationFlags = 0; + if (!IsTest) + { + CreationFlags |= CreateProcOptions::Flag_NewConsole; + } + + const std::filesystem::path BaseDir = m_Env.ProgramBaseDir(); + const std::filesystem::path Executable = BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL; + CreateProcOptions CreateOptions = { + .WorkingDirectory = &CurrentDirectory, + .Flags = CreationFlags, + }; + CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); +#if ZEN_PLATFORM_WINDOWS + if (!ChildPid && ::GetLastError() == ERROR_ELEVATION_REQUIRED) + { + ZEN_DEBUG("Regular spawn failed - spawning elevated server"); + CreateOptions.Flags |= CreateProcOptions::Flag_Elevated; + ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions); + } +#endif + + if (!ChildPid) + { + ThrowLastError("Server spawn failed"); + } + + ZEN_DEBUG("Server '{}' spawned OK", LogId); + + if (IsTest) + { + m_Process.Initialize(ChildPid); + } + + m_ReadyEvent = std::move(ChildEvent); +} + +void +ZenServerInstance::CreateShutdownEvent(int BasePort) +{ + ExtendableStringBuilder<32> ChildShutdownEventName; + ChildShutdownEventName << "Zen_" << BasePort; + ChildShutdownEventName << "_Shutdown"; + NamedEvent ChildShutdownEvent{ChildShutdownEventName}; + m_ShutdownEvent = std::move(ChildShutdownEvent); +} + +void +ZenServerInstance::AttachToRunningServer(int BasePort) +{ + ZenServerState State; + if (!State.InitializeReadOnly()) + { + // TODO: return success/error code instead? + throw std::runtime_error("No zen state found"); + } + + const ZenServerState::ZenServerEntry* Entry = nullptr; + + if (BasePort) + { + Entry = State.Lookup(BasePort); + } + else + { + State.Snapshot([&](const ZenServerState::ZenServerEntry& InEntry) { Entry = &InEntry; }); + } + + if (!Entry) + { + // TODO: return success/error code instead? + throw std::runtime_error("No server found"); + } + + m_Process.Initialize(Entry->Pid); + CreateShutdownEvent(Entry->EffectiveListenPort); +} + +void +ZenServerInstance::Detach() +{ + if (m_Process.IsValid()) + { + m_Process.Reset(); + m_ShutdownEvent.Close(); + } +} + +void +ZenServerInstance::WaitUntilReady() +{ + while (m_ReadyEvent.Wait(100) == false) + { + if (!m_Process.IsRunning() || !m_Process.IsValid()) + { + ZEN_INFO("Wait abandoned by invalid process (running={})", m_Process.IsRunning()); + return; + } + } +} + +bool +ZenServerInstance::WaitUntilReady(int Timeout) +{ + return m_ReadyEvent.Wait(Timeout); +} + +std::string +ZenServerInstance::GetBaseUri() const +{ + ZEN_ASSERT(m_BasePort); + + return fmt::format("http://localhost:{}", m_BasePort); +} + +} // namespace zen |