aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorStefan Boberg <[email protected]>2023-05-02 10:01:47 +0200
committerGitHub <[email protected]>2023-05-02 10:01:47 +0200
commit075d17f8ada47e990fe94606c3d21df409223465 (patch)
treee50549b766a2f3c354798a54ff73404217b4c9af /src
parentfix: bundle shouldn't append content zip to zen (diff)
downloadzen-075d17f8ada47e990fe94606c3d21df409223465.tar.xz
zen-075d17f8ada47e990fe94606c3d21df409223465.zip
moved source directories into `/src` (#264)
* moved source directories into `/src` * updated bundle.lua for new `src` path * moved some docs, icon * removed old test trees
Diffstat (limited to 'src')
-rw-r--r--src/UnrealEngine.icobin0 -> 65288 bytes
-rw-r--r--src/zen/chunk/chunk.cpp1216
-rw-r--r--src/zen/chunk/chunk.h25
-rw-r--r--src/zen/cmds/cache.cpp275
-rw-r--r--src/zen/cmds/cache.h68
-rw-r--r--src/zen/cmds/copy.cpp95
-rw-r--r--src/zen/cmds/copy.h28
-rw-r--r--src/zen/cmds/dedup.cpp302
-rw-r--r--src/zen/cmds/dedup.h28
-rw-r--r--src/zen/cmds/hash.cpp171
-rw-r--r--src/zen/cmds/hash.h27
-rw-r--r--src/zen/cmds/print.cpp193
-rw-r--r--src/zen/cmds/print.h41
-rw-r--r--src/zen/cmds/projectstore.cpp930
-rw-r--r--src/zen/cmds/projectstore.h180
-rw-r--r--src/zen/cmds/rpcreplay.cpp417
-rw-r--r--src/zen/cmds/rpcreplay.h65
-rw-r--r--src/zen/cmds/scrub.cpp154
-rw-r--r--src/zen/cmds/scrub.h58
-rw-r--r--src/zen/cmds/status.cpp41
-rw-r--r--src/zen/cmds/status.h22
-rw-r--r--src/zen/cmds/top.cpp89
-rw-r--r--src/zen/cmds/top.h35
-rw-r--r--src/zen/cmds/up.cpp108
-rw-r--r--src/zen/cmds/up.h36
-rw-r--r--src/zen/cmds/version.cpp79
-rw-r--r--src/zen/cmds/version.h24
-rw-r--r--src/zen/internalfile.cpp299
-rw-r--r--src/zen/internalfile.h62
-rw-r--r--src/zen/xmake.lua31
-rw-r--r--src/zen/zen.cpp421
-rw-r--r--src/zen/zen.h38
-rw-r--r--src/zen/zen.rc33
-rw-r--r--src/zencore-test/targetver.h10
-rw-r--r--src/zencore-test/xmake.lua8
-rw-r--r--src/zencore-test/zencore-test.cpp26
-rw-r--r--src/zencore/.gitignore1
-rw-r--r--src/zencore/base64.cpp107
-rw-r--r--src/zencore/blake3.cpp175
-rw-r--r--src/zencore/compactbinary.cpp2299
-rw-r--r--src/zencore/compactbinarybuilder.cpp1545
-rw-r--r--src/zencore/compactbinarypackage.cpp1350
-rw-r--r--src/zencore/compactbinaryvalidation.cpp664
-rw-r--r--src/zencore/compositebuffer.cpp446
-rw-r--r--src/zencore/compress.cpp1353
-rw-r--r--src/zencore/crc32.cpp545
-rw-r--r--src/zencore/crypto.cpp208
-rw-r--r--src/zencore/except.cpp93
-rw-r--r--src/zencore/filesystem.cpp1304
-rw-r--r--src/zencore/include/zencore/atomic.h74
-rw-r--r--src/zencore/include/zencore/base64.h17
-rw-r--r--src/zencore/include/zencore/blake3.h62
-rw-r--r--src/zencore/include/zencore/blockingqueue.h76
-rw-r--r--src/zencore/include/zencore/compactbinary.h1475
-rw-r--r--src/zencore/include/zencore/compactbinarybuilder.h661
-rw-r--r--src/zencore/include/zencore/compactbinarypackage.h341
-rw-r--r--src/zencore/include/zencore/compactbinaryvalidation.h197
-rw-r--r--src/zencore/include/zencore/compactbinaryvalue.h290
-rw-r--r--src/zencore/include/zencore/compositebuffer.h142
-rw-r--r--src/zencore/include/zencore/compress.h165
-rw-r--r--src/zencore/include/zencore/config.h.in16
-rw-r--r--src/zencore/include/zencore/crc32.h13
-rw-r--r--src/zencore/include/zencore/crypto.h77
-rw-r--r--src/zencore/include/zencore/endian.h113
-rw-r--r--src/zencore/include/zencore/enumflags.h61
-rw-r--r--src/zencore/include/zencore/except.h57
-rw-r--r--src/zencore/include/zencore/filesystem.h190
-rw-r--r--src/zencore/include/zencore/fmtutils.h52
-rw-r--r--src/zencore/include/zencore/intmath.h183
-rw-r--r--src/zencore/include/zencore/iobuffer.h423
-rw-r--r--src/zencore/include/zencore/iohash.h115
-rw-r--r--src/zencore/include/zencore/logging.h136
-rw-r--r--src/zencore/include/zencore/md5.h50
-rw-r--r--src/zencore/include/zencore/memory.h401
-rw-r--r--src/zencore/include/zencore/meta.h30
-rw-r--r--src/zencore/include/zencore/mpscqueue.h110
-rw-r--r--src/zencore/include/zencore/refcount.h186
-rw-r--r--src/zencore/include/zencore/scopeguard.h45
-rw-r--r--src/zencore/include/zencore/session.h14
-rw-r--r--src/zencore/include/zencore/sha1.h76
-rw-r--r--src/zencore/include/zencore/sharedbuffer.h167
-rw-r--r--src/zencore/include/zencore/stats.h295
-rw-r--r--src/zencore/include/zencore/stream.h90
-rw-r--r--src/zencore/include/zencore/string.h1115
-rw-r--r--src/zencore/include/zencore/testing.h67
-rw-r--r--src/zencore/include/zencore/testutils.h32
-rw-r--r--src/zencore/include/zencore/thread.h273
-rw-r--r--src/zencore/include/zencore/timer.h58
-rw-r--r--src/zencore/include/zencore/trace.h36
-rw-r--r--src/zencore/include/zencore/uid.h87
-rw-r--r--src/zencore/include/zencore/varint.h277
-rw-r--r--src/zencore/include/zencore/windows.h25
-rw-r--r--src/zencore/include/zencore/workthreadpool.h48
-rw-r--r--src/zencore/include/zencore/xxhash.h89
-rw-r--r--src/zencore/include/zencore/zencore.h383
-rw-r--r--src/zencore/intmath.cpp65
-rw-r--r--src/zencore/iobuffer.cpp653
-rw-r--r--src/zencore/iohash.cpp87
-rw-r--r--src/zencore/logging.cpp85
-rw-r--r--src/zencore/md5.cpp463
-rw-r--r--src/zencore/memory.cpp211
-rw-r--r--src/zencore/mpscqueue.cpp25
-rw-r--r--src/zencore/refcount.cpp65
-rw-r--r--src/zencore/session.cpp35
-rw-r--r--src/zencore/sha1.cpp443
-rw-r--r--src/zencore/sharedbuffer.cpp146
-rw-r--r--src/zencore/stats.cpp715
-rw-r--r--src/zencore/stream.cpp79
-rw-r--r--src/zencore/string.cpp1004
-rw-r--r--src/zencore/testing.cpp54
-rw-r--r--src/zencore/testutils.cpp42
-rw-r--r--src/zencore/thread.cpp1212
-rw-r--r--src/zencore/timer.cpp105
-rw-r--r--src/zencore/trace.cpp45
-rw-r--r--src/zencore/uid.cpp148
-rw-r--r--src/zencore/workthreadpool.cpp83
-rw-r--r--src/zencore/xmake.lua61
-rw-r--r--src/zencore/xxhash.cpp50
-rw-r--r--src/zencore/zencore.cpp175
-rw-r--r--src/zenhttp/httpasio.cpp1372
-rw-r--r--src/zenhttp/httpasio.h36
-rw-r--r--src/zenhttp/httpclient.cpp176
-rw-r--r--src/zenhttp/httpnull.cpp83
-rw-r--r--src/zenhttp/httpnull.h29
-rw-r--r--src/zenhttp/httpserver.cpp885
-rw-r--r--src/zenhttp/httpshared.cpp809
-rw-r--r--src/zenhttp/httpsys.cpp1674
-rw-r--r--src/zenhttp/httpsys.h90
-rw-r--r--src/zenhttp/include/zenhttp/httpclient.h47
-rw-r--r--src/zenhttp/include/zenhttp/httpcommon.h181
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h315
-rw-r--r--src/zenhttp/include/zenhttp/httpshared.h163
-rw-r--r--src/zenhttp/include/zenhttp/websocket.h256
-rw-r--r--src/zenhttp/include/zenhttp/zenhttp.h21
-rw-r--r--src/zenhttp/iothreadpool.cpp49
-rw-r--r--src/zenhttp/iothreadpool.h37
-rw-r--r--src/zenhttp/websocketasio.cpp1613
-rw-r--r--src/zenhttp/xmake.lua14
-rw-r--r--src/zenhttp/zenhttp.cpp22
-rw-r--r--src/zenserver-test/cachepolicy-tests.cpp153
-rw-r--r--src/zenserver-test/projectclient.cpp164
-rw-r--r--src/zenserver-test/projectclient.h32
-rw-r--r--src/zenserver-test/xmake.lua16
-rw-r--r--src/zenserver-test/zenserver-test.cpp3323
-rw-r--r--src/zenserver/admin/admin.cpp101
-rw-r--r--src/zenserver/admin/admin.h26
-rw-r--r--src/zenserver/auth/authmgr.cpp506
-rw-r--r--src/zenserver/auth/authmgr.h56
-rw-r--r--src/zenserver/auth/authservice.cpp91
-rw-r--r--src/zenserver/auth/authservice.h25
-rw-r--r--src/zenserver/auth/oidc.cpp127
-rw-r--r--src/zenserver/auth/oidc.h76
-rw-r--r--src/zenserver/cache/cachetracking.cpp376
-rw-r--r--src/zenserver/cache/cachetracking.h41
-rw-r--r--src/zenserver/cache/structuredcache.cpp3159
-rw-r--r--src/zenserver/cache/structuredcache.h187
-rw-r--r--src/zenserver/cache/structuredcachestore.cpp3648
-rw-r--r--src/zenserver/cache/structuredcachestore.h535
-rw-r--r--src/zenserver/cidstore.cpp124
-rw-r--r--src/zenserver/cidstore.h35
-rw-r--r--src/zenserver/compute/function.cpp629
-rw-r--r--src/zenserver/compute/function.h73
-rw-r--r--src/zenserver/config.cpp902
-rw-r--r--src/zenserver/config.h158
-rw-r--r--src/zenserver/diag/diagsvcs.cpp127
-rw-r--r--src/zenserver/diag/diagsvcs.h111
-rw-r--r--src/zenserver/diag/formatters.h71
-rw-r--r--src/zenserver/diag/logging.cpp467
-rw-r--r--src/zenserver/diag/logging.h10
-rw-r--r--src/zenserver/frontend/frontend.cpp128
-rw-r--r--src/zenserver/frontend/frontend.h25
-rw-r--r--src/zenserver/frontend/html/index.html59
-rw-r--r--src/zenserver/frontend/zipfs.cpp169
-rw-r--r--src/zenserver/frontend/zipfs.h26
-rw-r--r--src/zenserver/monitoring/httpstats.cpp62
-rw-r--r--src/zenserver/monitoring/httpstats.h38
-rw-r--r--src/zenserver/monitoring/httpstatus.cpp62
-rw-r--r--src/zenserver/monitoring/httpstatus.h38
-rw-r--r--src/zenserver/objectstore/objectstore.cpp232
-rw-r--r--src/zenserver/objectstore/objectstore.h48
-rw-r--r--src/zenserver/projectstore/fileremoteprojectstore.cpp235
-rw-r--r--src/zenserver/projectstore/fileremoteprojectstore.h19
-rw-r--r--src/zenserver/projectstore/jupiterremoteprojectstore.cpp244
-rw-r--r--src/zenserver/projectstore/jupiterremoteprojectstore.h26
-rw-r--r--src/zenserver/projectstore/projectstore.cpp4082
-rw-r--r--src/zenserver/projectstore/projectstore.h372
-rw-r--r--src/zenserver/projectstore/remoteprojectstore.cpp1036
-rw-r--r--src/zenserver/projectstore/remoteprojectstore.h111
-rw-r--r--src/zenserver/projectstore/zenremoteprojectstore.cpp341
-rw-r--r--src/zenserver/projectstore/zenremoteprojectstore.h18
-rw-r--r--src/zenserver/resource.h18
-rw-r--r--src/zenserver/targetver.h10
-rw-r--r--src/zenserver/testing/httptest.cpp207
-rw-r--r--src/zenserver/testing/httptest.h55
-rw-r--r--src/zenserver/upstream/hordecompute.cpp1457
-rw-r--r--src/zenserver/upstream/jupiter.cpp965
-rw-r--r--src/zenserver/upstream/jupiter.h217
-rw-r--r--src/zenserver/upstream/upstream.h8
-rw-r--r--src/zenserver/upstream/upstreamapply.cpp459
-rw-r--r--src/zenserver/upstream/upstreamapply.h192
-rw-r--r--src/zenserver/upstream/upstreamcache.cpp2112
-rw-r--r--src/zenserver/upstream/upstreamcache.h252
-rw-r--r--src/zenserver/upstream/upstreamservice.cpp56
-rw-r--r--src/zenserver/upstream/upstreamservice.h27
-rw-r--r--src/zenserver/upstream/zen.cpp326
-rw-r--r--src/zenserver/upstream/zen.h125
-rw-r--r--src/zenserver/windows/service.cpp646
-rw-r--r--src/zenserver/windows/service.h20
-rw-r--r--src/zenserver/xmake.lua60
-rw-r--r--src/zenserver/zenserver.cpp1261
-rw-r--r--src/zenserver/zenserver.rc105
-rw-r--r--src/zenstore-test/xmake.lua8
-rw-r--r--src/zenstore-test/zenstore-test.cpp32
-rw-r--r--src/zenstore/blockstore.cpp1312
-rw-r--r--src/zenstore/cas.cpp355
-rw-r--r--src/zenstore/cas.h67
-rw-r--r--src/zenstore/caslog.cpp236
-rw-r--r--src/zenstore/cidstore.cpp125
-rw-r--r--src/zenstore/compactcas.cpp1511
-rw-r--r--src/zenstore/compactcas.h95
-rw-r--r--src/zenstore/filecas.cpp1452
-rw-r--r--src/zenstore/filecas.h102
-rw-r--r--src/zenstore/gc.cpp1312
-rw-r--r--src/zenstore/hashkeyset.cpp60
-rw-r--r--src/zenstore/include/zenstore/blockstore.h175
-rw-r--r--src/zenstore/include/zenstore/caslog.h91
-rw-r--r--src/zenstore/include/zenstore/cidstore.h87
-rw-r--r--src/zenstore/include/zenstore/gc.h242
-rw-r--r--src/zenstore/include/zenstore/hashkeyset.h54
-rw-r--r--src/zenstore/include/zenstore/scrubcontext.h41
-rw-r--r--src/zenstore/include/zenstore/zenstore.h13
-rw-r--r--src/zenstore/xmake.lua9
-rw-r--r--src/zenstore/zenstore.cpp32
-rw-r--r--src/zentest-appstub/xmake.lua16
-rw-r--r--src/zentest-appstub/zentest-appstub.cpp34
-rw-r--r--src/zenutil/basicfile.cpp575
-rw-r--r--src/zenutil/cache/cachekey.cpp9
-rw-r--r--src/zenutil/cache/cachepolicy.cpp282
-rw-r--r--src/zenutil/cache/cacherequests.cpp1643
-rw-r--r--src/zenutil/cache/rpcrecording.cpp210
-rw-r--r--src/zenutil/include/zenutil/basicfile.h113
-rw-r--r--src/zenutil/include/zenutil/cache/cache.h6
-rw-r--r--src/zenutil/include/zenutil/cache/cachekey.h86
-rw-r--r--src/zenutil/include/zenutil/cache/cachepolicy.h227
-rw-r--r--src/zenutil/include/zenutil/cache/cacherequests.h279
-rw-r--r--src/zenutil/include/zenutil/cache/rpcrecording.h29
-rw-r--r--src/zenutil/include/zenutil/zenserverprocess.h141
-rw-r--r--src/zenutil/xmake.lua9
-rw-r--r--src/zenutil/zenserverprocess.cpp677
249 files changed, 81583 insertions, 0 deletions
diff --git a/src/UnrealEngine.ico b/src/UnrealEngine.ico
new file mode 100644
index 000000000..1cfa301a2
--- /dev/null
+++ b/src/UnrealEngine.ico
Binary files differ
diff --git a/src/zen/chunk/chunk.cpp b/src/zen/chunk/chunk.cpp
new file mode 100644
index 000000000..d3591f8ca
--- /dev/null
+++ b/src/zen/chunk/chunk.cpp
@@ -0,0 +1,1216 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "chunk.h"
+
+#if 0
+# include <gsl/gsl-lite.hpp>
+
+# include <zencore/filesystem.h>
+# include <zencore/iohash.h>
+# include <zencore/logging.h>
+# include <zencore/refcount.h>
+# include <zencore/scopeguard.h>
+# include <zencore/sha1.h>
+# include <zencore/string.h>
+# include <zencore/testing.h>
+# include <zencore/thread.h>
+# include <zencore/timer.h>
+# include <zenstore/gc.h>
+
+# include "../internalfile.h"
+
+# include <lz4.h>
+# include <zstd.h>
+
+# if ZEN_PLATFORM_WINDOWS
+# include <ppl.h>
+# include <ppltasks.h>
+# endif // ZEN_PLATFORM_WINDOWS
+
+# include <cmath>
+# include <filesystem>
+# include <random>
+# include <vector>
+
+//////////////////////////////////////////////////////////////////////////
+
+# if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+
+namespace Concurrency {
+
+template<typename IterType, typename LambdaType>
+void
+parallel_for_each(IterType Cursor, IterType End, const LambdaType& Lambda)
+{
+ for (; Cursor < End; ++Cursor)
+ {
+ Lambda(*Cursor);
+ }
+}
+
+template<typename T>
+struct combinable
+{
+ T& local() { return Value; }
+
+ template<typename LambdaType>
+ void combine_each(const LambdaType& Lambda)
+ {
+ Lambda(Value);
+ }
+
+ T Value = {};
+};
+
+struct task_group
+{
+ template<class Function>
+ void run(const Function& Func)
+ {
+ Func();
+ }
+
+ void wait() {}
+};
+
+} // namespace Concurrency
+
+# endif // ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+static const uint32_t buzhashTable[] = {
+ 0x458be752, 0xc10748cc, 0xfbbcdbb8, 0x6ded5b68, 0xb10a82b5, 0x20d75648, 0xdfc5665f, 0xa8428801, 0x7ebf5191, 0x841135c7, 0x65cc53b3,
+ 0x280a597c, 0x16f60255, 0xc78cbc3e, 0x294415f5, 0xb938d494, 0xec85c4e6, 0xb7d33edc, 0xe549b544, 0xfdeda5aa, 0x882bf287, 0x3116737c,
+ 0x05569956, 0xe8cc1f68, 0x0806ac5e, 0x22a14443, 0x15297e10, 0x50d090e7, 0x4ba60f6f, 0xefd9f1a7, 0x5c5c885c, 0x82482f93, 0x9bfd7c64,
+ 0x0b3e7276, 0xf2688e77, 0x8fad8abc, 0xb0509568, 0xf1ada29f, 0xa53efdfe, 0xcb2b1d00, 0xf2a9e986, 0x6463432b, 0x95094051, 0x5a223ad2,
+ 0x9be8401b, 0x61e579cb, 0x1a556a14, 0x5840fdc2, 0x9261ddf6, 0xcde002bb, 0x52432bb0, 0xbf17373e, 0x7b7c222f, 0x2955ed16, 0x9f10ca59,
+ 0xe840c4c9, 0xccabd806, 0x14543f34, 0x1462417a, 0x0d4a1f9c, 0x087ed925, 0xd7f8f24c, 0x7338c425, 0xcf86c8f5, 0xb19165cd, 0x9891c393,
+ 0x325384ac, 0x0308459d, 0x86141d7e, 0xc922116a, 0xe2ffa6b6, 0x53f52aed, 0x2cd86197, 0xf5b9f498, 0xbf319c8f, 0xe0411fae, 0x977eb18c,
+ 0xd8770976, 0x9833466a, 0xc674df7f, 0x8c297d45, 0x8ca48d26, 0xc49ed8e2, 0x7344f874, 0x556f79c7, 0x6b25eaed, 0xa03e2b42, 0xf68f66a4,
+ 0x8e8b09a2, 0xf2e0e62a, 0x0d3a9806, 0x9729e493, 0x8c72b0fc, 0x160b94f6, 0x450e4d3d, 0x7a320e85, 0xbef8f0e1, 0x21d73653, 0x4e3d977a,
+ 0x1e7b3929, 0x1cc6c719, 0xbe478d53, 0x8d752809, 0xe6d8c2c6, 0x275f0892, 0xc8acc273, 0x4cc21580, 0xecc4a617, 0xf5f7be70, 0xe795248a,
+ 0x375a2fe9, 0x425570b6, 0x8898dcf8, 0xdc2d97c4, 0x0106114b, 0x364dc22f, 0x1e0cad1f, 0xbe63803c, 0x5f69fac2, 0x4d5afa6f, 0x1bc0dfb5,
+ 0xfb273589, 0x0ea47f7b, 0x3c1c2b50, 0x21b2a932, 0x6b1223fd, 0x2fe706a8, 0xf9bd6ce2, 0xa268e64e, 0xe987f486, 0x3eacf563, 0x1ca2018c,
+ 0x65e18228, 0x2207360a, 0x57cf1715, 0x34c37d2b, 0x1f8f3cde, 0x93b657cf, 0x31a019fd, 0xe69eb729, 0x8bca7b9b, 0x4c9d5bed, 0x277ebeaf,
+ 0xe0d8f8ae, 0xd150821c, 0x31381871, 0xafc3f1b0, 0x927db328, 0xe95effac, 0x305a47bd, 0x426ba35b, 0x1233af3f, 0x686a5b83, 0x50e072e5,
+ 0xd9d3bb2a, 0x8befc475, 0x487f0de6, 0xc88dff89, 0xbd664d5e, 0x971b5d18, 0x63b14847, 0xd7d3c1ce, 0x7f583cf3, 0x72cbcb09, 0xc0d0a81c,
+ 0x7fa3429b, 0xe9158a1b, 0x225ea19a, 0xd8ca9ea3, 0xc763b282, 0xbb0c6341, 0x020b8293, 0xd4cd299d, 0x58cfa7f8, 0x91b4ee53, 0x37e4d140,
+ 0x95ec764c, 0x30f76b06, 0x5ee68d24, 0x679c8661, 0xa41979c2, 0xf2b61284, 0x4fac1475, 0x0adb49f9, 0x19727a23, 0x15a7e374, 0xc43a18d5,
+ 0x3fb1aa73, 0x342fc615, 0x924c0793, 0xbee2d7f0, 0x8a279de9, 0x4aa2d70c, 0xe24dd37f, 0xbe862c0b, 0x177c22c2, 0x5388e5ee, 0xcd8a7510,
+ 0xf901b4fd, 0xdbc13dbc, 0x6c0bae5b, 0x64efe8c7, 0x48b02079, 0x80331a49, 0xca3d8ae6, 0xf3546190, 0xfed7108b, 0xc49b941b, 0x32baf4a9,
+ 0xeb833a4a, 0x88a3f1a5, 0x3a91ce0a, 0x3cc27da1, 0x7112e684, 0x4a3096b1, 0x3794574c, 0xa3c8b6f3, 0x1d213941, 0x6e0a2e00, 0x233479f1,
+ 0x0f4cd82f, 0x6093edd2, 0x5d7d209e, 0x464fe319, 0xd4dcac9e, 0x0db845cb, 0xfb5e4bc3, 0xe0256ce1, 0x09fb4ed1, 0x0914be1e, 0xa5bdb2c3,
+ 0xc6eb57bb, 0x30320350, 0x3f397e91, 0xa67791bc, 0x86bc0e2c, 0xefa0a7e2, 0xe9ff7543, 0xe733612c, 0xd185897b, 0x329e5388, 0x91dd236b,
+ 0x2ecb0d93, 0xf4d82a3d, 0x35b5c03f, 0xe4e606f0, 0x05b21843, 0x37b45964, 0x5eff22f4, 0x6027f4cc, 0x77178b3c, 0xae507131, 0x7bf7cabc,
+ 0xf9c18d66, 0x593ade65, 0xd95ddf11,
+};
+
+// ROL operation (compiler turns this into a ROL when optimizing)
+static inline uint32_t
+Rotate32(uint32_t Value, size_t RotateCount)
+{
+ RotateCount &= 31;
+
+ return ((Value) << (RotateCount)) | ((Value) >> (32 - RotateCount));
+}
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+class ZenChunker
+{
+public:
+ void SetChunkSize(size_t MinSize, size_t MaxSize, size_t AvgSize);
+ size_t ScanChunk(const void* DataBytes, size_t ByteCount);
+ void Reset();
+
+ // This controls which chunking approach is used - threshold or
+ // modulo based. Threshold is faster and generates similarly sized
+ // chunks
+ void SetUseThreshold(bool NewState) { m_useThreshold = NewState; }
+
+ inline size_t ChunkSizeMin() const { return m_chunkSizeMin; }
+ inline size_t ChunkSizeMax() const { return m_chunkSizeMax; }
+ inline size_t ChunkSizeAvg() const { return m_chunkSizeAvg; }
+ inline uint64_t BytesScanned() const { return m_bytesScanned; }
+
+ static constexpr size_t NoBoundaryFound = size_t(~0ull);
+
+private:
+ size_t m_chunkSizeMin = 0;
+ size_t m_chunkSizeMax = 0;
+ size_t m_chunkSizeAvg = 0;
+
+ uint32_t m_discriminator = 0; // Computed in SetChunkSize()
+ uint32_t m_threshold = 0; // Computed in SetChunkSize()
+
+ bool m_useThreshold = true;
+
+ static constexpr size_t kChunkSizeLimitMax = 64 * 1024 * 1024;
+ static constexpr size_t kChunkSizeLimitMin = 1024;
+
+ static constexpr size_t kDefaultAverageChunkSize = 64 * 1024;
+
+ static constexpr int kWindowSize = 48;
+ uint8_t m_window[kWindowSize];
+ uint32_t m_windowSize = 0;
+
+ uint32_t m_currentHash = 0;
+ uint32_t m_currentChunkSize = 0;
+
+ uint64_t m_bytesScanned = 0;
+
+ size_t InternalScanChunk(const void* DataBytes, size_t ByteCount);
+ void InternalReset();
+};
+
+void
+ZenChunker::Reset()
+{
+ InternalReset();
+
+ m_bytesScanned = 0;
+}
+
+void
+ZenChunker::InternalReset()
+{
+ m_currentHash = 0;
+ m_currentChunkSize = 0;
+ m_windowSize = 0;
+}
+
+void
+ZenChunker::SetChunkSize(size_t MinSize, size_t MaxSize, size_t AvgSize)
+{
+ if (m_windowSize)
+ return; // Already started
+
+ static_assert(kChunkSizeLimitMin > kWindowSize);
+
+ if (AvgSize)
+ {
+ // TODO: Validate AvgSize range
+ }
+ else
+ {
+ if (MinSize && MaxSize)
+ {
+ AvgSize = lrint(pow(2, (log2(MinSize) + log2(MaxSize)) / 2));
+ }
+ else if (MinSize)
+ {
+ AvgSize = MinSize * 4;
+ }
+ else if (MaxSize)
+ {
+ AvgSize = MaxSize / 4;
+ }
+ else
+ {
+ AvgSize = kDefaultAverageChunkSize;
+ }
+ }
+
+ if (MinSize)
+ {
+ // TODO: Validate MinSize range
+ }
+ else
+ {
+ MinSize = std::max(AvgSize / 4, kChunkSizeLimitMin);
+ }
+
+ if (MaxSize)
+ {
+ // TODO: Validate MaxSize range
+ }
+ else
+ {
+ MaxSize = std::min(AvgSize * 4, kChunkSizeLimitMax);
+ }
+
+ m_discriminator = gsl::narrow<uint32_t>(AvgSize - MinSize);
+
+ if (m_discriminator < MinSize)
+ {
+ m_discriminator = gsl::narrow<uint32_t>(MinSize);
+ }
+
+ if (m_discriminator > MaxSize)
+ {
+ m_discriminator = gsl::narrow<uint32_t>(MaxSize);
+ }
+
+ m_threshold = gsl::narrow<uint32_t>((uint64_t(std::numeric_limits<uint32_t>::max()) + 1) / m_discriminator);
+
+ m_chunkSizeMin = MinSize;
+ m_chunkSizeMax = MaxSize;
+ m_chunkSizeAvg = AvgSize;
+}
+
+size_t
+ZenChunker::ScanChunk(const void* DataBytesIn, size_t ByteCount)
+{
+ size_t Result = InternalScanChunk(DataBytesIn, ByteCount);
+
+ if (Result == NoBoundaryFound)
+ {
+ m_bytesScanned += ByteCount;
+ }
+ else
+ {
+ m_bytesScanned += Result;
+ }
+
+ return Result;
+}
+
+size_t
+ZenChunker::InternalScanChunk(const void* DataBytesIn, size_t ByteCount)
+{
+ size_t CurrentOffset = 0;
+ const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(DataBytesIn);
+
+ // There's no point in updating the hash if we know we're not
+ // going to have a cut point, so just skip the data. This logic currently
+ // provides roughly a 20% speedup on my machine
+
+ const size_t NeedHashOffset = m_chunkSizeMin - kWindowSize;
+
+ if (m_currentChunkSize < NeedHashOffset)
+ {
+ const uint32_t SkipBytes = gsl::narrow<uint32_t>(std::min<uint64_t>(ByteCount, NeedHashOffset - m_currentChunkSize));
+
+ ByteCount -= SkipBytes;
+ m_currentChunkSize += SkipBytes;
+ CurrentOffset += SkipBytes;
+ CursorPtr += SkipBytes;
+
+ m_windowSize = 0;
+
+ if (ByteCount == 0)
+ {
+ return NoBoundaryFound;
+ }
+ }
+
+ // Fill window first
+
+ if (m_windowSize < kWindowSize)
+ {
+ const uint32_t FillBytes = uint32_t(std::min<size_t>(ByteCount, kWindowSize - m_windowSize));
+
+ memcpy(&m_window[m_windowSize], CursorPtr, FillBytes);
+
+ CursorPtr += FillBytes;
+
+ m_windowSize += FillBytes;
+ m_currentChunkSize += FillBytes;
+
+ CurrentOffset += FillBytes;
+ ByteCount -= FillBytes;
+
+ if (m_windowSize < kWindowSize)
+ {
+ return NoBoundaryFound;
+ }
+
+ // We have a full window, initialize hash
+
+ uint32_t CurrentHash = 0;
+
+ for (int i = 1; i < kWindowSize; ++i)
+ {
+ CurrentHash ^= detail::Rotate32(detail::buzhashTable[m_window[i - 1]], kWindowSize - i);
+ }
+
+ m_currentHash = CurrentHash ^ detail::buzhashTable[m_window[kWindowSize - 1]];
+ }
+
+ // Scan for boundaries (i.e points where the hash matches the value determined by
+ // the discriminator)
+
+ uint32_t CurrentHash = m_currentHash;
+ uint32_t CurrentChunkSize = m_currentChunkSize;
+
+ size_t Index = CurrentChunkSize % kWindowSize;
+
+ if (m_threshold && m_useThreshold)
+ {
+ // This is roughly 4x faster than the general modulo approach on my
+ // TR 3990X (~940MB/sec) and doesn't require any special parameters to
+ // achieve max performance
+
+ while (ByteCount)
+ {
+ const uint8_t NewByte = *CursorPtr;
+ const uint8_t OldByte = m_window[Index];
+
+ CurrentHash = detail::Rotate32(CurrentHash, 1) ^ detail::Rotate32(detail::buzhashTable[OldByte], m_windowSize) ^
+ detail::buzhashTable[NewByte];
+
+ CurrentChunkSize++;
+ CurrentOffset++;
+
+ if (CurrentChunkSize >= m_chunkSizeMin)
+ {
+ bool foundBreak;
+
+ if (CurrentChunkSize >= m_chunkSizeMax)
+ {
+ foundBreak = true;
+ }
+ else
+ {
+ foundBreak = CurrentHash <= m_threshold;
+ }
+
+ if (foundBreak)
+ {
+ // Boundary found!
+ InternalReset();
+
+ return CurrentOffset;
+ }
+ }
+
+ m_window[Index++] = *CursorPtr;
+
+ if (Index == kWindowSize)
+ {
+ Index = 0;
+ }
+
+ ++CursorPtr;
+ --ByteCount;
+ }
+ }
+ else if ((m_discriminator & (m_discriminator - 1)) == 0)
+ {
+ // This is quite a bit faster than the generic modulo path, but
+ // requires a very specific average chunk size to be used. If you
+ // pass in an even power-of-two divided by 0.75 as the average
+ // chunk size you'll hit this path
+
+ const uint32_t Mask = m_discriminator - 1;
+
+ while (ByteCount)
+ {
+ const uint8_t NewByte = *CursorPtr;
+ const uint8_t OldByte = m_window[Index];
+
+ CurrentHash = detail::Rotate32(CurrentHash, 1) ^ detail::Rotate32(detail::buzhashTable[OldByte], m_windowSize) ^
+ detail::buzhashTable[NewByte];
+
+ CurrentChunkSize++;
+ CurrentOffset++;
+
+ if (CurrentChunkSize >= m_chunkSizeMin)
+ {
+ bool foundBreak;
+
+ if (CurrentChunkSize >= m_chunkSizeMax)
+ {
+ foundBreak = true;
+ }
+ else
+ {
+ foundBreak = (CurrentHash & Mask) == Mask;
+ }
+
+ if (foundBreak)
+ {
+ // Boundary found!
+ InternalReset();
+
+ return CurrentOffset;
+ }
+ }
+
+ m_window[Index++] = *CursorPtr;
+
+ if (Index == kWindowSize)
+ {
+ Index = 0;
+ }
+
+ ++CursorPtr;
+ --ByteCount;
+ }
+ }
+ else
+ {
+ // This is the slowest path, which caps out around 250MB/sec for large sizes
+ // on my TR3900X
+
+ while (ByteCount)
+ {
+ const uint8_t NewByte = *CursorPtr;
+ const uint8_t OldByte = m_window[Index];
+
+ CurrentHash = detail::Rotate32(CurrentHash, 1) ^ detail::Rotate32(detail::buzhashTable[OldByte], m_windowSize) ^
+ detail::buzhashTable[NewByte];
+
+ CurrentChunkSize++;
+ CurrentOffset++;
+
+ if (CurrentChunkSize >= m_chunkSizeMin)
+ {
+ bool foundBreak;
+
+ if (CurrentChunkSize >= m_chunkSizeMax)
+ {
+ foundBreak = true;
+ }
+ else
+ {
+ foundBreak = (CurrentHash % m_discriminator) == (m_discriminator - 1);
+ }
+
+ if (foundBreak)
+ {
+ // Boundary found!
+ InternalReset();
+
+ return CurrentOffset;
+ }
+ }
+
+ m_window[Index++] = *CursorPtr;
+
+ if (Index == kWindowSize)
+ {
+ Index = 0;
+ }
+
+ ++CursorPtr;
+ --ByteCount;
+ }
+ }
+
+ m_currentChunkSize = CurrentChunkSize;
+ m_currentHash = CurrentHash;
+
+ return NoBoundaryFound;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+class DirectoryScanner
+{
+public:
+ struct FileEntry
+ {
+ std::filesystem::path Path;
+ uint64_t FileSize;
+ };
+
+ const std::vector<FileEntry>& Files() { return m_Files; }
+ std::vector<FileEntry>&& TakeFiles() { return std::move(m_Files); }
+ uint64_t FileBytes() const { return m_FileBytes; }
+
+ void Scan(std::filesystem::path RootPath)
+ {
+ for (const std::filesystem::directory_entry& Entry : std::filesystem::recursive_directory_iterator(RootPath))
+ {
+ if (Entry.is_regular_file())
+ {
+ m_Files.push_back({Entry.path(), Entry.file_size()});
+ m_FileBytes += Entry.file_size();
+ }
+ }
+ }
+
+private:
+ std::vector<FileEntry> m_Files;
+ uint64_t m_FileBytes = 0;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+class BaseChunker
+{
+public:
+ void SetCasStore(zen::CasStore* CasStore) { m_CasStore = CasStore; }
+
+ struct StatsBlock
+ {
+ uint64_t TotalBytes = 0;
+ uint64_t TotalChunks = 0;
+ uint64_t TotalCompressed = 0;
+ uint64_t UniqueBytes = 0;
+ uint64_t UniqueChunks = 0;
+ uint64_t UniqueCompressed = 0;
+ uint64_t DuplicateBytes = 0;
+ uint64_t NewCasChunks = 0;
+ uint64_t NewCasBytes = 0;
+
+ StatsBlock& operator+=(const StatsBlock& Rhs)
+ {
+ TotalBytes += Rhs.TotalBytes;
+ TotalChunks += Rhs.TotalChunks;
+ TotalCompressed += Rhs.TotalCompressed;
+ UniqueBytes += Rhs.UniqueBytes;
+ UniqueChunks += Rhs.UniqueChunks;
+ UniqueCompressed += Rhs.UniqueCompressed;
+ DuplicateBytes += Rhs.DuplicateBytes;
+ NewCasChunks += Rhs.NewCasChunks;
+ NewCasBytes += Rhs.NewCasBytes;
+ return *this;
+ }
+ };
+
+protected:
+ Concurrency::combinable<StatsBlock> m_StatsBlock;
+
+public:
+ StatsBlock SumStats()
+ {
+ StatsBlock _;
+ m_StatsBlock.combine_each([&](const StatsBlock& Block) { _ += Block; });
+ return _;
+ }
+
+protected:
+ struct HashSet
+ {
+ bool Add(const zen::IoHash& Hash)
+ {
+ const uint8_t ShardNo = Hash.Hash[19];
+
+ Bucket& Shard = m_Buckets[ShardNo];
+
+ zen::RwLock::ExclusiveLockScope _(Shard.HashLock);
+
+ auto rv = Shard.Hashes.insert(Hash);
+
+ return rv.second;
+ }
+
+ private:
+ struct alignas(64) Bucket
+ {
+ zen::RwLock HashLock;
+ std::unordered_set<zen::IoHash, zen::IoHash::Hasher> Hashes;
+# if ZEN_PLATFORM_WINDOWS
+# pragma warning(suppress : 4324) // Padding due to alignment
+# endif
+ };
+
+ Bucket m_Buckets[256];
+ };
+
+ zen::CasStore* m_CasStore = nullptr;
+};
+
+class FixedBlockSizeChunker : public BaseChunker
+{
+public:
+ FixedBlockSizeChunker(std::filesystem::path InRootPath) : m_RootPath(InRootPath) {}
+ ~FixedBlockSizeChunker() = default;
+
+ void SetChunkSize(uint64_t ChunkSize)
+ {
+ /* TODO: verify validity of chunk size */
+ m_ChunkSize = ChunkSize;
+ }
+ void SetUseCompression(bool UseCompression) { m_UseCompression = UseCompression; }
+ void SetPerformValidation(bool PerformValidation) { m_PerformValidation = PerformValidation; }
+
+ void InitCompression()
+ {
+ if (!m_CompressionBufferManager)
+ {
+ std::call_once(m_CompressionInitFlag, [&] {
+ // Wasteful, but should only be temporary
+ m_CompressionBufferManager.reset(new FileBufferManager(m_ChunkSize * 2, 128));
+ });
+ }
+ }
+
+ void ChunkFile(const DirectoryScanner::FileEntry& File)
+ {
+ InitCompression();
+
+ std::filesystem::path RelativePath{std::filesystem::relative(File.Path.generic_string(), m_RootPath)};
+
+ Concurrency::task_group ChunkProcessTasks;
+
+ ZEN_INFO("Chunking {} ({})", RelativePath.generic_string(), zen::NiceBytes(File.FileSize));
+
+ zen::RefPtr<InternalFile> Zfile = new InternalFile;
+ Zfile->OpenRead(File.Path);
+
+ size_t FileBytes = Zfile->GetFileSize();
+ uint64_t CurrentFileOffset = 0;
+
+ std::vector<zen::IoHash> BlockHashes{(FileBytes + m_ChunkSize - 1) / m_ChunkSize};
+
+ while (FileBytes)
+ {
+ zen::IoBuffer Buffer = m_BufferManager.AllocBuffer();
+
+ const size_t BytesToRead = std::min(FileBytes, Buffer.Size());
+
+ Zfile->Read((void*)Buffer.Data(), BytesToRead, CurrentFileOffset);
+
+ auto ProcessChunk = [this, Buffer, &BlockHashes, CurrentFileOffset, BytesToRead] {
+ StatsBlock& Stats = m_StatsBlock.local();
+ for (uint64_t Offset = 0; Offset < BytesToRead; Offset += m_ChunkSize)
+ {
+ const uint8_t* DataPointer = reinterpret_cast<const uint8_t*>(Buffer.Data()) + Offset;
+ const uint64_t DataSize = std::min(BytesToRead - Offset, m_ChunkSize);
+ const zen::IoHash Hash = zen::IoHash::HashBuffer(DataPointer, DataSize);
+
+ BlockHashes[(CurrentFileOffset + Offset) / m_ChunkSize] = Hash;
+
+ const bool IsNew = m_LocalHashSet.Add(Hash);
+
+ if (IsNew)
+ {
+ if (m_UseCompression)
+ {
+ if (true)
+ {
+ // Compress using ZSTD
+
+ // TODO: use CompressedBuffer format
+
+ const size_t CompressBufferSize = ZSTD_compressBound(DataSize);
+
+ zen::IoBuffer CompressedBuffer = m_CompressionBufferManager->AllocBuffer();
+ char* CompressBuffer = (char*)CompressedBuffer.Data();
+
+ ZEN_ASSERT(CompressedBuffer.Size() >= CompressBufferSize);
+
+ const size_t CompressedSize = ZSTD_compress(CompressBuffer,
+ CompressBufferSize,
+ (const char*)DataPointer,
+ DataSize,
+ ZSTD_CLEVEL_DEFAULT);
+
+ Stats.UniqueCompressed += CompressedSize;
+
+ if (m_CasStore)
+ {
+ const zen::IoHash CompressedHash = zen::IoHash::HashBuffer(CompressBuffer, CompressedSize);
+ zen::IoBuffer CompressedData = zen::IoBuffer(zen::IoBuffer::Wrap, CompressBuffer, CompressedSize);
+ zen::CasStore::InsertResult Result = m_CasStore->InsertChunk(CompressedData, CompressedHash);
+
+ if (Result.New)
+ {
+ Stats.NewCasChunks += 1;
+ Stats.NewCasBytes += CompressedSize;
+ }
+ }
+
+ m_CompressionBufferManager->ReturnBuffer(CompressedBuffer);
+ }
+ else
+ {
+ // Compress using LZ4
+ const int CompressBufferSize = LZ4_compressBound(gsl::narrow<int>(DataSize));
+
+ zen::IoBuffer CompressedBuffer = m_CompressionBufferManager->AllocBuffer();
+ char* CompressBuffer = (char*)CompressedBuffer.Data();
+
+ ZEN_ASSERT(CompressedBuffer.Size() >= size_t(CompressBufferSize));
+
+ const int CompressedSize = LZ4_compress_default((const char*)DataPointer,
+ CompressBuffer,
+ gsl::narrow<int>(DataSize),
+ CompressBufferSize);
+
+ Stats.UniqueCompressed += CompressedSize;
+
+ if (m_CasStore)
+ {
+ const zen::IoHash CompressedHash = zen::IoHash::HashBuffer(CompressBuffer, CompressedSize);
+ zen::IoBuffer CompressedData = zen::IoBuffer(zen::IoBuffer::Wrap, CompressBuffer, CompressedSize);
+ zen::CasStore::InsertResult Result = m_CasStore->InsertChunk(CompressedData, CompressedHash);
+
+ if (Result.New)
+ {
+ Stats.NewCasChunks += 1;
+ Stats.NewCasBytes += CompressedSize;
+ }
+ }
+
+ m_CompressionBufferManager->ReturnBuffer(CompressedBuffer);
+ }
+ }
+ else if (m_CasStore)
+ {
+ zen::CasStore::InsertResult Result = m_CasStore->InsertChunk(zen::IoBuffer(Buffer, Offset, DataSize), Hash);
+
+ if (Result.New)
+ {
+ Stats.NewCasChunks += 1;
+ Stats.NewCasBytes += DataSize;
+ }
+ }
+
+ Stats.UniqueBytes += DataSize;
+ Stats.UniqueChunks += 1;
+ }
+ else
+ {
+ // We've seen this chunk before
+ Stats.DuplicateBytes += DataSize;
+ }
+
+ Stats.TotalBytes += DataSize;
+ Stats.TotalChunks += 1;
+ }
+
+ m_BufferManager.ReturnBuffer(Buffer);
+ };
+
+ ChunkProcessTasks.run(ProcessChunk);
+
+ CurrentFileOffset += BytesToRead;
+ FileBytes -= BytesToRead;
+ }
+
+ ChunkProcessTasks.wait();
+
+ // Verify pass
+
+ if (!m_UseCompression && m_PerformValidation)
+ {
+ const uint8_t* FileData = reinterpret_cast<const uint8_t*>(Zfile->MemoryMapFile());
+ uint64_t Offset = 0;
+ const uint64_t BytesToRead = Zfile->GetFileSize();
+
+ for (zen::IoHash& Hash : BlockHashes)
+ {
+ const uint64_t DataSize = std::min(BytesToRead - Offset, m_ChunkSize);
+ const zen::IoHash CalcHash = zen::IoHash::HashBuffer(FileData + Offset, DataSize);
+
+ ZEN_ASSERT(CalcHash == Hash);
+
+ zen::IoBuffer FoundValue = m_CasStore->FindChunk(CalcHash);
+
+ ZEN_ASSERT(FoundValue);
+ ZEN_ASSERT(FoundValue.Size() == DataSize);
+
+ Offset += DataSize;
+ }
+ }
+ }
+
+private:
+ std::filesystem::path m_RootPath;
+ FileBufferManager m_BufferManager{128 * 1024, 128};
+ uint64_t m_ChunkSize = 64 * 1024;
+ HashSet m_LocalHashSet;
+ bool m_UseCompression = true;
+ bool m_PerformValidation = false;
+
+ std::once_flag m_CompressionInitFlag;
+ std::unique_ptr<FileBufferManager> m_CompressionBufferManager;
+};
+
+class VariableBlockSizeChunker : public BaseChunker
+{
+public:
+ VariableBlockSizeChunker(std::filesystem::path InRootPath) : m_RootPath(InRootPath) {}
+
+ void SetAverageChunkSize(uint64_t AverageChunkSize) { m_AverageChunkSize = AverageChunkSize; }
+ void SetUseCompression(bool UseCompression) { m_UseCompression = UseCompression; }
+
+ void ChunkFile(const DirectoryScanner::FileEntry& File)
+ {
+ std::filesystem::path RelativePath{std::filesystem::relative(File.Path.generic_string(), m_RootPath)};
+
+ ZEN_INFO("Chunking {} ({})", RelativePath.generic_string(), zen::NiceBytes(File.FileSize));
+
+ zen::RefPtr<InternalFile> Zfile = new InternalFile;
+ Zfile->OpenRead(File.Path);
+
+ // Could use IoBuffer here to help manage lifetimes of things
+ // across tasks / threads
+
+ ZenChunker Chunker;
+ Chunker.SetChunkSize(0, 0, m_AverageChunkSize);
+
+ const size_t DataSize = Zfile->GetFileSize();
+
+ std::vector<size_t> Boundaries;
+
+ uint64_t CurrentStreamPosition = 0;
+ uint64_t CurrentChunkSize = 0;
+ size_t RemainBytes = DataSize;
+
+ zen::IoHashStream IoHashStream;
+
+ while (RemainBytes != 0)
+ {
+ zen::IoBuffer Buffer = m_BufferManager.AllocBuffer();
+
+ size_t BytesToRead = std::min(RemainBytes, Buffer.Size());
+
+ uint8_t* DataPointer = (uint8_t*)Buffer.Data();
+
+ Zfile->Read(DataPointer, BytesToRead, CurrentStreamPosition);
+
+ StatsBlock& Stats = m_StatsBlock.local();
+
+ while (BytesToRead)
+ {
+ const size_t Boundary = Chunker.ScanChunk(DataPointer, BytesToRead);
+
+ if (Boundary == ZenChunker::NoBoundaryFound)
+ {
+ IoHashStream.Append(DataPointer, BytesToRead);
+ CurrentStreamPosition += BytesToRead;
+ CurrentChunkSize += BytesToRead;
+ RemainBytes -= BytesToRead;
+ break;
+ }
+
+ // Boundary found
+
+ IoHashStream.Append(DataPointer, Boundary);
+
+ const zen::IoHash Hash = IoHashStream.GetHash();
+ const bool IsNew = m_LocalHashSet.Add(Hash);
+
+ CurrentStreamPosition += Boundary;
+ CurrentChunkSize += Boundary;
+ Boundaries.push_back(CurrentStreamPosition);
+
+ if (IsNew)
+ {
+ Stats.UniqueBytes += CurrentChunkSize;
+ }
+ else
+ {
+ // We've seen this chunk before
+ Stats.DuplicateBytes += CurrentChunkSize;
+ }
+
+ DataPointer += Boundary;
+ RemainBytes -= Boundary;
+ BytesToRead -= Boundary;
+ CurrentChunkSize = 0;
+ IoHashStream.Reset();
+ }
+
+ m_BufferManager.ReturnBuffer(Buffer);
+
+# if 0
+ Active.AddCount(); // needs fixing
+
+ Concurrency::create_task([this, Zfile, CurrentPosition, DataPointer, &Active] {
+ const zen::IoHash Hash = zen::IoHash::HashBuffer(DataPointer, CurrentPosition);
+
+ const bool isNew = m_LocalHashSet.Add(Hash);
+
+ const int CompressBufferSize = LZ4_compressBound(gsl::narrow<int>(CurrentPosition));
+ char* CompressBuffer = (char*)_aligned_malloc(CompressBufferSize, 16);
+
+ const int CompressedSize =
+ LZ4_compress_default((const char*)DataPointer, CompressBuffer, gsl::narrow<int>(CurrentPosition), CompressBufferSize);
+
+ m_TotalCompressed.local() += CompressedSize;
+
+ if (isNew)
+ {
+ m_UniqueBytes.local() += CurrentPosition;
+ m_UniqueCompressed.local() += CompressedSize;
+
+ if (m_CasStore)
+ {
+ const zen::IoHash CompressedHash = zen::IoHash::HashBuffer(CompressBuffer, CompressedSize);
+ m_CasStore->InsertChunk(CompressBuffer, CompressedSize, CompressedHash);
+ }
+ }
+
+ Active.Signal(); // needs fixing
+
+ _aligned_free(CompressBuffer);
+ });
+# endif
+ }
+
+ StatsBlock& Stats = m_StatsBlock.local();
+ Stats.TotalBytes += DataSize;
+ Stats.TotalChunks += Boundaries.size() + 1;
+
+ // TODO: Wait for all compression tasks
+
+ auto ChunkCount = Boundaries.size() + 1;
+
+ ZEN_INFO("Split {} ({}) into {} chunks, avg size {}",
+ RelativePath.generic_string(),
+ zen::NiceBytes(File.FileSize),
+ ChunkCount,
+ File.FileSize / ChunkCount);
+ };
+
+private:
+ HashSet m_LocalHashSet;
+ std::filesystem::path m_RootPath;
+ uint64_t m_AverageChunkSize = 32 * 1024;
+ bool m_UseCompression = true;
+ FileBufferManager m_BufferManager{128 * 1024, 128};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+ChunkCommand::ChunkCommand()
+{
+ m_Options.add_options()("r,root", "Root directory for CAS pool", cxxopts::value(m_RootDirectory));
+ m_Options.add_options()("d,dir", "Directory to scan", cxxopts::value(m_ScanDirectory));
+ m_Options.add_options()("c,chunk-size", "Use fixed chunk size", cxxopts::value(m_ChunkSize));
+ m_Options.add_options()("a,average-chunk-size", "Use dynamic chunk size", cxxopts::value(m_AverageChunkSize));
+ m_Options.add_options()("compress", "Apply compression to chunks", cxxopts::value(m_UseCompression));
+}
+
+ChunkCommand::~ChunkCommand() = default;
+
+int
+ChunkCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ bool IsValid = m_ScanDirectory.length();
+
+ if (!IsValid)
+ throw cxxopts::OptionParseException("Chunk command requires a directory to scan");
+
+ if ((m_ChunkSize && m_AverageChunkSize) && (!m_ChunkSize && !m_AverageChunkSize))
+ throw cxxopts::OptionParseException("Either of --chunk-size or --average-chunk-size must be used");
+
+ std::unique_ptr<zen::CasStore> CasStore;
+
+ zen::GcManager Gc;
+
+ if (!m_RootDirectory.empty())
+ {
+ zen::CasStoreConfiguration Config;
+ Config.RootDirectory = m_RootDirectory;
+
+ CasStore = zen::CreateCasStore(Gc);
+ CasStore->Initialize(Config);
+ }
+
+ // Gather list of files to process
+
+ ZEN_INFO("Gathering files from {}", m_ScanDirectory);
+
+ std::filesystem::path RootPath{m_ScanDirectory};
+ DirectoryScanner Scanner;
+ Scanner.Scan(RootPath);
+
+ auto Files = Scanner.TakeFiles();
+ uint64_t FileBytes = Scanner.FileBytes();
+
+ std::sort(begin(Files), end(Files), [](const DirectoryScanner::FileEntry& Lhs, const DirectoryScanner::FileEntry& Rhs) {
+ return Lhs.FileSize < Rhs.FileSize;
+ });
+
+ ZEN_INFO("Gathered {} files, total size {}", Files.size(), zen::NiceBytes(FileBytes));
+
+ auto ReportSummary = [&](BaseChunker& Chunker, uint64_t ElapsedMs) {
+ const BaseChunker::StatsBlock& Stats = Chunker.SumStats();
+
+ const size_t TotalChunkCount = Stats.TotalChunks;
+ ZEN_INFO("Scanned {} files in {}, generated {} chunks", Files.size(), zen::NiceTimeSpanMs(ElapsedMs), TotalChunkCount);
+
+ const size_t TotalByteCount = Stats.TotalBytes;
+ const size_t TotalCompressedBytes = Stats.TotalCompressed;
+
+ ZEN_INFO("Total bytes {} ({}), compresses into {}",
+ zen::NiceBytes(TotalByteCount),
+ zen::NiceByteRate(TotalByteCount, ElapsedMs),
+ zen::NiceBytes(TotalCompressedBytes));
+
+ const size_t TotalUniqueBytes = Stats.UniqueBytes;
+ const size_t TotalUniqueCompressedBytes = Stats.UniqueCompressed;
+ const size_t TotalDuplicateBytes = Stats.DuplicateBytes;
+
+ ZEN_INFO("Chunksize average {}, unique bytes = {} (compressed {}), dup bytes = {}",
+ TotalByteCount / TotalChunkCount,
+ zen::NiceBytes(TotalUniqueBytes),
+ zen::NiceBytes(TotalUniqueCompressedBytes),
+ zen::NiceBytes(TotalDuplicateBytes));
+
+ ZEN_INFO("New to CAS: {} chunks, {}", Stats.NewCasChunks, zen::NiceBytes(Stats.NewCasBytes));
+ };
+
+ // Process them as quickly as possible
+
+ if (m_AverageChunkSize)
+ {
+ VariableBlockSizeChunker Chunker{RootPath};
+ Chunker.SetAverageChunkSize(m_AverageChunkSize);
+ Chunker.SetUseCompression(m_UseCompression);
+ Chunker.SetCasStore(CasStore.get());
+
+ zen::Stopwatch timer;
+
+# if 1
+ Concurrency::parallel_for_each(begin(Files), end(Files), [&Chunker](const auto& ThisFile) { Chunker.ChunkFile(ThisFile); });
+# else
+ for (const auto& ThisFile : Files)
+ {
+ Chunker.ChunkFile(ThisFile);
+ }
+# endif
+
+ uint64_t ElapsedMs = timer.GetElapsedTimeMs();
+
+ ReportSummary(Chunker, ElapsedMs);
+ }
+ else if (m_ChunkSize)
+ {
+ FixedBlockSizeChunker Chunker{RootPath};
+ Chunker.SetChunkSize(m_ChunkSize);
+ Chunker.SetUseCompression(m_UseCompression);
+ Chunker.SetCasStore(CasStore.get());
+
+ zen::Stopwatch timer;
+
+ Concurrency::parallel_for_each(begin(Files), end(Files), [&Chunker](const DirectoryScanner::FileEntry& ThisFile) {
+ try
+ {
+ Chunker.ChunkFile(ThisFile);
+ }
+ catch (std::exception& ex)
+ {
+ zen::ExtendableStringBuilder<256> Path8;
+ zen::PathToUtf8(ThisFile.Path, Path8);
+ ZEN_WARN("Caught exception while chunking '{}': {}", Path8, ex.what());
+ }
+ });
+
+ uint64_t ElapsedMs = timer.GetElapsedTimeMs();
+
+ ReportSummary(Chunker, ElapsedMs);
+ }
+ else
+ {
+ ZEN_ASSERT(false);
+ }
+
+ // TODO: implement snapshot enumeration and display
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+# if ZEN_WITH_TESTS
+TEST_CASE("chunking")
+{
+ using namespace zen;
+
+ auto test = [](bool UseThreshold, bool Random, int MinBlockSize, int MaxBlockSize) {
+ std::mt19937_64 mt;
+
+ std::vector<uint64_t> bytes;
+ bytes.resize(1 * 1024 * 1024);
+
+ if (Random == false)
+ {
+ // Generate a single block of randomness
+ for (auto& w : bytes)
+ {
+ w = mt();
+ }
+ }
+
+ for (int i = MinBlockSize; i <= MaxBlockSize; i <<= 1)
+ {
+ Stopwatch timer;
+
+ ZenChunker chunker;
+ chunker.SetUseThreshold(UseThreshold);
+ chunker.SetChunkSize(0, 0, i);
+ // chunker.SetChunkSize(i / 4, i * 4, 0);
+ // chunker.SetChunkSize(i / 8, i * 8, 0);
+ // chunker.SetChunkSize(i / 16, i * 16, 0);
+ // chunker.SetChunkSize(0, 0, size_t(i / 0.75)); // Hits the fast modulo path
+
+ std::vector<size_t> boundaries;
+
+ size_t CurrentPosition = 0;
+ int BoundaryCount = 0;
+
+ do
+ {
+ if (Random == true)
+ {
+ // Generate a new block of randomness for each pass
+ for (auto& w : bytes)
+ {
+ w = mt();
+ }
+ }
+
+ const uint8_t* Ptr = reinterpret_cast<const uint8_t*>(bytes.data());
+ size_t BytesRemain = bytes.size() * sizeof(uint64_t);
+
+ for (;;)
+ {
+ const size_t Boundary = chunker.ScanChunk(Ptr, BytesRemain);
+
+ if (Boundary == ZenChunker::NoBoundaryFound)
+ {
+ CurrentPosition += BytesRemain;
+ break;
+ }
+
+ // Boundary found
+
+ CurrentPosition += Boundary;
+
+ CHECK(CurrentPosition >= chunker.ChunkSizeMin());
+ CHECK(CurrentPosition <= chunker.ChunkSizeMax());
+
+ boundaries.push_back(CurrentPosition);
+
+ CurrentPosition = 0;
+ Ptr += Boundary;
+ BytesRemain -= Boundary;
+
+ ++BoundaryCount;
+ }
+ } while (BoundaryCount < 5000);
+
+ size_t BoundarySum = 0;
+
+ for (const auto& v : boundaries)
+ {
+ BoundarySum += v;
+ }
+
+ double Avg = double(BoundarySum) / BoundaryCount;
+ const uint64_t ElapsedTimeMs = timer.GetElapsedTimeMs();
+
+ ZEN_INFO("{:9} : Avg {:9} - {:2.5} ({:6}, {})",
+ i,
+ Avg,
+ double(i / Avg),
+ NiceTimeSpanMs(ElapsedTimeMs),
+ NiceByteRate(chunker.BytesScanned(), ElapsedTimeMs));
+ }
+ };
+
+ const bool Random = false;
+
+ SUBCASE("threshold method") { test(/* UseThreshold */ true, /* Random */ Random, 2048, 1 * 1024 * 1024); }
+
+ SUBCASE("mod method") { test(/* UseThreshold */ false, /* Random */ Random, 2048, 1 * 1024 * 1024); }
+}
+# endif
+#endif
diff --git a/src/zen/chunk/chunk.h b/src/zen/chunk/chunk.h
new file mode 100644
index 000000000..e796f4147
--- /dev/null
+++ b/src/zen/chunk/chunk.h
@@ -0,0 +1,25 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+#include <zencore/zencore.h>
+#include "../zen.h"
+
+#if 0
+class ChunkCommand : public ZenCmdBase
+{
+public:
+ ChunkCommand();
+ ~ChunkCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"chunk", "Do a chunking pass"};
+ std::string m_RootDirectory;
+ std::string m_ScanDirectory;
+ size_t m_ChunkSize = 0;
+ size_t m_AverageChunkSize = 0;
+ bool m_UseCompression = true;
+};
+#endif // 0
diff --git a/src/zen/cmds/cache.cpp b/src/zen/cmds/cache.cpp
new file mode 100644
index 000000000..495662d2f
--- /dev/null
+++ b/src/zen/cmds/cache.cpp
@@ -0,0 +1,275 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "cache.h"
+
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+#include <zenhttp/httpcommon.h>
+#include <zenutil/zenserverprocess.h>
+
+#include <memory>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+DropCommand::DropCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "n", "namespace", "Namespace name", cxxopts::value(m_NamespaceName), "<namespacename>");
+ m_Options.add_option("", "b", "bucket", "Bucket name", cxxopts::value(m_BucketName), "<bucketname>");
+ m_Options.parse_positional({"namespace", "bucket"});
+}
+
+DropCommand::~DropCommand() = default;
+
+int
+DropCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ if (m_NamespaceName.empty())
+ {
+ throw cxxopts::OptionParseException("Drop command requires a namespace");
+ }
+
+ cpr::Session Session;
+ if (m_BucketName.empty())
+ {
+ ZEN_CONSOLE("Dropping cache namespace '{}' from '{}'", m_NamespaceName, m_HostName);
+ Session.SetUrl({fmt::format("{}/z$/{}", m_HostName, m_NamespaceName)});
+ }
+ else
+ {
+ ZEN_CONSOLE("Dropping cache bucket '{}/{}' from '{}'", m_NamespaceName, m_BucketName, m_HostName);
+ Session.SetUrl({fmt::format("{}/z$/{}/{}", m_HostName, m_NamespaceName, m_BucketName)});
+ }
+
+ cpr::Response Result = Session.Delete();
+
+ if (zen::IsHttpSuccessCode(Result.status_code))
+ {
+ ZEN_CONSOLE("OK: drop succeeded");
+
+ return 0;
+ }
+
+ if (Result.status_code)
+ {
+ ZEN_ERROR("Drop failed: {}: {} ({})", Result.status_code, Result.reason, Result.text);
+ }
+ else
+ {
+ ZEN_ERROR("Drop failed: {}", Result.error.message);
+ }
+
+ return 1;
+}
+
+CacheInfoCommand::CacheInfoCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "n", "namespace", "Namespace name", cxxopts::value(m_NamespaceName), "<namespacename>");
+ m_Options.add_option("", "b", "bucket", "Bucket name", cxxopts::value(m_BucketName), "<bucketname>");
+ m_Options.parse_positional({"namespace", "bucket"});
+}
+
+CacheInfoCommand::~CacheInfoCommand() = default;
+
+int
+CacheInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Session Session;
+ Session.SetHeader(cpr::Header{{"Accept", "application/json"}});
+ if (m_HostName.empty())
+ {
+ ZEN_CONSOLE("Info on cache from '{}'", m_HostName);
+ Session.SetUrl({fmt::format("{}/z$", m_HostName)});
+ }
+ else if (m_BucketName.empty())
+ {
+ ZEN_CONSOLE("Info on cache namespace '{}' from '{}'", m_NamespaceName, m_HostName);
+ Session.SetUrl({fmt::format("{}/z$/{}", m_HostName, m_NamespaceName)});
+ }
+ else
+ {
+ ZEN_CONSOLE("Info on cache bucket '{}/{}' from '{}'", m_NamespaceName, m_BucketName, m_HostName);
+ Session.SetUrl({fmt::format("{}/z$/{}/{}", m_HostName, m_NamespaceName, m_BucketName)});
+ }
+
+ cpr::Response Result = Session.Get();
+
+ if (zen::IsHttpSuccessCode(Result.status_code))
+ {
+ ZEN_CONSOLE("{}", Result.text);
+
+ return 0;
+ }
+
+ if (Result.status_code)
+ {
+ ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text);
+ }
+ else
+ {
+ ZEN_ERROR("Info failed: {}", Result.error.message);
+ }
+
+ return 1;
+}
+
+CacheStatsCommand::CacheStatsCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+}
+
+CacheStatsCommand::~CacheStatsCommand() = default;
+
+int
+CacheStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Session Session;
+ Session.SetUrl({fmt::format("{}/stats/z$", m_HostName)});
+ Session.SetHeader(cpr::Header{{"Accept", "application/json"}});
+
+ cpr::Response Result = Session.Get();
+
+ if (zen::IsHttpSuccessCode(Result.status_code))
+ {
+ ZEN_CONSOLE("{}", Result.text);
+
+ return 0;
+ }
+
+ if (Result.status_code)
+ {
+ ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text);
+ }
+ else
+ {
+ ZEN_ERROR("Info failed: {}", Result.error.message);
+ }
+
+ return 1;
+}
+
+CacheDetailsCommand::CacheDetailsCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "c", "csv", "Info on csv format", cxxopts::value(m_CSV), "<csv>");
+ m_Options.add_option("", "d", "details", "Get detailed information about records", cxxopts::value(m_Details), "<details>");
+ m_Options.add_option("",
+ "a",
+ "attachmentdetails",
+ "Get detailed information about attachments",
+ cxxopts::value(m_AttachmentDetails),
+ "<attachmentdetails>");
+ m_Options.add_option("", "n", "namespace", "Namespace name to get info for", cxxopts::value(m_Namespace), "<namespace>");
+ m_Options.add_option("", "b", "bucket", "Filter on bucket name", cxxopts::value(m_Bucket), "<bucket>");
+ m_Options.add_option("", "v", "valuekey", "Filter on value key hash string", cxxopts::value(m_ValueKey), "<valuekey>");
+}
+
+CacheDetailsCommand::~CacheDetailsCommand() = default;
+
+int
+CacheDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Session Session;
+ cpr::Parameters Parameters;
+ if (m_Details)
+ {
+ Parameters.Add({"details", "true"});
+ }
+ if (m_AttachmentDetails)
+ {
+ Parameters.Add({"attachmentdetails", "true"});
+ }
+ if (m_CSV)
+ {
+ Parameters.Add({"csv", "true"});
+ }
+ else
+ {
+ Session.SetHeader(cpr::Header{{"Accept", "application/json"}});
+ }
+
+ if (!m_ValueKey.empty())
+ {
+ if (m_Namespace.empty() || m_Bucket.empty())
+ {
+ ZEN_ERROR("Provide namespace and bucket name");
+ ZEN_CONSOLE("{}", m_Options.help({""}).c_str());
+ return 1;
+ }
+ Session.SetUrl({fmt::format("{}/z$/details$/{}/{}/{}", m_HostName, m_Namespace, m_Bucket, m_ValueKey)});
+ }
+ else if (!m_Bucket.empty())
+ {
+ if (m_Namespace.empty())
+ {
+ ZEN_ERROR("Provide namespace name");
+ ZEN_CONSOLE("{}", m_Options.help({""}).c_str());
+ return 1;
+ }
+ Session.SetUrl({fmt::format("{}/z$/details$/{}/{}", m_HostName, m_Namespace, m_Bucket)});
+ }
+ else if (!m_Namespace.empty())
+ {
+ Session.SetUrl({fmt::format("{}/z$/details$/{}", m_HostName, m_Namespace)});
+ }
+ else
+ {
+ Session.SetUrl({fmt::format("{}/z$/details$", m_HostName)});
+ }
+ Session.SetParameters(Parameters);
+
+ cpr::Response Result = Session.Get();
+
+ if (zen::IsHttpSuccessCode(Result.status_code))
+ {
+ ZEN_CONSOLE("{}", Result.text);
+
+ return 0;
+ }
+
+ if (Result.status_code)
+ {
+ ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text);
+ }
+ else
+ {
+ ZEN_ERROR("Info failed: {}", Result.error.message);
+ }
+
+ return 1;
+}
diff --git a/src/zen/cmds/cache.h b/src/zen/cmds/cache.h
new file mode 100644
index 000000000..1f368bdec
--- /dev/null
+++ b/src/zen/cmds/cache.h
@@ -0,0 +1,68 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+class DropCommand : public ZenCmdBase
+{
+public:
+ DropCommand();
+ ~DropCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"drop", "Drop cache namespace or bucket"};
+ std::string m_HostName;
+ std::string m_NamespaceName;
+ std::string m_BucketName;
+};
+
+class CacheInfoCommand : public ZenCmdBase
+{
+public:
+ CacheInfoCommand();
+ ~CacheInfoCommand();
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"cache-info", "Info on cache, namespace or bucket"};
+ std::string m_HostName;
+ std::string m_NamespaceName;
+ std::string m_BucketName;
+};
+
+class CacheStatsCommand : public ZenCmdBase
+{
+public:
+ CacheStatsCommand();
+ ~CacheStatsCommand();
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"cache-stats", "Stats info on cache"};
+ std::string m_HostName;
+};
+
+class CacheDetailsCommand : public ZenCmdBase
+{
+public:
+ CacheDetailsCommand();
+ ~CacheDetailsCommand();
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"cache-details", "Detailed info on cache"};
+ std::string m_HostName;
+ bool m_CSV;
+ bool m_Details;
+ bool m_AttachmentDetails;
+ std::string m_Namespace;
+ std::string m_Bucket;
+ std::string m_ValueKey;
+};
diff --git a/src/zen/cmds/copy.cpp b/src/zen/cmds/copy.cpp
new file mode 100644
index 000000000..6f6c078d4
--- /dev/null
+++ b/src/zen/cmds/copy.cpp
@@ -0,0 +1,95 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "copy.h"
+
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include <zencore/timer.h>
+
+namespace zen {
+
+CopyCommand::CopyCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_options()("no-clone", "Do not perform block clone", cxxopts::value(m_NoClone)->default_value("false"));
+ m_Options.add_option("", "s", "source", "Copy source", cxxopts::value(m_CopySource), "<file/directory>");
+ m_Options.add_option("", "t", "target", "Copy target", cxxopts::value(m_CopyTarget), "<file/directory>");
+ m_Options.add_option("", "", "positional", "Positional arguments", cxxopts::value(m_Positional), "");
+ m_Options.parse_positional({"source", "target", "positional"});
+}
+
+CopyCommand::~CopyCommand() = default;
+
+int
+CopyCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ZenCmdBase::ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ // Validate arguments
+
+ if (m_CopySource.empty())
+ throw std::runtime_error("No source specified");
+
+ if (m_CopyTarget.empty())
+ throw std::runtime_error("No target specified");
+
+ std::filesystem::path FromPath;
+ std::filesystem::path ToPath;
+
+ FromPath = m_CopySource;
+ ToPath = m_CopyTarget;
+
+ const bool IsFileCopy = std::filesystem::is_regular_file(m_CopySource);
+ const bool IsDirCopy = std::filesystem::is_directory(m_CopySource);
+
+ if (!IsFileCopy && !IsDirCopy)
+ {
+ throw std::runtime_error("Invalid source specification (neither directory nor file)");
+ }
+
+ if (IsFileCopy && IsDirCopy)
+ {
+ throw std::runtime_error("Invalid source specification (both directory AND file!?)");
+ }
+
+ if (IsDirCopy)
+ {
+ if (std::filesystem::exists(ToPath))
+ {
+ const bool IsTargetDir = std::filesystem::is_directory(ToPath);
+ if (!IsTargetDir)
+ {
+ if (std::filesystem::is_regular_file(ToPath))
+ {
+ throw std::runtime_error("Attempted copy of directory into file");
+ }
+ }
+ }
+ else
+ {
+ std::filesystem::create_directories(ToPath);
+ }
+ }
+ else
+ {
+ // Single file copy
+
+ zen::Stopwatch Timer;
+
+ zen::CopyFileOptions CopyOptions;
+ CopyOptions.EnableClone = !m_NoClone;
+ zen::CopyFile(FromPath, ToPath, CopyOptions);
+
+ ZEN_CONSOLE("Copy completed in {}", zen::NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ }
+
+ return 0;
+}
+
+} // namespace zen
diff --git a/src/zen/cmds/copy.h b/src/zen/cmds/copy.h
new file mode 100644
index 000000000..5527ae9b8
--- /dev/null
+++ b/src/zen/cmds/copy.h
@@ -0,0 +1,28 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+namespace zen {
+
+/** Copy files, possibly using block cloning
+ */
+class CopyCommand : public ZenCmdBase
+{
+public:
+ CopyCommand();
+ ~CopyCommand();
+
+ virtual cxxopts::Options& Options() override { return m_Options; }
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+
+private:
+ cxxopts::Options m_Options{"copy", "Copy files"};
+ std::vector<std::string> m_Positional;
+ std::string m_CopySource;
+ std::string m_CopyTarget;
+ bool m_NoClone = false;
+};
+
+} // namespace zen
diff --git a/src/zen/cmds/dedup.cpp b/src/zen/cmds/dedup.cpp
new file mode 100644
index 000000000..b48fb8c2d
--- /dev/null
+++ b/src/zen/cmds/dedup.cpp
@@ -0,0 +1,302 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "dedup.h"
+
+#include <zencore/blake3.h>
+#include <zencore/filesystem.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/timer.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <ppl.h>
+#endif
+
+#include <list>
+
+namespace zen {
+
+////////////////////////////////////////////////////////////////////////////////
+
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+
+namespace Concurrency {
+
+ template<typename T0, typename T1>
+ inline void parallel_invoke(T0 const& t0, T1 const& t1)
+ {
+ t0();
+ t1();
+ }
+
+} // namespace Concurrency
+
+#endif // ZEN_PLATFORM_LINUX/MAC
+
+////////////////////////////////////////////////////////////////////////////////
+
+DedupCommand::DedupCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_options()("size", "Configure size threshold for dedup", cxxopts::value(m_SizeThreshold)->default_value("131072"));
+ m_Options.add_option("", "s", "source", "Copy source", cxxopts::value(m_DedupSource), "<file/directory>");
+ m_Options.add_option("", "t", "target", "Copy target", cxxopts::value(m_DedupTarget), "<file/directory>");
+ m_Options.add_option("", "", "positional", "Positional arguments", cxxopts::value(m_Positional), "");
+ m_Options.parse_positional({"source", "target", "positional"});
+}
+
+DedupCommand::~DedupCommand() = default;
+
+int
+DedupCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ // Validate arguments
+
+ const bool SourceGood = zen::SupportsBlockRefCounting(m_DedupSource);
+ const bool TargetGood = zen::SupportsBlockRefCounting(m_DedupTarget);
+
+ if (!SourceGood)
+ {
+ ZEN_ERROR("Source directory '{}' does not support deduplication", m_DedupSource);
+
+ return 0;
+ }
+
+ if (!TargetGood)
+ {
+ ZEN_ERROR("Target directory '{}' does not support deduplication", m_DedupTarget);
+
+ return 0;
+ }
+
+ ZEN_CONSOLE("Performing dedup operation between {} and {}, size threshold {}",
+ m_DedupSource,
+ m_DedupTarget,
+ zen::NiceBytes(m_SizeThreshold));
+
+ using DirEntryList_t = std::list<std::filesystem::directory_entry>;
+
+ zen::RwLock MapLock;
+ std::unordered_map<size_t, DirEntryList_t> FileSizeMap;
+ size_t CandidateCount = 0;
+
+ auto AddToList = [&](const std::filesystem::directory_entry& Entry) {
+ if (Entry.is_regular_file())
+ {
+ uintmax_t FileSize = Entry.file_size();
+ if (FileSize > m_SizeThreshold)
+ {
+ zen::RwLock::ExclusiveLockScope _(MapLock);
+ FileSizeMap[FileSize].push_back(Entry);
+ ++CandidateCount;
+ }
+ }
+ };
+
+ std::filesystem::recursive_directory_iterator DirEnd;
+
+ ZEN_CONSOLE("Gathering file info from source: '{}'", m_DedupSource);
+ ZEN_CONSOLE("Gathering file info from target: '{}'", m_DedupTarget);
+
+ {
+ zen::Stopwatch Timer;
+
+ Concurrency::parallel_invoke(
+ [&] {
+ for (std::filesystem::recursive_directory_iterator DirIt1(m_DedupSource); DirIt1 != DirEnd; ++DirIt1)
+ {
+ AddToList(*DirIt1);
+ }
+ },
+ [&] {
+ for (std::filesystem::recursive_directory_iterator DirIt2(m_DedupTarget); DirIt2 != DirEnd; ++DirIt2)
+ {
+ AddToList(*DirIt2);
+ }
+ });
+
+ ZEN_CONSOLE("Gathered {} candidates across {} size buckets. Elapsed: {}",
+ CandidateCount,
+ FileSizeMap.size(),
+ zen::NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ }
+
+ ZEN_CONSOLE("Sorting buckets by size");
+
+ zen::Stopwatch Timer;
+
+ uint64_t DupeBytes = 0;
+
+ struct SizeList
+ {
+ size_t Size;
+ DirEntryList_t* DirEntries;
+ };
+
+ std::vector<SizeList> SizeLists{FileSizeMap.size()};
+
+ {
+ int i = 0;
+
+ for (auto& kv : FileSizeMap)
+ {
+ ZEN_ASSERT(kv.first >= m_SizeThreshold);
+ SizeLists[i].Size = kv.first;
+ SizeLists[i].DirEntries = &kv.second;
+ ++i;
+ }
+ }
+
+ std::sort(begin(SizeLists), end(SizeLists), [](const SizeList& Lhs, const SizeList& Rhs) { return Lhs.Size > Rhs.Size; });
+
+ ZEN_CONSOLE("Bucket summary:");
+
+ std::vector<size_t> BucketId;
+ std::vector<size_t> BucketOffsets;
+ std::vector<size_t> BucketSizes;
+ std::vector<size_t> BucketFileCounts;
+
+ size_t TotalFileSizes = 0;
+ size_t TotalFileCount = 0;
+
+ {
+ size_t CurrentPow2 = 0;
+ size_t BucketSize = 0;
+ size_t BucketFileCount = 0;
+ bool FirstBucket = true;
+
+ for (size_t i = 0; i < SizeLists.size(); ++i)
+ {
+ const size_t ThisSize = SizeLists[i].Size;
+ const size_t Pow2 = zen::NextPow2(ThisSize);
+
+ if (CurrentPow2 != Pow2)
+ {
+ CurrentPow2 = Pow2;
+
+ if (!FirstBucket)
+ {
+ BucketSizes.push_back(BucketSize);
+ BucketFileCounts.push_back(BucketFileCount);
+ }
+
+ BucketId.push_back(Pow2);
+ BucketOffsets.push_back(i);
+
+ FirstBucket = false;
+ BucketSize = 0;
+ BucketFileCount = 0;
+ }
+
+ BucketSize += ThisSize;
+ TotalFileSizes += ThisSize;
+ BucketFileCount += SizeLists[i].DirEntries->size();
+ TotalFileCount += SizeLists[i].DirEntries->size();
+ }
+
+ if (!FirstBucket)
+ {
+ BucketSizes.push_back(BucketSize);
+ BucketFileCounts.push_back(BucketFileCount);
+ }
+
+ ZEN_ASSERT(BucketOffsets.size() == BucketSizes.size());
+ ZEN_ASSERT(BucketOffsets.size() == BucketFileCounts.size());
+ }
+
+ for (size_t i = 0; i < BucketOffsets.size(); ++i)
+ {
+ ZEN_CONSOLE(" Bucket {} : {}, {} candidates", zen::NiceBytes(BucketId[i]), zen::NiceBytes(BucketSizes[i]), BucketFileCounts[i]);
+ }
+
+ ZEN_CONSOLE("Total : {}, {} candidates", zen::NiceBytes(TotalFileSizes), TotalFileCount);
+
+ std::string CurrentNice;
+
+ for (SizeList& Size : SizeLists)
+ {
+ std::string CurNice{zen::NiceBytes(zen::NextPow2(Size.Size))};
+
+ if (CurNice != CurrentNice)
+ {
+ CurrentNice = CurNice;
+ ZEN_CONSOLE("Now scanning bucket: {}", CurrentNice);
+ }
+
+ std::unordered_map<zen::BLAKE3, const std::filesystem::directory_entry*, zen::BLAKE3::Hasher> DedupMap;
+
+ for (const auto& Entry : *Size.DirEntries)
+ {
+ zen::BLAKE3 Hash;
+
+ if constexpr (true)
+ {
+ zen::BLAKE3Stream b3s;
+
+ zen::ScanFile(Entry.path(), 64 * 1024, [&](const void* Data, size_t Size) { b3s.Append(Data, Size); });
+
+ Hash = b3s.GetHash();
+ }
+ else
+ {
+ zen::FileContents Contents = zen::ReadFile(Entry.path());
+
+ zen::BLAKE3Stream b3s;
+
+ for (zen::IoBuffer& Buffer : Contents.Data)
+ {
+ b3s.Append(Buffer.Data(), Buffer.Size());
+ }
+ Hash = b3s.GetHash();
+ }
+
+ if (const std::filesystem::directory_entry* Dupe = DedupMap[Hash])
+ {
+ std::string FileA = PathToUtf8(Dupe->path());
+ std::string FileB = PathToUtf8(Entry.path());
+
+ size_t MinLen = std::min(FileA.size(), FileB.size());
+ auto Its = std::mismatch(FileB.rbegin(), FileB.rbegin() + MinLen, FileA.rbegin());
+
+ if (Its.first != FileB.rbegin())
+ {
+ if (Its.first[-1] == '\\' || Its.first[-1] == '/')
+ --Its.first;
+
+ FileB = std::string(FileB.begin(), Its.first.base()) + "...";
+ }
+
+ ZEN_INFO("{} {} <-> {}", zen::NiceBytes(Entry.file_size()).c_str(), FileA.c_str(), FileB.c_str());
+
+ zen::CopyFileOptions Options;
+ Options.EnableClone = true;
+ Options.MustClone = true;
+
+ zen::CopyFile(Dupe->path(), Entry.path(), Options);
+
+ DupeBytes += Entry.file_size();
+ }
+ else
+ {
+ DedupMap[Hash] = &Entry;
+ }
+ }
+
+ Size.DirEntries->clear();
+ }
+
+ ZEN_CONSOLE("Elapsed: {} Deduped: {}", zen::NiceTimeSpanMs(Timer.GetElapsedTimeMs()), zen::NiceBytes(DupeBytes));
+
+ return 0;
+}
+
+} // namespace zen
diff --git a/src/zen/cmds/dedup.h b/src/zen/cmds/dedup.h
new file mode 100644
index 000000000..6318704f5
--- /dev/null
+++ b/src/zen/cmds/dedup.h
@@ -0,0 +1,28 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+namespace zen {
+
+/** Deduplicate files in a tree using block cloning
+ */
+class DedupCommand : public ZenCmdBase
+{
+public:
+ DedupCommand();
+ ~DedupCommand();
+
+ virtual cxxopts::Options& Options() override { return m_Options; }
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+
+private:
+ cxxopts::Options m_Options{"dedup", "Deduplicate files"};
+ std::vector<std::string> m_Positional;
+ std::string m_DedupSource;
+ std::string m_DedupTarget;
+ size_t m_SizeThreshold = 1024 * 1024;
+};
+
+} // namespace zen
diff --git a/src/zen/cmds/hash.cpp b/src/zen/cmds/hash.cpp
new file mode 100644
index 000000000..7987d7738
--- /dev/null
+++ b/src/zen/cmds/hash.cpp
@@ -0,0 +1,171 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hash.h"
+
+#include <zencore/blake3.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include <zencore/timer.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <ppl.h>
+#endif
+
+namespace zen {
+
+////////////////////////////////////////////////////////////////////////////////
+
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+
+namespace Concurrency {
+
+ template<typename IterType, typename LambdaType>
+ void parallel_for_each(IterType Cursor, IterType End, const LambdaType& Lambda)
+ {
+ for (; Cursor < End; ++Cursor)
+ {
+ Lambda(*Cursor);
+ }
+ }
+
+ template<typename T>
+ struct combinable
+ {
+ combinable<T>& local() { return *this; }
+
+ void operator+=(T Rhs) { Value += Rhs; }
+
+ template<typename LambdaType>
+ void combine_each(const LambdaType& Lambda)
+ {
+ Lambda(Value);
+ }
+
+ T Value = 0;
+ };
+
+} // namespace Concurrency
+
+#endif // ZEN_PLATFORM_LINUX|MAC
+
+////////////////////////////////////////////////////////////////////////////////
+
+HashCommand::HashCommand()
+{
+ m_Options.add_options()("d,dir", "Directory to scan", cxxopts::value<std::string>(m_ScanDirectory))(
+ "o,output",
+ "Output file",
+ cxxopts::value<std::string>(m_OutputFile));
+}
+
+HashCommand::~HashCommand() = default;
+
+int
+HashCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ bool valid = m_ScanDirectory.length();
+
+ if (!valid)
+ throw cxxopts::OptionParseException("Hash command requires a directory to scan");
+
+ // Gather list of files to process
+
+ ZEN_CONSOLE("Gathering files from {}", m_ScanDirectory);
+
+ struct FileEntry
+ {
+ std::filesystem::path FilePath;
+ zen::BLAKE3 FileHash;
+ };
+
+ std::vector<FileEntry> FileList;
+ uint64_t FileBytes = 0;
+
+ std::filesystem::path ScanDirectoryPath{m_ScanDirectory};
+
+ for (const std::filesystem::directory_entry& Entry : std::filesystem::recursive_directory_iterator(ScanDirectoryPath))
+ {
+ if (Entry.is_regular_file())
+ {
+ FileList.push_back({Entry.path()});
+ FileBytes += Entry.file_size();
+ }
+ }
+
+ ZEN_CONSOLE("Gathered {} files, total size {}", FileList.size(), zen::NiceBytes(FileBytes));
+
+ Concurrency::combinable<uint64_t> TotalBytes;
+
+ auto hashFile = [&](FileEntry& File) {
+ InternalFile InputFile;
+ InputFile.OpenRead(File.FilePath);
+ const uint8_t* DataPointer = (const uint8_t*)InputFile.MemoryMapFile();
+ const size_t DataSize = InputFile.GetFileSize();
+
+ File.FileHash = zen::BLAKE3::HashMemory(DataPointer, DataSize);
+
+ TotalBytes.local() += DataSize;
+ };
+
+ // Process them as quickly as possible
+
+ zen::Stopwatch Timer;
+
+#if 1
+ Concurrency::parallel_for_each(begin(FileList), end(FileList), [&](auto& file) { hashFile(file); });
+#else
+ for (const auto& file : FileList)
+ {
+ hashFile(file);
+ }
+#endif
+
+ size_t TotalByteCount = 0;
+
+ TotalBytes.combine_each([&](size_t Total) { TotalByteCount += Total; });
+
+ const uint64_t ElapsedMs = Timer.GetElapsedTimeMs();
+ ZEN_CONSOLE("Scanned {} files in {}", FileList.size(), zen::NiceTimeSpanMs(ElapsedMs));
+ ZEN_CONSOLE("Total bytes {} ({})", zen::NiceBytes(TotalByteCount), zen::NiceByteRate(TotalByteCount, ElapsedMs));
+
+ InternalFile Output;
+
+ if (m_OutputFile.empty())
+ {
+ // TEMPORARY -- should properly open stdout
+ Output.OpenWrite("CONOUT$", false);
+ }
+ else
+ {
+ Output.OpenWrite(m_OutputFile, true);
+ }
+
+ zen::ExtendableStringBuilder<256> Line;
+
+ uint64_t CurrentOffset = 0;
+
+ for (const auto& File : FileList)
+ {
+ Line.Append(File.FilePath.generic_u8string().c_str());
+ Line.Append(',');
+ File.FileHash.ToHexString(Line);
+ Line.Append('\n');
+
+ Output.Write(Line.Data(), Line.Size(), CurrentOffset);
+ CurrentOffset += Line.Size();
+
+ Line.Reset();
+ }
+
+ // TODO: implement snapshot enumeration and display
+ return 0;
+}
+
+} // namespace zen
diff --git a/src/zen/cmds/hash.h b/src/zen/cmds/hash.h
new file mode 100644
index 000000000..e5ee071e9
--- /dev/null
+++ b/src/zen/cmds/hash.h
@@ -0,0 +1,27 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../internalfile.h"
+#include "../zen.h"
+
+namespace zen {
+
+/** Generate hash list file
+ */
+class HashCommand : public ZenCmdBase
+{
+public:
+ HashCommand();
+ ~HashCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"hash", "Hash files"};
+ std::string m_ScanDirectory;
+ std::string m_OutputFile;
+};
+
+} // namespace zen
diff --git a/src/zen/cmds/print.cpp b/src/zen/cmds/print.cpp
new file mode 100644
index 000000000..67191605c
--- /dev/null
+++ b/src/zen/cmds/print.cpp
@@ -0,0 +1,193 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "print.h"
+
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include <zenhttp/httpshared.h>
+
+using namespace std::literals;
+
+namespace zen {
+
+static void
+PrintCbObject(CbObject Object)
+{
+ zen::StringBuilder<1024> ObjStr;
+ zen::CompactBinaryToJson(Object, ObjStr);
+ ZEN_CONSOLE("{}", ObjStr);
+}
+
+static void
+PrintCbObject(IoBuffer Data)
+{
+ zen::CbObject Object{SharedBuffer(Data)};
+
+ PrintCbObject(Object);
+}
+
+PrintCommand::PrintCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "s", "source", "Object payload file", cxxopts::value(m_Filename), "<file name>");
+ m_Options.parse_positional({"source"});
+}
+
+PrintCommand::~PrintCommand() = default;
+
+int
+PrintCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ // Validate arguments
+
+ if (m_Filename.empty())
+ throw std::runtime_error("No file specified");
+
+ zen::FileContents Fc;
+
+ if (m_Filename == "-")
+ {
+ Fc = zen::ReadStdIn();
+ }
+ else
+ {
+ Fc = zen::ReadFile(m_Filename);
+ }
+
+ if (Fc.ErrorCode)
+ {
+ ZEN_ERROR("Failed to read file '{}': {}", m_Filename, Fc.ErrorCode.message());
+
+ return 1;
+ }
+
+ IoBuffer Data = Fc.Flatten();
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(Data, RawHash, RawSize))
+ {
+ ZEN_CONSOLE("Compressed binary: size {}, raw size {}, hash: {}", Data.GetSize(), RawSize, RawHash);
+ }
+ else if (IsPackageMessage(Data))
+ {
+ CbPackage Package = ParsePackageMessage(Data);
+
+ CbObject Object = Package.GetObject();
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+
+ ZEN_CONSOLE("Package - {} attachments, object hash {}", Package.GetAttachments().size(), Package.GetObjectHash());
+ ZEN_CONSOLE("");
+
+ int AttachmentIndex = 1;
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ std::string AttachmentSize = "n/a";
+ const char* AttachmentType = "unknown";
+
+ if (Attachment.IsCompressedBinary())
+ {
+ AttachmentType = "Compressed";
+ AttachmentSize = fmt::format("{} ({} uncompressed)",
+ Attachment.AsCompressedBinary().GetCompressedSize(),
+ Attachment.AsCompressedBinary().DecodeRawSize());
+ }
+ else if (Attachment.IsBinary())
+ {
+ AttachmentType = "Binary";
+ AttachmentSize = fmt::format("{}", Attachment.AsBinary().GetSize());
+ }
+ else if (Attachment.IsObject())
+ {
+ AttachmentType = "Object";
+ AttachmentSize = fmt::format("{}", Attachment.AsObject().GetSize());
+ }
+ else if (Attachment.IsNull())
+ {
+ AttachmentType = "null";
+ }
+
+ ZEN_CONSOLE("Attachment #{} : {}, {}, size {}", AttachmentIndex, Attachment.GetHash(), AttachmentType, AttachmentSize);
+
+ ++AttachmentIndex;
+ }
+
+ ZEN_CONSOLE("---8<---");
+
+ PrintCbObject(Object);
+ }
+ else if (CbValidateError Result = ValidateCompactBinary(Data, CbValidateMode::All); Result == CbValidateError::None)
+ {
+ PrintCbObject(Data);
+ }
+ else
+ {
+ ZEN_ERROR("Data in file '{}' does not appear to be compact binary (validation error {:#x})", m_Filename, uint32_t(Result));
+
+ return 1;
+ }
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+PrintPackageCommand::PrintPackageCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "s", "source", "Package payload file", cxxopts::value(m_Filename), "<file name>");
+ m_Options.parse_positional({"source"});
+}
+
+PrintPackageCommand::~PrintPackageCommand()
+{
+}
+
+int
+PrintPackageCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ // Validate arguments
+
+ if (m_Filename.empty())
+ throw std::runtime_error("No file specified");
+
+ zen::FileContents Fc = zen::ReadFile(m_Filename);
+ IoBuffer Data = Fc.Flatten();
+ zen::CbPackage Package;
+
+ bool Ok = Package.TryLoad(Data) || zen::legacy::TryLoadCbPackage(Package, Data, &UniqueBuffer::Alloc);
+
+ if (Ok)
+ {
+ zen::StringBuilder<1024> ObjStr;
+ zen::CompactBinaryToJson(Package.GetObject(), ObjStr);
+ ZEN_CONSOLE("{}", ObjStr);
+ }
+ else
+ {
+ ZEN_ERROR("error: malformed package?");
+ }
+
+ return 0;
+}
+
+} // namespace zen
diff --git a/src/zen/cmds/print.h b/src/zen/cmds/print.h
new file mode 100644
index 000000000..09d91830a
--- /dev/null
+++ b/src/zen/cmds/print.h
@@ -0,0 +1,41 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+namespace zen {
+
+/** Print Compact Binary
+ */
+class PrintCommand : public ZenCmdBase
+{
+public:
+ PrintCommand();
+ ~PrintCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"print", "Print compact binary object"};
+ std::string m_Filename;
+};
+
+/** Print Compact Binary Package
+ */
+class PrintPackageCommand : public ZenCmdBase
+{
+public:
+ PrintPackageCommand();
+ ~PrintPackageCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"printpkg", "Print compact binary package"};
+ std::string m_Filename;
+};
+
+} // namespace zen
diff --git a/src/zen/cmds/projectstore.cpp b/src/zen/cmds/projectstore.cpp
new file mode 100644
index 000000000..fe0dd713e
--- /dev/null
+++ b/src/zen/cmds/projectstore.cpp
@@ -0,0 +1,930 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "projectstore.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/logging.h>
+#include <zencore/stream.h>
+#include <zenhttp/httpcommon.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace {
+
+using namespace std::literals;
+
+const std::string DefaultCloudAccessTokenEnvVariableName(
+#if ZEN_PLATFORM_WINDOWS
+ "UE-CloudDataCacheAccessToken"sv
+#endif
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ "UE_CloudDataCacheAccessToken"sv
+#endif
+);
+
+} // namespace
+
+///////////////////////////////////////
+
+DropProjectCommand::DropProjectCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>");
+ m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>");
+ m_Options.parse_positional({"project", "oplog"});
+}
+
+DropProjectCommand::~DropProjectCommand() = default;
+
+int
+DropProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ if (m_ProjectName.empty())
+ {
+ throw cxxopts::OptionParseException("Drop command requires a project");
+ }
+
+ cpr::Session Session;
+ if (m_OplogName.empty())
+ {
+ ZEN_CONSOLE("Dropping project '{}' from '{}'", m_ProjectName, m_HostName);
+ Session.SetUrl({fmt::format("{}/prj/{}", m_HostName, m_ProjectName)});
+ }
+ else
+ {
+ ZEN_CONSOLE("Dropping oplog '{}/{}' from '{}'", m_ProjectName, m_OplogName, m_HostName);
+ Session.SetUrl({fmt::format("{}/prj/{}/oplog/{}", m_HostName, m_ProjectName, m_OplogName)});
+ }
+
+ cpr::Response Result = Session.Delete();
+
+ if (zen::IsHttpSuccessCode(Result.status_code))
+ {
+ ZEN_CONSOLE("OK: drop succeeded");
+ return 0;
+ }
+
+ if (Result.status_code)
+ {
+ ZEN_ERROR("Drop failed: {}: {} ({})", Result.status_code, Result.reason, Result.text);
+ }
+ else
+ {
+ ZEN_ERROR("Drop failed: {}", Result.error.message);
+ }
+
+ return 1;
+}
+
+///////////////////////////////////////
+
+ProjectInfoCommand::ProjectInfoCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>");
+ m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>");
+ m_Options.parse_positional({"project", "oplog"});
+}
+
+ProjectInfoCommand::~ProjectInfoCommand() = default;
+
+int
+ProjectInfoCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Session Session;
+ Session.SetHeader(cpr::Header{{"Accept", "application/json"}});
+ if (m_ProjectName.empty())
+ {
+ ZEN_CONSOLE("Info from '{}'", m_HostName);
+ Session.SetUrl({fmt::format("{}/prj", m_HostName)});
+ }
+ else if (m_OplogName.empty())
+ {
+ ZEN_CONSOLE("Info on project '{}' from '{}'", m_ProjectName, m_HostName);
+ Session.SetUrl({fmt::format("{}/prj/{}", m_HostName, m_ProjectName)});
+ }
+ else
+ {
+ ZEN_CONSOLE("Info on oplog '{}/{}' from '{}'", m_ProjectName, m_OplogName, m_HostName);
+ Session.SetUrl({fmt::format("{}/prj/{}/oplog/{}", m_HostName, m_ProjectName, m_OplogName)});
+ }
+
+ cpr::Response Result = Session.Get();
+
+ if (zen::IsHttpSuccessCode(Result.status_code))
+ {
+ ZEN_CONSOLE("{}", Result.text);
+
+ return 0;
+ }
+
+ if (Result.status_code)
+ {
+ ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text);
+ }
+ else
+ {
+ ZEN_ERROR("Info failed: {}", Result.error.message);
+ }
+
+ return 1;
+}
+
+///////////////////////////////////////
+
+CreateProjectCommand::CreateProjectCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectId), "<projectid>");
+ m_Options.add_option("", "", "rootdir", "Absolute path to root directory", cxxopts::value(m_RootDir), "<root>");
+ m_Options.add_option("", "", "enginedir", "Absolute path to engine root directory", cxxopts::value(m_EngineRootDir), "<engineroot>");
+ m_Options.add_option("", "", "projectdir", "Absolute path to project directory", cxxopts::value(m_ProjectRootDir), "<projectroot>");
+ m_Options.add_option("", "", "projectfile", "Absolute path to .uproject file", cxxopts::value(m_ProjectFile), "<projectfile>");
+ m_Options.parse_positional({"project", "rootdir", "enginedir", "projectdir", "projectfile"});
+}
+
+CreateProjectCommand::~CreateProjectCommand() = default;
+
+int
+CreateProjectCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ using namespace std::literals;
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Session Session;
+ Session.SetHeader(cpr::Header{{"Accept", "application/json"}});
+
+ if (m_ProjectId.empty())
+ {
+ ZEN_ERROR("Project name must be given");
+ return 1;
+ }
+
+ Session.SetUrl({fmt::format("{}/prj/{}", m_HostName, m_ProjectId)});
+ cpr::Response Response = Session.Get();
+ if (zen::IsHttpSuccessCode(Response.status_code))
+ {
+ ZEN_CONSOLE("Project already exists.\n{}", Response.text);
+ return 1;
+ }
+
+ if (Response.status_code == static_cast<long>(zen::HttpResponseCode::NotFound))
+ {
+ zen::CbObjectWriter Project;
+ Project.AddString("id"sv, m_ProjectId);
+ Project.AddString("root"sv, m_RootDir);
+ Project.AddString("engine"sv, m_EngineRootDir);
+ Project.AddString("project"sv, m_ProjectRootDir);
+ Project.AddString("projectfile"sv, m_ProjectFile);
+ zen::IoBuffer ProjectPayload = Project.Save().GetBuffer().AsIoBuffer();
+ Session.SetBody(cpr::Body{(const char*)ProjectPayload.GetData(), ProjectPayload.GetSize()});
+ Session.SetHeader(cpr::Header{{"Accept", "text"}});
+ Response = Session.Post();
+ }
+
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+}
+
+///////////////////////////////////////
+
+CreateOplogCommand::CreateOplogCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectId), "<projectid>");
+ m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogId), "<oplogid>");
+ m_Options.add_option("", "", "gcpath", "Absolute path to oplog lifetime marker file", cxxopts::value(m_GcPath), "<path>");
+ m_Options.parse_positional({"project", "oplog", "gcpath"});
+}
+
+CreateOplogCommand::~CreateOplogCommand() = default;
+
+int
+CreateOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ using namespace std::literals;
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Session Session;
+ Session.SetHeader(cpr::Header{{"Accept", "application/json"}});
+
+ if (m_ProjectId.empty())
+ {
+ ZEN_ERROR("Project name must be given");
+ return 1;
+ }
+
+ if (m_OplogId.empty())
+ {
+ ZEN_ERROR("Oplog name must be given");
+ return 1;
+ }
+
+ Session.SetUrl({fmt::format("{}/prj/{}/oplog/{}", m_HostName, m_ProjectId, m_OplogId)});
+ cpr::Response Response = Session.Get();
+ if (zen::IsHttpSuccessCode(Response.status_code))
+ {
+ ZEN_CONSOLE("Oplog already exists.\n{}", Response.text);
+ return 1;
+ }
+
+ if (Response.status_code == static_cast<long>(zen::HttpResponseCode::NotFound))
+ {
+ Session.SetHeader(cpr::Header{{"Accept", "text"}});
+ if (!m_GcPath.empty())
+ {
+ zen::CbObjectWriter Oplog;
+ Oplog.AddString("gcpath"sv, m_GcPath);
+ zen::IoBuffer OplogPayload = Oplog.Save().GetBuffer().AsIoBuffer();
+ Session.SetBody(cpr::Body{(const char*)OplogPayload.GetData(), OplogPayload.GetSize()});
+ Session.SetHeader(cpr::Header{{"Accept", "text"}, {"Content-Type", std::string(ToString(zen::HttpContentType::kCbObject))}});
+ }
+
+ Response = Session.Post();
+ }
+
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+
+ return MapHttpToCommandReturnCode(Response);
+}
+
+///////////////////////////////////////
+
+ExportOplogCommand::ExportOplogCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>");
+ m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>");
+ m_Options.add_option("", "", "maxblocksize", "Max size for bundled attachments", cxxopts::value(m_MaxBlockSize), "<blocksize>");
+ m_Options.add_option("",
+ "",
+ "maxchunkembedsize",
+ "Max size for attachment to be bundled",
+ cxxopts::value(m_MaxChunkEmbedSize),
+ "<chunksize>");
+ m_Options.add_option("", "f", "force", "Force export of all attachments", cxxopts::value(m_Force), "<force>");
+ m_Options.add_option("",
+ "",
+ "disableblocks",
+ "Disable block creation and save all attachments individually (applies to file and cloud target)",
+ cxxopts::value(m_DisableBlocks),
+ "<disable>");
+
+ m_Options.add_option("", "", "cloud", "Cloud Storage URL", cxxopts::value(m_CloudUrl), "<url>");
+ m_Options.add_option("cloud", "", "namespace", "Cloud Storage namespace", cxxopts::value(m_CloudNamespace), "<namespace>");
+ m_Options.add_option("cloud", "", "bucket", "Cloud Storage bucket", cxxopts::value(m_CloudBucket), "<bucket>");
+ m_Options.add_option("cloud", "", "key", "Cloud Storage key", cxxopts::value(m_CloudKey), "<key>");
+ m_Options
+ .add_option("cloud", "", "openid-provider", "Cloud Storage openid provider", cxxopts::value(m_CloudOpenIdProvider), "<provider>");
+ m_Options.add_option("cloud", "", "access-token", "Cloud Storage access token", cxxopts::value(m_CloudAccessToken), "<accesstoken>");
+ m_Options.add_option("cloud",
+ "",
+ "access-token-env",
+ "Name of environment variable that holds the cloud Storage access token",
+ cxxopts::value(m_CloudAccessTokenEnv)->default_value(DefaultCloudAccessTokenEnvVariableName),
+ "<envvariable>");
+ m_Options.add_option("cloud",
+ "",
+ "disabletempblocks",
+ "Disable temp block creation and upload blocks without waiting for oplog container to be uploaded",
+ cxxopts::value(m_CloudDisableTempBlocks),
+ "<disable>");
+
+ m_Options.add_option("", "", "zen", "Zen service upload address", cxxopts::value(m_ZenUrl), "<url>");
+ m_Options.add_option("zen", "", "target-project", "Zen target project name", cxxopts::value(m_ZenProjectName), "<targetprojectid>");
+ m_Options.add_option("zen", "", "target-oplog", "Zen target oplog name", cxxopts::value(m_ZenOplogName), "<targetoplogid>");
+ m_Options.add_option("zen", "", "clean", "Delete existing target Zen oplog", cxxopts::value(m_ZenClean), "<clean>");
+
+ m_Options.add_option("", "", "file", "Local folder path", cxxopts::value(m_FileDirectoryPath), "<path>");
+ m_Options.add_option("file", "", "name", "Local file name", cxxopts::value(m_FileName), "<filename>");
+ m_Options.add_option("file",
+ "",
+ "forcetempblocks",
+ "Force creation of temp attachment blocks",
+ cxxopts::value(m_FileForceEnableTempBlocks),
+ "<forcetempblocks>");
+
+ m_Options.parse_positional({"project", "oplog"});
+}
+
+ExportOplogCommand::~ExportOplogCommand() = default;
+
+int
+ExportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ using namespace std::literals;
+
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ if (m_ProjectName.empty())
+ {
+ ZEN_ERROR("Project name must be given");
+ return 1;
+ }
+
+ if (m_OplogName.empty())
+ {
+ ZEN_ERROR("Oplog name must be given");
+ return 1;
+ }
+
+ size_t TargetCount = 0;
+ TargetCount += m_CloudUrl.empty() ? 0 : 1;
+ TargetCount += m_ZenUrl.empty() ? 0 : 1;
+ TargetCount += m_FileDirectoryPath.empty() ? 0 : 1;
+ if (TargetCount != 1)
+ {
+ ZEN_ERROR("Provide one target only");
+ ZEN_CONSOLE("{}", m_Options.help({""}).c_str());
+ return 1;
+ }
+
+ cpr::Session Session;
+
+ if (!m_CloudUrl.empty())
+ {
+ if (m_CloudNamespace.empty() || m_CloudBucket.empty())
+ {
+ ZEN_ERROR("Options for cloud target are missing");
+ ZEN_CONSOLE("{}", m_Options.help({"cloud"}).c_str());
+ return 1;
+ }
+ if (m_CloudKey.empty())
+ {
+ std::string KeyString = fmt::format("{}/{}/{}/{}", m_ProjectName, m_OplogName, m_CloudNamespace, m_CloudBucket);
+ zen::IoHash Key = zen::IoHash::HashBuffer(KeyString.data(), KeyString.size());
+ m_CloudKey = Key.ToHexString();
+ ZEN_WARN("Using auto generated cloud key '{}'", m_CloudKey);
+ }
+ }
+
+ if (!m_ZenUrl.empty())
+ {
+ if (m_ZenProjectName.empty())
+ {
+ m_ZenProjectName = m_ProjectName;
+ ZEN_WARN("Using default zen target project id '{}'", m_ZenProjectName);
+ }
+ if (m_ZenOplogName.empty())
+ {
+ m_ZenOplogName = m_OplogName;
+ ZEN_WARN("Using default zen target oplog id '{}'", m_ZenOplogName);
+ }
+
+ std::string TargetUrlBase = fmt::format("{}/prj", m_ZenUrl);
+ if (TargetUrlBase.find("://") == std::string::npos)
+ {
+ // Assume https URL
+ TargetUrlBase = fmt::format("http://{}", TargetUrlBase);
+ }
+
+ Session.SetUrl({fmt::format("{}/{}/oplog/{}", TargetUrlBase, m_ZenProjectName, m_ZenOplogName)});
+ cpr::Response Response = Session.Get();
+ if (Response.status_code == static_cast<long>(zen::HttpResponseCode::NotFound))
+ {
+ ZEN_WARN("Automatically creating oplog '{}/{}'", m_ZenProjectName, m_ZenOplogName)
+ Response = Session.Post();
+ if (!zen::IsHttpSuccessCode(Response.status_code))
+ {
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+ }
+ }
+ else if (!zen::IsHttpSuccessCode(Response.status_code))
+ {
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+ }
+ else if (m_ZenClean)
+ {
+ ZEN_WARN("Cleaning oplog '{}/{}'", m_ZenProjectName, m_ZenOplogName)
+ Response = Session.Delete();
+ if (!zen::IsHttpSuccessCode(Response.status_code))
+ {
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+ }
+ Response = Session.Post();
+ if (!zen::IsHttpSuccessCode(Response.status_code))
+ {
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+ }
+ }
+ }
+
+ if (!m_FileDirectoryPath.empty())
+ {
+ if (m_FileName.empty())
+ {
+ m_FileName = m_OplogName;
+ ZEN_WARN("Using default file name '{}'", m_FileName);
+ }
+ }
+
+ const std::string SourceUrlBase = fmt::format("{}/prj", m_HostName);
+ std::string TargetDescription;
+ Session.SetUrl({fmt::format("{}/{}/oplog/{}/rpc", SourceUrlBase, m_ProjectName, m_OplogName)});
+ Session.SetHeader({{"Content-Type", std::string(zen::MapContentTypeToString(zen::HttpContentType::kCbObject))}});
+ zen::CbObjectWriter Writer;
+ Writer.AddString("method"sv, "export"sv);
+ Writer.BeginObject("params"sv);
+ {
+ if (m_MaxBlockSize != 0)
+ {
+ Writer.AddInteger("maxblocksize"sv, m_MaxBlockSize);
+ }
+ if (m_MaxChunkEmbedSize != 0)
+ {
+ Writer.AddInteger("maxchunkembedsize"sv, m_MaxChunkEmbedSize);
+ }
+ if (m_Force)
+ {
+ Writer.AddBool("force"sv, true);
+ }
+ if (!m_FileDirectoryPath.empty())
+ {
+ Writer.BeginObject("file"sv);
+ {
+ Writer.AddString("file"sv, m_FileDirectoryPath);
+ Writer.AddString("name"sv, m_FileName);
+ if (m_DisableBlocks)
+ {
+ Writer.AddBool("disableblocks"sv, true);
+ }
+ if (m_FileForceEnableTempBlocks)
+ {
+ Writer.AddBool("enabletempblocks"sv, true);
+ }
+ }
+ Writer.EndObject(); // "file"
+ TargetDescription = fmt::format("[file] '{}/{}'", m_FileDirectoryPath, m_FileName);
+ }
+ if (!m_CloudUrl.empty())
+ {
+ Writer.BeginObject("cloud"sv);
+ {
+ Writer.AddString("url"sv, m_CloudUrl);
+ Writer.AddString("namespace"sv, m_CloudNamespace);
+ Writer.AddString("bucket"sv, m_CloudBucket);
+ Writer.AddString("key"sv, m_CloudKey);
+ if (!m_CloudOpenIdProvider.empty())
+ {
+ Writer.AddString("openid-provider"sv, m_CloudOpenIdProvider);
+ }
+ if (!m_CloudAccessToken.empty())
+ {
+ Writer.AddString("access-token"sv, m_CloudAccessToken);
+ }
+ if (!m_CloudAccessTokenEnv.empty())
+ {
+ Writer.AddString("access-token-env"sv, m_CloudAccessTokenEnv);
+ }
+ if (m_DisableBlocks)
+ {
+ Writer.AddBool("disableblocks"sv, true);
+ }
+ if (m_CloudDisableTempBlocks)
+ {
+ Writer.AddBool("disabletempblocks"sv, true);
+ }
+ }
+ Writer.EndObject(); // "cloud"
+ TargetDescription = fmt::format("[cloud] '{}/{}/{}/{}'", m_CloudUrl, m_CloudNamespace, m_CloudBucket, m_CloudKey);
+ }
+ if (!m_ZenUrl.empty())
+ {
+ Writer.BeginObject("zen"sv);
+ {
+ Writer.AddString("url"sv, m_ZenUrl);
+ Writer.AddString("project"sv, m_ZenProjectName);
+ Writer.AddString("oplog"sv, m_ZenOplogName);
+ }
+ Writer.EndObject(); // "zen"
+
+ TargetDescription = fmt::format("[zen] '{}/{}/{}'", m_ZenUrl, m_ZenProjectName, m_ZenOplogName);
+ }
+ }
+ Writer.EndObject(); // "params"
+
+ zen::BinaryWriter MemOut;
+ Writer.Save(MemOut);
+ Session.SetBody(cpr::Body{(const char*)MemOut.GetData(), MemOut.GetSize()});
+
+ ZEN_CONSOLE("Saving oplog '{}/{}' from '{}' to {}", m_ProjectName, m_OplogName, m_HostName, TargetDescription);
+ cpr::Response Response = Session.Post();
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+}
+
+////////////////////////////
+
+ImportOplogCommand::ImportOplogCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "p", "project", "Project name", cxxopts::value(m_ProjectName), "<projectid>");
+ m_Options.add_option("", "o", "oplog", "Oplog name", cxxopts::value(m_OplogName), "<oplogid>");
+ m_Options.add_option("", "", "maxblocksize", "Max size for bundled attachments", cxxopts::value(m_MaxBlockSize), "<blocksize>");
+ m_Options.add_option("",
+ "",
+ "maxchunkembedsize",
+ "Max size for attachment to be bundled",
+ cxxopts::value(m_MaxChunkEmbedSize),
+ "<chunksize>");
+ m_Options.add_option("", "f", "force", "Force import of all attachments", cxxopts::value(m_Force), "<force>");
+
+ m_Options.add_option("", "", "cloud", "Cloud Storage URL", cxxopts::value(m_CloudUrl), "<url>");
+ m_Options.add_option("cloud", "", "namespace", "Cloud Storage namespace", cxxopts::value(m_CloudNamespace), "<namespace>");
+ m_Options.add_option("cloud", "", "bucket", "Cloud Storage bucket", cxxopts::value(m_CloudBucket), "<bucket>");
+ m_Options.add_option("cloud", "", "key", "Cloud Storage key", cxxopts::value(m_CloudKey), "<key>");
+ m_Options
+ .add_option("cloud", "", "openid-provider", "Cloud Storage openid provider", cxxopts::value(m_CloudOpenIdProvider), "<provider>");
+ m_Options.add_option("cloud", "", "access-token", "Cloud Storage access token", cxxopts::value(m_CloudAccessToken), "<accesstoken>");
+ m_Options.add_option("cloud",
+ "",
+ "access-token-env",
+ "Name of environment variable that holds the cloud Storage access token",
+ cxxopts::value(m_CloudAccessTokenEnv)->default_value(DefaultCloudAccessTokenEnvVariableName),
+ "<envvariable>");
+
+ m_Options.add_option("", "", "zen", "Zen service upload address", cxxopts::value(m_ZenUrl), "<url>");
+ m_Options.add_option("zen", "", "source-project", "Zen source project name", cxxopts::value(m_ZenProjectName), "<sourceprojectid>");
+ m_Options.add_option("zen", "", "source-oplog", "Zen source oplog name", cxxopts::value(m_ZenOplogName), "<sourceoplogid>");
+ m_Options.add_option("zen", "", "clean", "Delete existing target Zen oplog", cxxopts::value(m_ZenClean), "<clean>");
+
+ m_Options.add_option("", "", "file", "Local folder path", cxxopts::value(m_FileDirectoryPath), "<path>");
+ m_Options.add_option("file", "", "name", "Local file name", cxxopts::value(m_FileName), "<filename>");
+
+ m_Options.parse_positional({"project", "oplog"});
+}
+
+ImportOplogCommand::~ImportOplogCommand() = default;
+
+int
+ImportOplogCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ using namespace std::literals;
+
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ if (m_ProjectName.empty())
+ {
+ ZEN_ERROR("Project name must be given");
+ return 1;
+ }
+
+ if (m_OplogName.empty())
+ {
+ ZEN_ERROR("Oplog name must be given");
+ return 1;
+ }
+
+ size_t TargetCount = 0;
+ TargetCount += m_CloudUrl.empty() ? 0 : 1;
+ TargetCount += m_ZenUrl.empty() ? 0 : 1;
+ TargetCount += m_FileDirectoryPath.empty() ? 0 : 1;
+ if (TargetCount != 1)
+ {
+ ZEN_ERROR("Provide one source only");
+ ZEN_CONSOLE("{}", m_Options.help({""}).c_str());
+ return 1;
+ }
+
+ cpr::Session Session;
+
+ if (!m_CloudUrl.empty())
+ {
+ if (m_CloudNamespace.empty() || m_CloudBucket.empty())
+ {
+ ZEN_ERROR("Options for cloud source are missing");
+ ZEN_CONSOLE("{}", m_Options.help({"cloud"}).c_str());
+ return 1;
+ }
+ if (m_CloudKey.empty())
+ {
+ std::string KeyString = fmt::format("{}/{}/{}/{}", m_ProjectName, m_OplogName, m_CloudNamespace, m_CloudBucket);
+ zen::IoHash Key = zen::IoHash::HashBuffer(KeyString.data(), KeyString.size());
+ m_CloudKey = Key.ToHexString();
+ ZEN_WARN("Using auto generated cloud key '{}'", m_CloudKey);
+ }
+ }
+
+ if (!m_ZenUrl.empty())
+ {
+ if (m_ZenProjectName.empty())
+ {
+ m_ZenProjectName = m_ProjectName;
+ ZEN_WARN("Using default zen target project id '{}'", m_ZenProjectName);
+ }
+ if (m_ZenOplogName.empty())
+ {
+ m_ZenOplogName = m_OplogName;
+ ZEN_WARN("Using default zen target oplog id '{}'", m_ZenOplogName);
+ }
+ }
+
+ if (!m_FileDirectoryPath.empty())
+ {
+ if (m_FileName.empty())
+ {
+ m_FileName = m_OplogName;
+ ZEN_WARN("Using auto generated file name '{}'", m_FileName);
+ }
+ }
+
+ const std::string TargetUrlBase = fmt::format("{}/prj", m_HostName);
+ Session.SetUrl({fmt::format("{}/{}/oplog/{}", TargetUrlBase, m_ProjectName, m_OplogName)});
+ cpr::Response Response = Session.Get();
+ if (Response.status_code == static_cast<long>(zen::HttpResponseCode::NotFound))
+ {
+ ZEN_WARN("Automatically creating oplog '{}/{}'", m_ProjectName, m_OplogName)
+ Response = Session.Post();
+ if (!zen::IsHttpSuccessCode(Response.status_code))
+ {
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+ }
+ }
+ else if (!zen::IsHttpSuccessCode(Response.status_code))
+ {
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+ }
+ else if (m_ZenClean)
+ {
+ ZEN_WARN("Cleaning oplog '{}/{}'", m_ProjectName, m_OplogName)
+ Response = Session.Delete();
+ if (!zen::IsHttpSuccessCode(Response.status_code))
+ {
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+ }
+ Response = Session.Post();
+ if (!zen::IsHttpSuccessCode(Response.status_code))
+ {
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+ }
+ }
+
+ std::string SourceDescription;
+ Session.SetUrl(fmt::format("{}/{}/oplog/{}/rpc", TargetUrlBase, m_ProjectName, m_OplogName));
+ Session.SetHeader({{"Content-Type", std::string(zen::MapContentTypeToString(zen::HttpContentType::kCbObject))}});
+
+ zen::CbObjectWriter Writer;
+ Writer.AddString("method"sv, "import"sv);
+ Writer.BeginObject("params"sv);
+ {
+ if (m_Force)
+ {
+ Writer.AddBool("force"sv, true);
+ }
+ if (!m_FileDirectoryPath.empty())
+ {
+ Writer.BeginObject("file"sv);
+ {
+ Writer.AddString("file"sv, m_FileDirectoryPath);
+ Writer.AddString("name"sv, m_FileName);
+ }
+ Writer.EndObject(); // "file"
+ SourceDescription = fmt::format("[file] '{}/{}'", m_FileDirectoryPath, m_FileName);
+ }
+ if (!m_CloudUrl.empty())
+ {
+ Writer.BeginObject("cloud"sv);
+ {
+ Writer.AddString("url"sv, m_CloudUrl);
+ Writer.AddString("namespace"sv, m_CloudNamespace);
+ Writer.AddString("bucket"sv, m_CloudBucket);
+ Writer.AddString("key"sv, m_CloudKey);
+ if (!m_CloudOpenIdProvider.empty())
+ {
+ Writer.AddString("openid-provider"sv, m_CloudOpenIdProvider);
+ }
+ if (!m_CloudAccessToken.empty())
+ {
+ Writer.AddString("access-token"sv, m_CloudAccessToken);
+ }
+ if (!m_CloudAccessTokenEnv.empty())
+ {
+ Writer.AddString("access-token-env"sv, m_CloudAccessTokenEnv);
+ }
+ }
+ Writer.EndObject(); // "cloud"
+ SourceDescription = fmt::format("[cloud] '{}/{}/{}/{}'", m_CloudUrl, m_CloudNamespace, m_CloudBucket, m_CloudKey);
+ }
+ if (!m_ZenUrl.empty())
+ {
+ Writer.BeginObject("zen"sv);
+ {
+ Writer.AddString("url"sv, m_ZenUrl);
+ Writer.AddString("project"sv, m_ZenProjectName);
+ Writer.AddString("oplog"sv, m_ZenOplogName);
+ }
+ Writer.EndObject(); // "zen"
+ SourceDescription = fmt::format("[zen] '{}'", m_ZenUrl);
+ }
+ }
+ Writer.EndObject(); // "params"
+
+ zen::BinaryWriter MemOut;
+ Writer.Save(MemOut);
+ Session.SetBody(cpr::Body{(const char*)MemOut.GetData(), MemOut.GetSize()});
+
+ ZEN_CONSOLE("Loading oplog '{}/{}' from '{}' to {}", m_ProjectName, m_OplogName, SourceDescription, m_HostName);
+ Response = Session.Post();
+
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+}
+
+ProjectStatsCommand::ProjectStatsCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+}
+
+ProjectStatsCommand::~ProjectStatsCommand() = default;
+
+int
+ProjectStatsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Session Session;
+ Session.SetUrl({fmt::format("{}/stats/prj", m_HostName)});
+ Session.SetHeader(cpr::Header{{"Accept", "application/json"}});
+
+ cpr::Response Result = Session.Get();
+
+ if (zen::IsHttpSuccessCode(Result.status_code))
+ {
+ ZEN_CONSOLE("{}", Result.text);
+
+ return 0;
+ }
+
+ if (Result.status_code)
+ {
+ ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text);
+ }
+ else
+ {
+ ZEN_ERROR("Info failed: {}", Result.error.message);
+ }
+
+ return 1;
+}
+
+ProjectDetailsCommand::ProjectDetailsCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "c", "csv", "Output in CSV format (default is JSon)", cxxopts::value(m_CSV), "<csv>");
+ m_Options.add_option("", "d", "details", "Detailed info on opslog", cxxopts::value(m_Details), "<details>");
+ m_Options.add_option("", "o", "opdetails", "Details info on oplog body", cxxopts::value(m_OpDetails), "<opdetails>");
+ m_Options.add_option("", "p", "project", "Project name to get info from", cxxopts::value(m_ProjectName), "<projectid>");
+ m_Options.add_option("", "l", "oplog", "Oplog name to get info from", cxxopts::value(m_OplogName), "<oplogid>");
+ m_Options.add_option("", "i", "opid", "Oid of a specific op info for", cxxopts::value(m_OpId), "<opid>");
+ m_Options.add_option("",
+ "a",
+ "attachmentdetails",
+ "Get detailed information about attachments",
+ cxxopts::value(m_AttachmentDetails),
+ "<attachmentdetails>");
+}
+
+ProjectDetailsCommand::~ProjectDetailsCommand() = default;
+
+int
+ProjectDetailsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Session Session;
+ cpr::Parameters Parameters;
+ if (m_OpDetails)
+ {
+ Parameters.Add({"opdetails", "true"});
+ }
+ if (m_Details)
+ {
+ Parameters.Add({"details", "true"});
+ }
+ if (m_AttachmentDetails)
+ {
+ Parameters.Add({"attachmentdetails", "true"});
+ }
+ if (m_CSV)
+ {
+ Parameters.Add({"csv", "true"});
+ }
+ else
+ {
+ Session.SetHeader(cpr::Header{{"Accept", "application/json"}});
+ }
+
+ if (!m_OpId.empty())
+ {
+ if (m_ProjectName.empty() || m_OplogName.empty())
+ {
+ ZEN_ERROR("Provide project and oplog name");
+ ZEN_CONSOLE("{}", m_Options.help({""}).c_str());
+ return 1;
+ }
+ Session.SetUrl({fmt::format("{}/prj/details$/{}/{}/{}", m_HostName, m_ProjectName, m_OplogName, m_OpId)});
+ }
+ else if (!m_OplogName.empty())
+ {
+ if (m_ProjectName.empty())
+ {
+ ZEN_ERROR("Provide project name");
+ ZEN_CONSOLE("{}", m_Options.help({""}).c_str());
+ return 1;
+ }
+ Session.SetUrl({fmt::format("{}/prj/details$/{}/{}", m_HostName, m_ProjectName, m_OplogName)});
+ }
+ else if (!m_ProjectName.empty())
+ {
+ Session.SetUrl({fmt::format("{}/prj/details$/{}", m_HostName, m_ProjectName)});
+ }
+ else
+ {
+ Session.SetUrl({fmt::format("{}/prj/details$", m_HostName)});
+ }
+ Session.SetParameters(Parameters);
+
+ cpr::Response Result = Session.Get();
+
+ if (zen::IsHttpSuccessCode(Result.status_code))
+ {
+ ZEN_CONSOLE("{}", Result.text);
+
+ return 0;
+ }
+
+ if (Result.status_code)
+ {
+ ZEN_ERROR("Info failed: {}: {} ({})", Result.status_code, Result.reason, Result.text);
+ }
+ else
+ {
+ ZEN_ERROR("Info failed: {}", Result.error.message);
+ }
+
+ return 1;
+}
diff --git a/src/zen/cmds/projectstore.h b/src/zen/cmds/projectstore.h
new file mode 100644
index 000000000..10927a546
--- /dev/null
+++ b/src/zen/cmds/projectstore.h
@@ -0,0 +1,180 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+class DropProjectCommand : public ZenCmdBase
+{
+public:
+ DropProjectCommand();
+ ~DropProjectCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"project-drop", "Drop project or project oplog"};
+ std::string m_HostName;
+ std::string m_ProjectName;
+ std::string m_OplogName;
+};
+
+class ProjectInfoCommand : public ZenCmdBase
+{
+public:
+ ProjectInfoCommand();
+ ~ProjectInfoCommand();
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"project-info", "Info on project or project oplog"};
+ std::string m_HostName;
+ std::string m_ProjectName;
+ std::string m_OplogName;
+};
+
+class CreateProjectCommand : public ZenCmdBase
+{
+public:
+ CreateProjectCommand();
+ ~CreateProjectCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"project-create", "Create project"};
+ std::string m_HostName;
+ std::string m_ProjectId;
+ std::string m_RootDir;
+ std::string m_EngineRootDir;
+ std::string m_ProjectRootDir;
+ std::string m_ProjectFile;
+};
+
+class CreateOplogCommand : public ZenCmdBase
+{
+public:
+ CreateOplogCommand();
+ ~CreateOplogCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"oplog-create", "Create oplog"};
+ std::string m_HostName;
+ std::string m_ProjectId;
+ std::string m_OplogId;
+ std::string m_GcPath;
+};
+
+class ExportOplogCommand : public ZenCmdBase
+{
+public:
+ ExportOplogCommand();
+ ~ExportOplogCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"oplog-export",
+ "Export project store oplog to cloud (--cloud), file system (--file) or other Zen instance (--zen)"};
+ std::string m_HostName;
+ std::string m_ProjectName;
+ std::string m_OplogName;
+ uint64_t m_MaxBlockSize = 0;
+ uint64_t m_MaxChunkEmbedSize = 0;
+ bool m_Force = false;
+ bool m_DisableBlocks = false;
+
+ std::string m_CloudUrl;
+ std::string m_CloudNamespace;
+ std::string m_CloudBucket;
+ std::string m_CloudKey;
+ std::string m_CloudOpenIdProvider;
+ std::string m_CloudAccessToken;
+ std::string m_CloudAccessTokenEnv;
+ bool m_CloudDisableTempBlocks = false;
+
+ std::string m_ZenUrl;
+ std::string m_ZenProjectName;
+ std::string m_ZenOplogName;
+ bool m_ZenClean;
+
+ std::string m_FileDirectoryPath;
+ std::string m_FileName;
+ bool m_FileForceEnableTempBlocks = false;
+};
+
+class ImportOplogCommand : public ZenCmdBase
+{
+public:
+ ImportOplogCommand();
+ ~ImportOplogCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"oplog-import",
+ "Import project store oplog from cloud (--cloud), file system (--file) or other Zen instance (--zen)"};
+ std::string m_HostName;
+ std::string m_ProjectName;
+ std::string m_OplogName;
+ size_t m_MaxBlockSize = 0;
+ size_t m_MaxChunkEmbedSize = 0;
+ bool m_Force = false;
+
+ std::string m_CloudUrl;
+ std::string m_CloudNamespace;
+ std::string m_CloudBucket;
+ std::string m_CloudKey;
+ std::string m_CloudOpenIdProvider;
+ std::string m_CloudAccessToken;
+ std::string m_CloudAccessTokenEnv;
+
+ std::string m_ZenUrl;
+ std::string m_ZenProjectName;
+ std::string m_ZenOplogName;
+ bool m_ZenClean;
+
+ std::string m_FileDirectoryPath;
+ std::string m_FileName;
+};
+
+class ProjectStatsCommand : public ZenCmdBase
+{
+public:
+ ProjectStatsCommand();
+ ~ProjectStatsCommand();
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"project-stats", "Stats info on project store"};
+ std::string m_HostName;
+};
+
+class ProjectDetailsCommand : public ZenCmdBase
+{
+public:
+ ProjectDetailsCommand();
+ ~ProjectDetailsCommand();
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"project-details", "Detail info on project store"};
+ std::string m_HostName;
+ bool m_Details;
+ bool m_OpDetails;
+ bool m_AttachmentDetails;
+ bool m_CSV;
+ std::string m_ProjectName;
+ std::string m_OplogName;
+ std::string m_OpId;
+};
diff --git a/src/zen/cmds/rpcreplay.cpp b/src/zen/cmds/rpcreplay.cpp
new file mode 100644
index 000000000..9bc4b2c7b
--- /dev/null
+++ b/src/zen/cmds/rpcreplay.cpp
@@ -0,0 +1,417 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "rpcreplay.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/workthreadpool.h>
+#include <zenhttp/httpcommon.h>
+#include <zenhttp/httpshared.h>
+#include <zenutil/cache/rpcrecording.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <fmt/format.h>
+#include <gsl/gsl-lite.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <memory>
+
+namespace zen {
+
+using namespace std::literals;
+
+RpcStartRecordingCommand::RpcStartRecordingCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "p", "path", "Recording file path", cxxopts::value(m_RecordingPath), "<path>");
+
+ m_Options.parse_positional("path");
+}
+
+RpcStartRecordingCommand::~RpcStartRecordingCommand() = default;
+
+int
+RpcStartRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions, argc, argv);
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ if (m_RecordingPath.empty())
+ {
+ throw cxxopts::OptionParseException("Rpc start recording command requires a path");
+ }
+
+ cpr::Session Session;
+ Session.SetUrl(fmt::format("{}/z$/exec$/start-recording"sv, m_HostName));
+ Session.SetParameters({{"path", m_RecordingPath}});
+ cpr::Response Response = Session.Post();
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+}
+
+////////////////////////////////////////////////////
+
+RpcStopRecordingCommand::RpcStopRecordingCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+}
+
+RpcStopRecordingCommand::~RpcStopRecordingCommand() = default;
+
+int
+RpcStopRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions, argc, argv);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Session Session;
+ Session.SetUrl(fmt::format("{}/z$/exec$/stop-recording"sv, m_HostName));
+ cpr::Response Response = Session.Post();
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+}
+
+////////////////////////////////////////////////////
+
+RpcReplayCommand::RpcReplayCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("", "p", "path", "Recording file path", cxxopts::value(m_RecordingPath), "<path>");
+ m_Options.add_option("",
+ "w",
+ "numthreads",
+ "Number of worker threads per process",
+ cxxopts::value(m_ThreadCount)->default_value(fmt::format("{}", std::thread::hardware_concurrency())),
+ "<count>");
+ m_Options.add_option("", "", "onhost", "Replay on host, bypassing http/network layer", cxxopts::value(m_OnHost), "<onhost>");
+ m_Options.add_option("",
+ "",
+ "showmethodstats",
+ "Show statistics of which RPC methods are used",
+ cxxopts::value(m_ShowMethodStats),
+ "<showmethodstats>");
+ m_Options.add_option("",
+ "",
+ "offset",
+ "Offset into request recording to start replay",
+ cxxopts::value(m_Offset)->default_value("0"),
+ "<offset>");
+ m_Options.add_option("",
+ "",
+ "stride",
+ "Stride for request recording when replaying requests",
+ cxxopts::value(m_Stride)->default_value("1"),
+ "<stride>");
+ m_Options.add_option("", "", "numproc", "Number of worker processes", cxxopts::value(m_ProcessCount)->default_value("1"), "<count>");
+ m_Options.add_option("",
+ "",
+ "forceallowlocalrefs",
+ "Force enable local refs in requests",
+ cxxopts::value(m_ForceAllowLocalRefs),
+ "<enable>");
+ m_Options
+ .add_option("", "", "disablelocalrefs", "Force disable local refs in requests", cxxopts::value(m_DisableLocalRefs), "<enable>");
+ m_Options.add_option("",
+ "",
+ "forceallowlocalhandlerefs",
+ "Force enable local refs as handles in requests",
+ cxxopts::value(m_ForceAllowLocalHandleRef),
+ "<enable>");
+ m_Options.add_option("",
+ "",
+ "disablelocalhandlerefs",
+ "Force disable local refs as handles in requests",
+ cxxopts::value(m_DisableLocalHandleRefs),
+ "<enable>");
+ m_Options.add_option("",
+ "",
+ "forceallowpartiallocalrefs",
+ "Force enable local refs for all sizes",
+ cxxopts::value(m_ForceAllowPartialLocalRefs),
+ "<enable>");
+ m_Options.add_option("",
+ "",
+ "disablepartiallocalrefs",
+ "Force disable local refs for all sizes",
+ cxxopts::value(m_DisablePartialLocalRefs),
+ "<enable>");
+
+ m_Options.parse_positional("path");
+}
+
+RpcReplayCommand::~RpcReplayCommand() = default;
+
+int
+RpcReplayCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions, argc, argv);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ if (m_RecordingPath.empty())
+ {
+ throw cxxopts::OptionParseException("Rpc replay command requires a path");
+ }
+
+ if (m_OnHost)
+ {
+ cpr::Session Session;
+ Session.SetUrl(fmt::format("{}/z$/exec$/replay-recording"sv, m_HostName));
+ Session.SetParameters({{"path", m_RecordingPath}, {"thread-count", fmt::format("{}", m_ThreadCount)}});
+ cpr::Response Response = Session.Post();
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ return MapHttpToCommandReturnCode(Response);
+ }
+
+ std::unique_ptr<cache::IRpcRequestReplayer> Replayer = cache::MakeDiskRequestReplayer(m_RecordingPath, true);
+ uint64_t EntryCount = Replayer->GetRequestCount();
+
+ std::atomic_uint64_t EntryOffset = m_Offset;
+ std::atomic_uint64_t BytesSent = 0;
+ std::atomic_uint64_t BytesReceived = 0;
+
+ Stopwatch Timer;
+
+ if (m_ProcessCount > 1)
+ {
+ std::vector<std::unique_ptr<ProcessHandle>> WorkerProcesses;
+ WorkerProcesses.resize(m_ProcessCount);
+
+ ProcessMonitor Monitor;
+ for (int ProcessIndex = 0; ProcessIndex < m_ProcessCount; ++ProcessIndex)
+ {
+ std::string CommandLine =
+ fmt::format("{} rpc-record-replay --hosturl {} --path \"{}\" --offset {} --stride {} --numthreads {} --numproc {}"sv,
+ argv[0],
+ m_HostName,
+ m_RecordingPath,
+ m_Stride == 1 ? 0 : m_Offset + ProcessIndex,
+ m_Stride,
+ m_ThreadCount,
+ 1);
+ CreateProcResult Result(CreateProc(std::filesystem::path(std::string(argv[0])), CommandLine));
+ WorkerProcesses[ProcessIndex] = std::make_unique<ProcessHandle>();
+ WorkerProcesses[ProcessIndex]->Initialize(Result);
+ Monitor.AddPid(WorkerProcesses[ProcessIndex]->Pid());
+ }
+ while (Monitor.IsRunning())
+ {
+ ZEN_CONSOLE("Waiting for worker processes...");
+ Sleep(1000);
+ }
+ return 0;
+ }
+ else
+ {
+ std::map<std::string, size_t> MethodTypes;
+ RwLock MethodTypesLock;
+
+ WorkerThreadPool WorkerPool(m_ThreadCount);
+
+ Latch WorkLatch(m_ThreadCount);
+ for (int WorkerIndex = 0; WorkerIndex < m_ThreadCount; ++WorkerIndex)
+ {
+ WorkerPool.ScheduleWork(
+ [this, &WorkLatch, EntryCount, &EntryOffset, &Replayer, &BytesSent, &BytesReceived, &MethodTypes, &MethodTypesLock]() {
+ auto _ = MakeGuard([&WorkLatch]() { WorkLatch.CountDown(); });
+
+ cpr::Session Session;
+ Session.SetUrl(fmt::format("{}/z$/$rpc"sv, m_HostName));
+
+ uint64_t EntryIndex = EntryOffset.fetch_add(m_Stride);
+ while (EntryIndex < EntryCount)
+ {
+ IoBuffer Payload;
+ std::pair<ZenContentType, ZenContentType> Types = Replayer->GetRequest(EntryIndex, Payload);
+ ZenContentType RequestContentType = Types.first;
+ ZenContentType AcceptContentType = Types.second;
+
+ CbPackage RequestPackage;
+ CbObject Request;
+ switch (RequestContentType)
+ {
+ case ZenContentType::kCbPackage:
+ {
+ if (ParsePackageMessageWithLegacyFallback(Payload, RequestPackage))
+ {
+ Request = RequestPackage.GetObject();
+ }
+ }
+ break;
+ case ZenContentType::kCbObject:
+ {
+ Request = LoadCompactBinaryObject(Payload);
+ }
+ break;
+ }
+
+ RpcAcceptOptions OriginalAcceptOptions = static_cast<RpcAcceptOptions>(Request["AcceptFlags"sv].AsUInt16(0u));
+ int OriginalProcessPid = Request["Pid"sv].AsInt32(0);
+
+ int AdjustedPid = 0;
+ RpcAcceptOptions AdjustedAcceptOptions = RpcAcceptOptions::kNone;
+ if (!m_DisableLocalRefs)
+ {
+ if (EnumHasAnyFlags(OriginalAcceptOptions, RpcAcceptOptions::kAllowLocalReferences) || m_ForceAllowLocalRefs)
+ {
+ AdjustedAcceptOptions |= RpcAcceptOptions::kAllowLocalReferences;
+ if (!m_DisablePartialLocalRefs)
+ {
+ if (EnumHasAnyFlags(OriginalAcceptOptions, RpcAcceptOptions::kAllowPartialLocalReferences) ||
+ m_ForceAllowPartialLocalRefs)
+ {
+ AdjustedAcceptOptions |= RpcAcceptOptions::kAllowPartialLocalReferences;
+ }
+ }
+ if (!m_DisableLocalHandleRefs)
+ {
+ if (OriginalProcessPid != 0 || m_ForceAllowLocalHandleRef)
+ {
+ AdjustedPid = GetCurrentProcessId();
+ }
+ }
+ }
+ }
+
+ if (m_ShowMethodStats)
+ {
+ std::string MethodName = std::string(Request["Method"sv].AsString());
+ RwLock::ExclusiveLockScope __(MethodTypesLock);
+ if (auto It = MethodTypes.find(MethodName); It != MethodTypes.end())
+ {
+ It->second++;
+ }
+ else
+ {
+ MethodTypes[MethodName] = 1;
+ }
+ }
+
+ if (OriginalAcceptOptions != AdjustedAcceptOptions || OriginalProcessPid != AdjustedPid)
+ {
+ CbObjectWriter RequestCopyWriter;
+ for (const CbFieldView& Field : Request)
+ {
+ if (!Field.HasName())
+ {
+ RequestCopyWriter.AddField(Field);
+ continue;
+ }
+ std::string_view FieldName = Field.GetName();
+ if (FieldName == "Pid"sv)
+ {
+ continue;
+ }
+ if (FieldName == "AcceptFlags"sv)
+ {
+ continue;
+ }
+ RequestCopyWriter.AddField(FieldName, Field);
+ }
+ if (AdjustedPid != 0)
+ {
+ RequestCopyWriter.AddInteger("Pid"sv, AdjustedPid);
+ }
+ if (AdjustedAcceptOptions != RpcAcceptOptions::kNone)
+ {
+ RequestCopyWriter.AddInteger("AcceptFlags"sv, static_cast<uint16_t>(AdjustedAcceptOptions));
+ }
+
+ if (RequestContentType == ZenContentType::kCbPackage)
+ {
+ RequestPackage.SetObject(RequestCopyWriter.Save());
+ std::vector<IoBuffer> Buffers = FormatPackageMessage(RequestPackage);
+ std::vector<SharedBuffer> SharedBuffers(Buffers.begin(), Buffers.end());
+ Payload = CompositeBuffer(std::move(SharedBuffers)).Flatten().AsIoBuffer();
+ }
+ else
+ {
+ RequestCopyWriter.Finalize();
+ Payload = IoBuffer(RequestCopyWriter.GetSaveSize());
+ RequestCopyWriter.Save(Payload.GetMutableView());
+ }
+ }
+
+ Session.SetHeader({{"Content-Type", std::string(MapContentTypeToString(RequestContentType))},
+ {"Accept", std::string(MapContentTypeToString(AcceptContentType))}});
+ uint64_t Offset = 0;
+ auto ReadCallback = [&Payload, &Offset](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, Payload.GetSize() - Offset);
+ IoBuffer PayloadRange = IoBuffer(Payload, Offset, size);
+ MutableMemoryView Data(buffer, size);
+ Data.CopyFrom(PayloadRange.GetView());
+ Offset += size;
+ return true;
+ };
+ Session.SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
+ cpr::Response Response = Session.Post();
+ BytesSent.fetch_add(Payload.GetSize());
+ if (Response.error || !(IsHttpSuccessCode(Response.status_code) ||
+ Response.status_code == gsl::narrow<long>(HttpResponseCode::NotFound)))
+ {
+ ZEN_CONSOLE("{}", FormatHttpResponse(Response));
+ break;
+ }
+ BytesReceived.fetch_add(Response.downloaded_bytes);
+ EntryIndex = EntryOffset.fetch_add(m_Stride);
+ }
+ });
+ }
+
+ while (!WorkLatch.Wait(1000))
+ {
+ ZEN_CONSOLE("Processing {} requests, {} remaining (sent {}, recevied {})...",
+ (EntryCount - m_Offset) / m_Stride,
+ (EntryCount - EntryOffset.load()) / m_Stride,
+ NiceBytes(BytesSent.load()),
+ NiceBytes(BytesReceived.load()));
+ }
+ if (m_ShowMethodStats)
+ {
+ for (const auto& It : MethodTypes)
+ {
+ ZEN_CONSOLE("{}: {}", It.first, It.second);
+ }
+ }
+ }
+
+ const uint64_t RequestsSent = (EntryOffset.load() - m_Offset) / m_Stride;
+ const uint64_t ElapsedMS = Timer.GetElapsedTimeMs();
+ const double ElapsedS = ElapsedMS / 1000.500;
+ const uint64_t Sent = BytesSent.load();
+ const uint64_t Received = BytesReceived.load();
+ const uint64_t RequestsPerS = static_cast<uint64_t>(RequestsSent / ElapsedS);
+ const uint64_t SentPerS = static_cast<uint64_t>(Sent / ElapsedS);
+ const uint64_t ReceivedPerS = static_cast<uint64_t>(Received / ElapsedS);
+
+ ZEN_CONSOLE("Requests sent {} ({}/s), payloads sent {}B ({}B/s), payloads received {}B ({}B/s) in {}",
+ RequestsSent,
+ RequestsPerS,
+ NiceBytes(Sent),
+ NiceBytes(SentPerS),
+ NiceBytes(Received),
+ NiceBytes(ReceivedPerS),
+ NiceTimeSpanMs(ElapsedMS));
+
+ return 0;
+}
+
+} // namespace zen
diff --git a/src/zen/cmds/rpcreplay.h b/src/zen/cmds/rpcreplay.h
new file mode 100644
index 000000000..742e5ec5b
--- /dev/null
+++ b/src/zen/cmds/rpcreplay.h
@@ -0,0 +1,65 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+namespace zen {
+
+class RpcStartRecordingCommand : public ZenCmdBase
+{
+public:
+ RpcStartRecordingCommand();
+ ~RpcStartRecordingCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"rpc-record-start", "Starts recording of cache rpc requests on a host"};
+ std::string m_HostName;
+ std::string m_RecordingPath;
+};
+
+class RpcStopRecordingCommand : public ZenCmdBase
+{
+public:
+ RpcStopRecordingCommand();
+ ~RpcStopRecordingCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"rpc-record-stop", "Stops recording of cache rpc requests on a host"};
+ std::string m_HostName;
+};
+
+class RpcReplayCommand : public ZenCmdBase
+{
+public:
+ RpcReplayCommand();
+ ~RpcReplayCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"rpc-record-replay", "Replays a previously recorded session of cache rpc requests to a target host"};
+ std::string m_HostName;
+ std::string m_RecordingPath;
+ bool m_OnHost = false;
+ bool m_ShowMethodStats = false;
+ int m_ProcessCount;
+ int m_ThreadCount;
+ uint64_t m_Offset;
+ uint64_t m_Stride;
+ bool m_ForceAllowLocalRefs;
+ bool m_DisableLocalRefs;
+ bool m_ForceAllowLocalHandleRef;
+ bool m_DisableLocalHandleRefs;
+ bool m_ForceAllowPartialLocalRefs;
+ bool m_DisablePartialLocalRefs;
+};
+
+} // namespace zen
diff --git a/src/zen/cmds/scrub.cpp b/src/zen/cmds/scrub.cpp
new file mode 100644
index 000000000..27ff5e0ac
--- /dev/null
+++ b/src/zen/cmds/scrub.cpp
@@ -0,0 +1,154 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "scrub.h"
+#include <zencore/logging.h>
+#include <zenhttp/httpcommon.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+using namespace std::literals;
+
+namespace zen {
+
+ScrubCommand::ScrubCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+}
+
+ScrubCommand::~ScrubCommand() = default;
+
+int
+ScrubCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions, argc, argv);
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+GcCommand::GcCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+ m_Options.add_option("",
+ "s",
+ "smallobjects",
+ "Collect small objects",
+ cxxopts::value(m_SmallObjects)->default_value("false"),
+ "<smallobjects>");
+ m_Options.add_option("",
+ "m",
+ "maxcacheduration",
+ "Max cache lifetime (in seconds)",
+ cxxopts::value(m_MaxCacheDuration)->default_value("0"),
+ "<maxcacheduration>");
+ m_Options.add_option("",
+ "d",
+ "disksizesoftlimit",
+ "Max disk usage size (in bytes)",
+ cxxopts::value(m_DiskSizeSoftLimit)->default_value("0"),
+ "<disksizesoftlimit>");
+}
+
+GcCommand::~GcCommand()
+{
+}
+
+int
+GcCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions, argc, argv);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Parameters Params;
+ if (m_SmallObjects)
+ {
+ Params.Add({"smallobjects", "true"});
+ }
+ if (m_MaxCacheDuration != 0)
+ {
+ Params.Add({"maxcacheduration", fmt::format("{}", m_MaxCacheDuration)});
+ }
+ if (m_DiskSizeSoftLimit != 0)
+ {
+ Params.Add({"disksizesoftlimit", fmt::format("{}", m_DiskSizeSoftLimit)});
+ }
+
+ cpr::Session Session;
+ Session.SetHeader(cpr::Header{{"Accept", "application/json"}});
+ Session.SetUrl({fmt::format("{}/admin/gc", m_HostName)});
+ Session.SetParameters(Params);
+
+ cpr::Response Result = Session.Post();
+
+ if (zen::IsHttpSuccessCode(Result.status_code))
+ {
+ ZEN_CONSOLE("OK: {}", Result.text);
+ return 0;
+ }
+
+ if (Result.status_code)
+ {
+ ZEN_ERROR("GC start failed: {}: {} ({})", Result.status_code, Result.reason, Result.text);
+ }
+ else
+ {
+ ZEN_ERROR("GC start failed: {}", Result.error.message);
+ }
+
+ return 1;
+}
+
+GcStatusCommand::GcStatusCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value("http://localhost:1337"), "<hosturl>");
+}
+
+GcStatusCommand::~GcStatusCommand()
+{
+}
+
+int
+GcStatusCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions, argc, argv);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ cpr::Session Session;
+ Session.SetHeader(cpr::Header{{"Accept", "application/json"}});
+ Session.SetUrl({fmt::format("{}/admin/gc", m_HostName)});
+
+ cpr::Response Result = Session.Get();
+
+ if (zen::IsHttpSuccessCode(Result.status_code))
+ {
+ ZEN_CONSOLE("OK: {}", Result.text);
+ return 0;
+ }
+
+ if (Result.status_code)
+ {
+ ZEN_ERROR("GC status failed: {}: {} ({})", Result.status_code, Result.reason, Result.text);
+ }
+ else
+ {
+ ZEN_ERROR("GC status failed: {}", Result.error.message);
+ }
+
+ return 1;
+}
+
+} // namespace zen
diff --git a/src/zen/cmds/scrub.h b/src/zen/cmds/scrub.h
new file mode 100644
index 000000000..ee8b4fdbb
--- /dev/null
+++ b/src/zen/cmds/scrub.h
@@ -0,0 +1,58 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+namespace zen {
+
+/** Scrub storage
+ */
+class ScrubCommand : public ZenCmdBase
+{
+public:
+ ScrubCommand();
+ ~ScrubCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"scrub", "Scrub zen storage"};
+ std::string m_HostName;
+};
+
+/** Garbage collect storage
+ */
+class GcCommand : public ZenCmdBase
+{
+public:
+ GcCommand();
+ ~GcCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"gc", "Garbage collect zen storage"};
+ std::string m_HostName;
+ bool m_SmallObjects{false};
+ uint64_t m_MaxCacheDuration{0};
+ uint64_t m_DiskSizeSoftLimit{0};
+};
+
+class GcStatusCommand : public ZenCmdBase
+{
+public:
+ GcStatusCommand();
+ ~GcStatusCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"gc-status", "Garbage collect zen storage status check"};
+ std::string m_HostName;
+};
+
+} // namespace zen
diff --git a/src/zen/cmds/status.cpp b/src/zen/cmds/status.cpp
new file mode 100644
index 000000000..23c27f9f9
--- /dev/null
+++ b/src/zen/cmds/status.cpp
@@ -0,0 +1,41 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "status.h"
+
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include <zencore/uid.h>
+#include <zenutil/zenserverprocess.h>
+
+namespace zen {
+
+StatusCommand::StatusCommand()
+{
+}
+
+StatusCommand::~StatusCommand() = default;
+
+int
+StatusCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions, argc, argv);
+
+ ZenServerState State;
+ if (!State.InitializeReadOnly())
+ {
+ ZEN_CONSOLE("no Zen state found");
+
+ return 0;
+ }
+
+ ZEN_CONSOLE("{:>5} {:>6} {:>24}", "port", "pid", "session");
+ State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) {
+ StringBuilder<25> SessionStringBuilder;
+ Entry.GetSessionId().ToString(SessionStringBuilder);
+ ZEN_CONSOLE("{:>5} {:>6} {:>24}", Entry.EffectiveListenPort, Entry.Pid, SessionStringBuilder.ToString());
+ });
+
+ return 0;
+}
+
+} // namespace zen
diff --git a/src/zen/cmds/status.h b/src/zen/cmds/status.h
new file mode 100644
index 000000000..98f72e651
--- /dev/null
+++ b/src/zen/cmds/status.h
@@ -0,0 +1,22 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+namespace zen {
+
+class StatusCommand : public ZenCmdBase
+{
+public:
+ StatusCommand();
+ ~StatusCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"status", "Show zen status"};
+};
+
+} // namespace zen
diff --git a/src/zen/cmds/top.cpp b/src/zen/cmds/top.cpp
new file mode 100644
index 000000000..4fe8c9cdf
--- /dev/null
+++ b/src/zen/cmds/top.cpp
@@ -0,0 +1,89 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "top.h"
+
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/uid.h>
+#include <zenutil/zenserverprocess.h>
+
+#include <memory>
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+TopCommand::TopCommand()
+{
+}
+
+TopCommand::~TopCommand() = default;
+
+int
+TopCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions, argc, argv);
+
+ ZenServerState State;
+ if (!State.InitializeReadOnly())
+ {
+ ZEN_CONSOLE("no Zen state found");
+
+ return 0;
+ }
+
+ int n = 0;
+ const int HeaderPeriod = 20;
+
+ for (;;)
+ {
+ if ((n++ % HeaderPeriod) == 0)
+ {
+ ZEN_CONSOLE("{:>5} {:>6} {:>24}", "port", "pid", "session");
+ }
+
+ State.Snapshot([&](const ZenServerState::ZenServerEntry& Entry) {
+ StringBuilder<25> SessionStringBuilder;
+ Entry.GetSessionId().ToString(SessionStringBuilder);
+ ZEN_CONSOLE("{:>5} {:>6} {:>24}", Entry.EffectiveListenPort, Entry.Pid, SessionStringBuilder.ToString());
+ });
+
+ zen::Sleep(1000);
+
+ if (!State.IsReadOnly())
+ {
+ State.Sweep();
+ }
+ }
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+PsCommand::PsCommand()
+{
+}
+
+PsCommand::~PsCommand() = default;
+
+int
+PsCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions, argc, argv);
+
+ ZenServerState State;
+ if (!State.InitializeReadOnly())
+ {
+ ZEN_CONSOLE("no Zen state found");
+
+ return 0;
+ }
+
+ State.Snapshot(
+ [&](const ZenServerState::ZenServerEntry& Entry) { ZEN_CONSOLE("Port {} : pid {}", Entry.EffectiveListenPort, Entry.Pid); });
+
+ return 0;
+}
+
+} // namespace zen
diff --git a/src/zen/cmds/top.h b/src/zen/cmds/top.h
new file mode 100644
index 000000000..83410587b
--- /dev/null
+++ b/src/zen/cmds/top.h
@@ -0,0 +1,35 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+namespace zen {
+
+class TopCommand : public ZenCmdBase
+{
+public:
+ TopCommand();
+ ~TopCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"top", "Show dev UI"};
+};
+
+class PsCommand : public ZenCmdBase
+{
+public:
+ PsCommand();
+ ~PsCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"ps", "Enumerate running Zen server instances"};
+};
+
+} // namespace zen
diff --git a/src/zen/cmds/up.cpp b/src/zen/cmds/up.cpp
new file mode 100644
index 000000000..69bcbe829
--- /dev/null
+++ b/src/zen/cmds/up.cpp
@@ -0,0 +1,108 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "up.h"
+
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+#include <zenutil/zenserverprocess.h>
+
+#include <memory>
+
+namespace zen {
+
+UpCommand::UpCommand()
+{
+}
+
+UpCommand::~UpCommand() = default;
+
+int
+UpCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions, argc, argv);
+
+ std::filesystem::path ExePath = zen::GetRunningExecutablePath();
+
+ ZenServerEnvironment ServerEnvironment;
+ ServerEnvironment.Initialize(ExePath.parent_path());
+ ZenServerInstance Server(ServerEnvironment);
+ Server.SpawnServer();
+
+ int Timeout = 10000;
+
+ if (!Server.WaitUntilReady(Timeout))
+ {
+ ZEN_ERROR("zen server launch failed (timed out)");
+ }
+ else
+ {
+ ZEN_CONSOLE("zen server up");
+ }
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+DownCommand::DownCommand()
+{
+ m_Options.add_option("", "p", "port", "Host port", cxxopts::value(m_Port)->default_value("1337"), "<hostport>");
+}
+
+DownCommand::~DownCommand() = default;
+
+int
+DownCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+ // Discover executing instances
+
+ ZenServerState Instance;
+ Instance.Initialize();
+ ZenServerState::ZenServerEntry* Entry = Instance.Lookup(m_Port);
+
+ if (!Entry)
+ {
+ ZEN_WARN("no zen server to bring down");
+
+ return 0;
+ }
+
+ try
+ {
+ std::filesystem::path ExePath = zen::GetRunningExecutablePath();
+
+ ZenServerEnvironment ServerEnvironment;
+ ServerEnvironment.Initialize(ExePath.parent_path());
+ ZenServerInstance Server(ServerEnvironment);
+ Server.AttachToRunningServer(m_Port);
+
+ ZEN_CONSOLE("attached to server on port {}, requesting shutdown", m_Port);
+
+ Server.Shutdown();
+
+ ZEN_CONSOLE("shutdown complete");
+
+ return 0;
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_DEBUG("Exception caught when requesting shutdown: {}", Ex.what());
+ }
+
+ // Since we cannot obtain a handle to the process we are unable to block on the process
+ // handle to determine when the server has shut down. Thus we signal that we would like
+ // a shutdown via the shutdown flag and then
+
+ ZEN_CONSOLE("requesting shutdown of server on port {}", m_Port);
+ Entry->SignalShutdownRequest();
+
+ return 0;
+}
+
+} // namespace zen
diff --git a/src/zen/cmds/up.h b/src/zen/cmds/up.h
new file mode 100644
index 000000000..5af05541a
--- /dev/null
+++ b/src/zen/cmds/up.h
@@ -0,0 +1,36 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+namespace zen {
+
+class UpCommand : public ZenCmdBase
+{
+public:
+ UpCommand();
+ ~UpCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"up", "Bring up zen service"};
+};
+
+class DownCommand : public ZenCmdBase
+{
+public:
+ DownCommand();
+ ~DownCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"down", "Bring down zen service"};
+ uint16_t m_Port;
+};
+
+} // namespace zen
diff --git a/src/zen/cmds/version.cpp b/src/zen/cmds/version.cpp
new file mode 100644
index 000000000..ba83b527d
--- /dev/null
+++ b/src/zen/cmds/version.cpp
@@ -0,0 +1,79 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "version.h"
+
+#include <zencore/config.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zenhttp/httpcommon.h>
+#include <zenutil/zenserverprocess.h>
+
+#include <memory>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+VersionCommand::VersionCommand()
+{
+ m_Options.add_options()("h,help", "Print help");
+ m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName), "[hosturl]");
+ m_Options.add_option("", "d", "detailed", "Detailed Version", cxxopts::value(m_DetailedVersion), "[detailedversion]");
+ m_Options.parse_positional({"hosturl"});
+}
+
+VersionCommand::~VersionCommand() = default;
+
+int
+VersionCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
+{
+ ZEN_UNUSED(GlobalOptions);
+ if (!ParseOptions(argc, argv))
+ {
+ return 0;
+ }
+
+ std::string Version;
+
+ if (m_HostName.empty())
+ {
+ if (m_DetailedVersion)
+ {
+ Version = ZEN_CFG_VERSION_BUILD_STRING_FULL;
+ }
+ else
+ {
+ Version = ZEN_CFG_VERSION;
+ }
+ }
+ else
+ {
+ const std::string UrlBase = fmt::format("{}/health", m_HostName);
+ cpr::Session Session;
+ std::string VersionRequest = fmt::format("{}/version{}", UrlBase, m_DetailedVersion ? "?detailed=true" : "");
+ Session.SetUrl(VersionRequest);
+ cpr::Response Response = Session.Get();
+ if (!zen::IsHttpSuccessCode(Response.status_code))
+ {
+ if (Response.status_code)
+ {
+ ZEN_ERROR("{} failed: {}: {} ({})", VersionRequest, Response.status_code, Response.reason, Response.text);
+ }
+ else
+ {
+ ZEN_ERROR("{} failed: {}", VersionRequest, Response.error.message);
+ }
+
+ return 1;
+ }
+ Version = Response.text;
+ }
+
+ zen::ConsoleLog().info("{}", Version);
+
+ return 0;
+}
+} // namespace zen
diff --git a/src/zen/cmds/version.h b/src/zen/cmds/version.h
new file mode 100644
index 000000000..0e37e91a0
--- /dev/null
+++ b/src/zen/cmds/version.h
@@ -0,0 +1,24 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "../zen.h"
+
+namespace zen {
+
+class VersionCommand : public ZenCmdBase
+{
+public:
+ VersionCommand();
+ ~VersionCommand();
+
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override;
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"version", "Get zen service version"};
+ std::string m_HostName;
+ bool m_DetailedVersion;
+};
+
+} // namespace zen
diff --git a/src/zen/internalfile.cpp b/src/zen/internalfile.cpp
new file mode 100644
index 000000000..2ade86e29
--- /dev/null
+++ b/src/zen/internalfile.cpp
@@ -0,0 +1,299 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "internalfile.h"
+
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+# include <fcntl.h>
+# include <sys/file.h>
+# include <sys/mman.h>
+# include <sys/stat.h>
+#endif
+
+#include <gsl/gsl-lite.hpp>
+
+#define ZEN_USE_SLIST ZEN_PLATFORM_WINDOWS
+
+#if ZEN_USE_SLIST == 0
+struct FileBufferManager::Impl
+{
+ zen::RwLock m_Lock;
+ std::list<zen::IoBuffer> m_FreeBuffers;
+
+ uint64_t m_BufferSize;
+ uint64_t m_MaxBufferCount;
+
+ Impl(uint64_t BufferSize, uint64_t MaxBuffers) : m_BufferSize(BufferSize), m_MaxBufferCount(MaxBuffers) {}
+
+ zen::IoBuffer AllocBuffer()
+ {
+ zen::RwLock::ExclusiveLockScope _(m_Lock);
+
+ if (m_FreeBuffers.empty())
+ {
+ return zen::IoBuffer{m_BufferSize, 64 * 1024};
+ }
+ else
+ {
+ zen::IoBuffer Buffer = std::move(m_FreeBuffers.front());
+ m_FreeBuffers.pop_front();
+ return Buffer;
+ }
+ }
+
+ void ReturnBuffer(zen::IoBuffer Buffer)
+ {
+ zen::RwLock::ExclusiveLockScope _(m_Lock);
+
+ m_FreeBuffers.push_front(std::move(Buffer));
+ }
+};
+#else
+struct FileBufferManager::Impl
+{
+ struct BufferItem
+ {
+ SLIST_ENTRY ItemEntry;
+ zen::IoBuffer Buffer;
+ };
+
+ SLIST_HEADER m_FreeList;
+ uint64_t m_BufferSize;
+ uint64_t m_MaxBufferCount;
+
+ Impl(uint64_t BufferSize, uint64_t MaxBuffers) : m_BufferSize(BufferSize), m_MaxBufferCount(MaxBuffers)
+ {
+ InitializeSListHead(&m_FreeList);
+ }
+
+ ~Impl()
+ {
+ while (SLIST_ENTRY* Entry = InterlockedPopEntrySList(&m_FreeList))
+ {
+ BufferItem* Item = reinterpret_cast<BufferItem*>(Entry);
+ delete Item;
+ }
+ }
+
+ zen::IoBuffer AllocBuffer()
+ {
+ SLIST_ENTRY* Entry = InterlockedPopEntrySList(&m_FreeList);
+
+ if (Entry == nullptr)
+ {
+ return zen::IoBuffer{m_BufferSize, 64 * 1024};
+ }
+ else
+ {
+ BufferItem* Item = reinterpret_cast<BufferItem*>(Entry);
+ zen::IoBuffer Buffer = std::move(Item->Buffer);
+ delete Item; // Todo: could keep this around in another list
+
+ return Buffer;
+ }
+ }
+
+ void ReturnBuffer(zen::IoBuffer Buffer)
+ {
+ BufferItem* Item = new BufferItem{nullptr, std::move(Buffer)};
+
+ InterlockedPushEntrySList(&m_FreeList, &Item->ItemEntry);
+ }
+};
+#endif
+
+FileBufferManager::FileBufferManager(uint64_t BufferSize, uint64_t MaxBuffers)
+{
+ m_Impl = new Impl{BufferSize, MaxBuffers};
+}
+
+FileBufferManager::~FileBufferManager()
+{
+ delete m_Impl;
+}
+
+zen::IoBuffer
+FileBufferManager::AllocBuffer()
+{
+ return m_Impl->AllocBuffer();
+}
+
+void
+FileBufferManager::ReturnBuffer(zen::IoBuffer Buffer)
+{
+ return m_Impl->ReturnBuffer(Buffer);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+InternalFile::InternalFile()
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+: m_File(nullptr)
+, m_Mmap(nullptr)
+#endif
+{
+}
+
+InternalFile::~InternalFile()
+{
+ if (m_Memory)
+ zen::Memory::Free(m_Memory);
+
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ if (m_Mmap)
+ munmap(m_Mmap, GetFileSize());
+ if (m_File)
+ close(int(intptr_t(m_File)));
+#endif
+}
+
+size_t
+InternalFile::GetFileSize()
+{
+#if ZEN_PLATFORM_WINDOWS
+ ULONGLONG sz;
+ m_File.GetSize(sz);
+ return size_t(sz);
+#else
+ int Fd = int(intptr_t(m_File));
+ static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files");
+ struct stat Stat;
+ fstat(Fd, &Stat);
+ return size_t(Stat.st_size);
+#endif
+}
+
+void
+InternalFile::OpenWrite(std::filesystem::path FileName, bool IsCreate)
+{
+ bool Success = false;
+
+#if ZEN_PLATFORM_WINDOWS
+ const DWORD dwCreationDisposition = IsCreate ? CREATE_ALWAYS : OPEN_EXISTING;
+
+ HRESULT hRes = m_File.Create(FileName.c_str(), GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ, dwCreationDisposition);
+ Success = SUCCEEDED(hRes);
+#else
+ int OpenFlags = O_RDWR | O_CLOEXEC;
+ OpenFlags |= IsCreate ? O_CREAT | O_TRUNC : 0;
+
+ int Fd = open(FileName.c_str(), OpenFlags, 0666);
+ if (Fd >= 0)
+ {
+ if (IsCreate)
+ {
+ fchmod(Fd, 0666);
+ }
+ Success = true;
+ m_File = (void*)(intptr_t(Fd));
+ }
+#endif // ZEN_PLATFORM_WINDOWS
+
+ if (Success)
+ {
+ zen::ThrowLastError(fmt::format("Failed to open file for writing: '{}'", FileName));
+ }
+}
+
+void
+InternalFile::OpenRead(std::filesystem::path FileName)
+{
+ bool Success = false;
+
+#if ZEN_PLATFORM_WINDOWS
+ const DWORD dwCreationDisposition = OPEN_EXISTING;
+
+ HRESULT hRes = m_File.Create(FileName.c_str(), GENERIC_READ, FILE_SHARE_READ, dwCreationDisposition);
+ Success = SUCCEEDED(hRes);
+#else
+ int Fd = open(FileName.c_str(), O_RDONLY);
+ if (Fd >= 0)
+ {
+ Success = true;
+ m_File = (void*)(intptr_t(Fd));
+ }
+#endif
+
+ if (Success)
+ {
+ zen::ThrowLastError(fmt::format("Failed to open file for reading: '{}'", FileName));
+ }
+}
+
+const void*
+InternalFile::MemoryMapFile()
+{
+ auto FileSize = GetFileSize();
+
+ if (FileSize <= 100 * 1024 * 1024)
+ {
+ m_Memory = zen::Memory::Alloc(FileSize, 64);
+ Read(m_Memory, FileSize, 0);
+
+ return m_Memory;
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ m_Mmap.MapFile(m_File);
+ return m_Mmap.GetData();
+#else
+ int Fd = int(intptr_t(m_File));
+ m_Mmap = mmap(nullptr, FileSize, PROT_READ, MAP_PRIVATE, Fd, 0);
+ return m_Mmap;
+#endif
+}
+
+void
+InternalFile::Read(void* Data, uint64_t Size, uint64_t Offset)
+{
+ bool Success;
+
+#if ZEN_PLATFORM_WINDOWS
+ OVERLAPPED ovl{};
+
+ ovl.Offset = DWORD(Offset & 0xffff'ffffu);
+ ovl.OffsetHigh = DWORD(Offset >> 32);
+
+ HRESULT hRes = m_File.Read(Data, gsl::narrow<DWORD>(Size), &ovl);
+ Success = SUCCEEDED(hRes);
+#else
+ int Fd = int(intptr_t(m_File));
+ int BytesRead = pread(Fd, Data, Size, Offset);
+ Success = (BytesRead > 0);
+#endif
+
+ if (Success)
+ {
+ zen::ThrowLastError(fmt::format("Failed to read from file '{}'", "")); // zen::PathFromHandle(m_File)));
+ }
+}
+
+void
+InternalFile::Write(const void* Data, uint64_t Size, uint64_t Offset)
+{
+ bool Success;
+
+#if ZEN_PLATFORM_WINDOWS
+ OVERLAPPED Ovl{};
+
+ Ovl.Offset = DWORD(Offset & 0xffff'ffffu);
+ Ovl.OffsetHigh = DWORD(Offset >> 32);
+
+ HRESULT hRes = m_File.Write(Data, gsl::narrow<DWORD>(Size), &Ovl);
+ Success = SUCCEEDED(hRes);
+#else
+ int Fd = int(intptr_t(m_File));
+ int BytesWritten = pwrite(Fd, Data, Size, Offset);
+ Success = (BytesWritten > 0);
+#endif
+
+ if (Success)
+ {
+ zen::ThrowLastError(fmt::format("Failed to write to file '{}'", zen::PathFromHandle(m_File)));
+ }
+}
diff --git a/src/zen/internalfile.h b/src/zen/internalfile.h
new file mode 100644
index 000000000..8acb600ff
--- /dev/null
+++ b/src/zen/internalfile.h
@@ -0,0 +1,62 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#include <zencore/iobuffer.h>
+#include <zencore/refcount.h>
+#include <zencore/thread.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#endif
+
+#if ZEN_PLATFORM_WINDOWS
+# include <atlfile.h>
+#endif
+
+#include <filesystem>
+#include <list>
+
+//////////////////////////////////////////////////////////////////////////
+
+class FileBufferManager : public zen::RefCounted
+{
+public:
+ FileBufferManager(uint64_t BufferSize, uint64_t MaxBufferCount);
+ ~FileBufferManager();
+
+ zen::IoBuffer AllocBuffer();
+ void ReturnBuffer(zen::IoBuffer Buffer);
+
+private:
+ struct Impl;
+
+ Impl* m_Impl;
+};
+
+class InternalFile : public zen::RefCounted
+{
+public:
+ InternalFile();
+ ~InternalFile();
+
+ void OpenRead(std::filesystem::path FileName);
+ void Read(void* Data, uint64_t Size, uint64_t Offset);
+
+ void OpenWrite(std::filesystem::path FileName, bool isCreate);
+ void Write(const void* Data, uint64_t Size, uint64_t Offset);
+
+ const void* MemoryMapFile();
+ size_t GetFileSize();
+
+private:
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ using CAtlFile = void*;
+ using CAtlFileMappingBase = void*;
+#endif
+ CAtlFile m_File;
+ CAtlFileMappingBase m_Mmap;
+ void* m_Memory = nullptr;
+};
diff --git a/src/zen/xmake.lua b/src/zen/xmake.lua
new file mode 100644
index 000000000..b83999efc
--- /dev/null
+++ b/src/zen/xmake.lua
@@ -0,0 +1,31 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target("zen")
+ set_kind("binary")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_files("zen.cpp", {unity_ignored = true })
+ add_deps("zencore", "zenhttp", "zenutil")
+ add_includedirs(".")
+ set_symbols("debug")
+
+ if is_mode("release") then
+ set_optimize("fastest")
+ end
+
+ if is_plat("windows") then
+ add_files("zen.rc")
+ add_ldflags("/subsystem:console,5.02")
+ add_ldflags("/LTCG")
+ add_ldflags("crypt32.lib", "wldap32.lib", "Ws2_32.lib")
+ end
+
+ if is_plat("macosx") then
+ add_ldflags("-framework CoreFoundation")
+ add_ldflags("-framework Security")
+ add_ldflags("-framework SystemConfiguration")
+ add_syslinks("bsm")
+ end
+
+ add_packages("vcpkg::zstd")
+ add_packages("vcpkg::cxxopts", "vcpkg::mimalloc")
diff --git a/src/zen/zen.cpp b/src/zen/zen.cpp
new file mode 100644
index 000000000..9754f4434
--- /dev/null
+++ b/src/zen/zen.cpp
@@ -0,0 +1,421 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+// Zen command line client utility
+//
+
+#include "zen.h"
+
+#include "chunk/chunk.h"
+#include "cmds/cache.h"
+#include "cmds/copy.h"
+#include "cmds/dedup.h"
+#include "cmds/hash.h"
+#include "cmds/print.h"
+#include "cmds/projectstore.h"
+#include "cmds/rpcreplay.h"
+#include "cmds/scrub.h"
+#include "cmds/status.h"
+#include "cmds/top.h"
+#include "cmds/up.h"
+#include "cmds/version.h"
+
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/string.h>
+
+#include <zenhttp/httpcommon.h>
+
+#if ZEN_WITH_TESTS
+# define ZEN_TEST_WITH_RUNNER 1
+# include <zencore/testing.h>
+#endif
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <gsl/gsl-lite.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_USE_MIMALLOC
+# include <mimalloc-new-delete.h>
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+bool
+ZenCmdBase::ParseOptions(int argc, char** argv)
+{
+ cxxopts::Options& CmdOptions = Options();
+ cxxopts::ParseResult Result = CmdOptions.parse(argc, argv);
+
+ if (Result.count("help"))
+ {
+ printf("%s\n", CmdOptions.help().c_str());
+ return false;
+ }
+
+ if (!Result.unmatched().empty())
+ {
+ zen::ExtendableStringBuilder<64> StringBuilder;
+ for (bool First = true; const auto& Param : Result.unmatched())
+ {
+ if (!First)
+ {
+ StringBuilder.Append(", ");
+ }
+ StringBuilder.Append('"');
+ StringBuilder.Append(Param);
+ StringBuilder.Append('"');
+ First = false;
+ }
+
+ throw cxxopts::OptionParseException(fmt::format("Invalid arguments: {}", StringBuilder.ToView()));
+ }
+
+ return true;
+}
+
+std::string
+ZenCmdBase::FormatHttpResponse(const cpr::Response& Response)
+{
+ if (Response.error.code != cpr::ErrorCode::OK)
+ {
+ if (Response.error.message.empty())
+ {
+ return fmt::format("Request '{}' failed, error code {}", Response.url.str(), static_cast<int>(Response.error.code));
+ }
+ return fmt::format("Request '{}' failed. Reason: '{}' ({})",
+ Response.url.str(),
+ Response.error.message,
+ static_cast<int>(Response.error.code));
+ }
+
+ std::string Content;
+ if (auto It = Response.header.find("Content-Type"); It != Response.header.end())
+ {
+ zen::HttpContentType ContentType = zen::ParseContentType(It->second);
+ if (ContentType == zen::HttpContentType::kText)
+ {
+ Content = fmt::format("'{}'", Response.text);
+ }
+ else if (ContentType == zen::HttpContentType::kJSON)
+ {
+ Content = fmt::format("\n{}", Response.text);
+ }
+ else if (!Response.text.empty())
+ {
+ Content = fmt::format("[{}]", MapContentTypeToString(ContentType));
+ }
+ }
+
+ std::string_view ResponseString = zen::ReasonStringForHttpResultCode(
+ Response.status_code == static_cast<long>(zen::HttpResponseCode::NoContent) ? static_cast<long>(zen::HttpResponseCode::OK)
+ : Response.status_code);
+ if (Content.empty())
+ {
+ return std::string(ResponseString);
+ }
+
+ return fmt::format("{}: {}", ResponseString, Content);
+}
+
+int
+ZenCmdBase::MapHttpToCommandReturnCode(const cpr::Response& Response)
+{
+ if (zen::IsHttpSuccessCode(Response.status_code))
+ {
+ return 0;
+ }
+ if (Response.error.code != cpr::ErrorCode::OK)
+ {
+ return static_cast<int>(Response.error.code);
+ }
+ return 1;
+}
+
+#if ZEN_WITH_TESTS
+
+class RunTestsCommand : public ZenCmdBase
+{
+public:
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) override
+ {
+ ZEN_UNUSED(GlobalOptions);
+
+ // Set output mode to handle virtual terminal sequences
+# if ZEN_PLATFORM_WINDOWS
+ HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE);
+ if (hOut == INVALID_HANDLE_VALUE)
+ return GetLastError();
+
+ DWORD dwMode = 0;
+ if (!GetConsoleMode(hOut, &dwMode))
+ return GetLastError();
+
+ dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING;
+ if (!SetConsoleMode(hOut, dwMode))
+ return GetLastError();
+# endif // ZEN_PLATFORM_WINDOWS
+
+ return ZEN_RUN_TESTS(argc, argv);
+ }
+
+ virtual cxxopts::Options& Options() override { return m_Options; }
+
+private:
+ cxxopts::Options m_Options{"runtests", "Run tests"};
+};
+
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// TODO: should make this Unicode-aware so we can pass anything in on the
+// command line.
+
+int
+main(int argc, char** argv)
+{
+ using namespace zen;
+
+#if ZEN_USE_MIMALLOC
+ mi_version();
+#endif
+
+ zen::logging::InitializeLogging();
+ zen::MaximizeOpenFileCount();
+
+ //////////////////////////////////////////////////////////////////////////
+
+ auto _ = zen::MakeGuard([] { spdlog::shutdown(); });
+
+ CacheInfoCommand CacheInfoCmd;
+ CopyCommand CopyCmd;
+ CreateOplogCommand CreateOplogCmd;
+ CreateProjectCommand CreateProjectCmd;
+ DedupCommand DedupCmd;
+ DownCommand DownCmd;
+ DropCommand DropCmd;
+ DropProjectCommand ProjectDropCmd;
+ ExportOplogCommand ExportOplogCmd;
+ GcCommand GcCmd;
+ GcStatusCommand GcStatusCmd;
+ HashCommand HashCmd;
+ ImportOplogCommand ImportOplogCmd;
+ PrintCommand PrintCmd;
+ PrintPackageCommand PrintPkgCmd;
+ ProjectInfoCommand ProjectInfoCmd;
+ PsCommand PsCmd;
+ RpcReplayCommand RpcReplayCmd;
+ RpcStartRecordingCommand RpcStartRecordingCmd;
+ RpcStopRecordingCommand RpcStopRecordingCmd;
+ StatusCommand StatusCmd;
+ TopCommand TopCmd;
+ UpCommand UpCmd;
+ VersionCommand VersionCmd;
+ CacheStatsCommand CacheStatsCmd;
+ CacheDetailsCommand CacheDetailsCmd;
+ ProjectStatsCommand ProjectStatsCmd;
+ ProjectDetailsCommand ProjectDetailsCmd;
+#if ZEN_WITH_TESTS
+ RunTestsCommand RunTestsCmd;
+#endif
+
+ const struct CommandInfo
+ {
+ const char* CmdName;
+ ZenCmdBase* Cmd;
+ const char* CmdSummary;
+ } Commands[] = {
+ // clang-format off
+// {"chunk", &ChunkCmd, "Perform chunking"},
+ {"cache-info", &CacheInfoCmd, "Info on cache, namespace or bucket"},
+ {"copy", &CopyCmd, "Copy file(s)"},
+ {"dedup", &DedupCmd, "Dedup files"},
+ {"down", &DownCmd, "Bring zen server down"},
+ {"drop", &DropCmd, "Drop cache namespace or bucket"},
+ {"gc-status", &GcStatusCmd, "Garbage collect zen storage status check"},
+ {"gc", &GcCmd, "Garbage collect zen storage"},
+ {"hash", &HashCmd, "Compute file hashes"},
+ {"oplog-create", &CreateOplogCmd, "Create a project oplog"},
+ {"oplog-export", &ExportOplogCmd, "Export project store oplog"},
+ {"oplog-import", &ImportOplogCmd, "Import project store oplog"},
+ {"print", &PrintCmd, "Print compact binary object"},
+ {"printpackage", &PrintPkgCmd, "Print compact binary package"},
+ {"project-create", &CreateProjectCmd, "Create a project"},
+ {"project-drop", &ProjectDropCmd, "Drop project or project oplog"},
+ {"project-info", &ProjectInfoCmd, "Info on project or project oplog"},
+ {"ps", &PsCmd, "Enumerate running zen server instances"},
+ {"rpc-record-replay", &RpcReplayCmd, "Stops recording of cache rpc requests on a host"},
+ {"rpc-record-start", &RpcStartRecordingCmd, "Replays a previously recorded session of rpc requests"},
+ {"rpc-record-stop", &RpcStopRecordingCmd, "Starts recording of cache rpc requests on a host"},
+ {"status", &StatusCmd, "Show zen status"},
+ {"top", &TopCmd, "Monitor zen server activity"},
+ {"up", &UpCmd, "Bring zen server up"},
+ {"version", &VersionCmd, "Get zen server version"},
+ {"cache-stats", &CacheStatsCmd, "Stats on cache"},
+ {"cache-details", &CacheDetailsCmd, "Details on cache"},
+ {"project-stats", &ProjectStatsCmd, "Stats on project store"},
+ {"project-details", &ProjectDetailsCmd, "Details on project store"},
+#if ZEN_WITH_TESTS
+ {"runtests", &RunTestsCmd, "Run zen tests"},
+#endif
+ // clang-format on
+ };
+
+ // Build set containing available commands
+
+ std::unordered_set<std::string> CommandSet;
+
+ for (const auto& Cmd : Commands)
+ CommandSet.insert(Cmd.CmdName);
+
+ // Split command line into options, commands and any pass-through arguments
+
+ std::string Passthrough;
+ std::vector<std::string> PassthroughV;
+
+ for (int i = 1; i < argc; ++i)
+ {
+ if (strcmp(argv[i], "--") == 0)
+ {
+ bool IsFirst = true;
+ zen::ExtendableStringBuilder<256> Line;
+
+ for (int j = i + 1; j < argc; ++j)
+ {
+ if (!IsFirst)
+ {
+ Line.AppendAscii(" ");
+ }
+
+ std::string_view ThisArg(argv[j]);
+ PassthroughV.push_back(std::string(ThisArg));
+
+ const bool NeedsQuotes = (ThisArg.find(' ') != std::string_view::npos);
+
+ if (NeedsQuotes)
+ {
+ Line.AppendAscii("\"");
+ }
+
+ Line.Append(ThisArg);
+
+ if (NeedsQuotes)
+ {
+ Line.AppendAscii("\"");
+ }
+
+ IsFirst = false;
+ }
+
+ Passthrough = Line.c_str();
+
+ // This will "truncate" the arg vector and terminate the loop
+ argc = i - 1;
+ }
+ }
+
+ // Split command line into global vs command options. We do this by simply
+ // scanning argv for a string we recognise as a command and split it there
+
+ std::vector<char*> CommandArgVec;
+ CommandArgVec.push_back(argv[0]);
+
+ for (int i = 1; i < argc; ++i)
+ {
+ if (CommandSet.find(argv[i]) != CommandSet.end())
+ {
+ int commandArgCount = /* exec name */ 1 + argc - (i + 1);
+ CommandArgVec.resize(commandArgCount);
+ std::copy(argv + i + 1, argv + argc, CommandArgVec.begin() + 1);
+
+ argc = i + 1;
+
+ break;
+ }
+ }
+
+ // Parse global CLI arguments
+
+ ZenCliOptions GlobalOptions;
+
+ GlobalOptions.PassthroughArgs = Passthrough;
+ GlobalOptions.PassthroughV = PassthroughV;
+
+ std::string SubCommand = "<None>";
+
+ cxxopts::Options Options("zen", "Zen management tool");
+
+ Options.add_options()("d, debug", "Enable debugging", cxxopts::value<bool>(GlobalOptions.IsDebug));
+ Options.add_options()("v, verbose", "Enable verbose logging", cxxopts::value<bool>(GlobalOptions.IsVerbose));
+ Options.add_options()("help", "Show command line help");
+ Options.add_options()("c, command", "Sub command", cxxopts::value<std::string>(SubCommand));
+
+ Options.parse_positional({"command"});
+
+ const bool IsNullInvoke = (argc == 1); // If no arguments are passed we want to print usage information
+
+ try
+ {
+ auto ParseResult = Options.parse(argc, argv);
+
+ if (ParseResult.count("help") || IsNullInvoke == 1)
+ {
+ std::string Help = Options.help();
+
+ printf("%s\n", Help.c_str());
+
+ printf("available commands:\n");
+
+ for (const auto& CmdInfo : Commands)
+ {
+ printf(" %-20s %s\n", CmdInfo.CmdName, CmdInfo.CmdSummary);
+ }
+
+ exit(0);
+ }
+
+ if (GlobalOptions.IsDebug)
+ {
+ spdlog::set_level(spdlog::level::debug);
+ }
+
+ for (const CommandInfo& CmdInfo : Commands)
+ {
+ if (StrCaseCompare(SubCommand.c_str(), CmdInfo.CmdName) == 0)
+ {
+ cxxopts::Options& VerbOptions = CmdInfo.Cmd->Options();
+ try
+ {
+ return CmdInfo.Cmd->Run(GlobalOptions, (int)CommandArgVec.size(), CommandArgVec.data());
+ }
+ catch (cxxopts::OptionParseException& Ex)
+ {
+ std::string help = VerbOptions.help();
+
+ printf("Error parsing arguments for command '%s': %s\n\n%s", SubCommand.c_str(), Ex.what(), help.c_str());
+
+ exit(11);
+ }
+ }
+ }
+
+ printf("Unknown command specified: '%s', exiting\n", SubCommand.c_str());
+ }
+ catch (cxxopts::OptionParseException& Ex)
+ {
+ std::string HelpMessage = Options.help();
+
+ printf("Error parsing program arguments: %s\n\n%s", Ex.what(), HelpMessage.c_str());
+
+ return 9;
+ }
+ catch (std::exception& Ex)
+ {
+ printf("Exception caught from 'main': %s\n", Ex.what());
+
+ return 10;
+ }
+
+ return 0;
+}
diff --git a/src/zen/zen.h b/src/zen/zen.h
new file mode 100644
index 000000000..b55e7a16c
--- /dev/null
+++ b/src/zen/zen.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cxxopts.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#endif
+
+namespace cpr {
+class Response;
+}
+
+struct ZenCliOptions
+{
+ bool IsDebug = false;
+ bool IsVerbose = false;
+
+ // Arguments after " -- " on command line are passed through and not parsed
+ std::string PassthroughArgs;
+ std::vector<std::string> PassthroughV;
+};
+
+class ZenCmdBase
+{
+public:
+ virtual int Run(const ZenCliOptions& GlobalOptions, int argc, char** argv) = 0;
+ virtual cxxopts::Options& Options() = 0;
+
+ bool ParseOptions(int argc, char** argv);
+ static std::string FormatHttpResponse(const cpr::Response& Response);
+ static int MapHttpToCommandReturnCode(const cpr::Response& Response);
+};
diff --git a/src/zen/zen.rc b/src/zen/zen.rc
new file mode 100644
index 000000000..14a9afb70
--- /dev/null
+++ b/src/zen/zen.rc
@@ -0,0 +1,33 @@
+#include "zencore/config.h"
+
+#define APSTUDIO_READONLY_SYMBOLS
+#include "winres.h"
+#undef APSTUDIO_READONLY_SYMBOLS
+
+LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US
+#pragma code_page(1252)
+
+101 ICON "..\\UnrealEngine.ico"
+
+VS_VERSION_INFO VERSIONINFO
+FILEVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0
+PRODUCTVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0
+{
+ BLOCK "StringFileInfo"
+ {
+ BLOCK "040904b0"
+ {
+ VALUE "CompanyName", "Epic Games Inc\0"
+ VALUE "FileDescription", "CLI utility for Zen Storage Service\0"
+ VALUE "FileVersion", ZEN_CFG_VERSION "\0"
+ VALUE "LegalCopyright", "Copyright Epic Games Inc. All Rights Reserved\0"
+ VALUE "OriginalFilename", "zen.exe\0"
+ VALUE "ProductName", "Zen Storage Server\0"
+ VALUE "ProductVersion", ZEN_CFG_VERSION_BUILD_STRING_FULL "\0"
+ }
+ }
+ BLOCK "VarFileInfo"
+ {
+ VALUE "Translation", 0x409, 1200
+ }
+}
diff --git a/src/zencore-test/targetver.h b/src/zencore-test/targetver.h
new file mode 100644
index 000000000..d432d6993
--- /dev/null
+++ b/src/zencore-test/targetver.h
@@ -0,0 +1,10 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+// Including SDKDDKVer.h defines the highest available Windows platform.
+
+// If you wish to build your application for a previous Windows platform, include WinSDKVer.h and
+// set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h.
+
+#include <SDKDDKVer.h>
diff --git a/src/zencore-test/xmake.lua b/src/zencore-test/xmake.lua
new file mode 100644
index 000000000..74c7e74a7
--- /dev/null
+++ b/src/zencore-test/xmake.lua
@@ -0,0 +1,8 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target("zencore-test")
+ set_kind("binary")
+ add_headerfiles("**.h")
+ add_files("*.cpp")
+ add_deps("zencore")
+ add_packages("vcpkg::doctest")
diff --git a/src/zencore-test/zencore-test.cpp b/src/zencore-test/zencore-test.cpp
new file mode 100644
index 000000000..53413fb25
--- /dev/null
+++ b/src/zencore-test/zencore-test.cpp
@@ -0,0 +1,26 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+// zencore-test.cpp : Defines the entry point for the console application.
+//
+
+#include <zencore/logging.h>
+#include <zencore/zencore.h>
+
+#if ZEN_WITH_TESTS
+# define ZEN_TEST_WITH_RUNNER 1
+# include <zencore/testing.h>
+#endif
+
+int
+main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
+{
+#if ZEN_WITH_TESTS
+ zen::zencore_forcelinktests();
+
+ zen::logging::InitializeLogging();
+
+ return ZEN_RUN_TESTS(argc, argv);
+#else
+ return 0;
+#endif
+}
diff --git a/src/zencore/.gitignore b/src/zencore/.gitignore
new file mode 100644
index 000000000..77d39c17e
--- /dev/null
+++ b/src/zencore/.gitignore
@@ -0,0 +1 @@
+include/zencore/config.h
diff --git a/src/zencore/base64.cpp b/src/zencore/base64.cpp
new file mode 100644
index 000000000..b97dfebbf
--- /dev/null
+++ b/src/zencore/base64.cpp
@@ -0,0 +1,107 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/base64.h>
+
+namespace zen {
+
+/** The table used to encode a 6 bit value as an ascii character */
+static const uint8_t EncodingAlphabet[64] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
+ 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
+ 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
+ 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};
+
+/** The table used to convert an ascii character into a 6 bit value */
+#if 0
+static const uint8_t DecodingAlphabet[256] = {
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x00-0x0f
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x10-0x1f
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF, 0x3F, // 0x20-0x2f
+ 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x30-0x3f
+ 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, // 0x40-0x4f
+ 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x50-0x5f
+ 0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, // 0x60-0x6f
+ 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x70-0x7f
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x80-0x8f
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x90-0x9f
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xa0-0xaf
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xb0-0xbf
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xc0-0xcf
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xd0-0xdf
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0xe0-0xef
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF // 0xf0-0xff
+};
+#endif // 0
+
+template<typename CharType>
+uint32_t
+Base64::Encode(const uint8_t* Source, uint32_t Length, CharType* Dest)
+{
+ CharType* EncodedBytes = Dest;
+
+ // Loop through the buffer converting 3 bytes of binary data at a time
+ while (Length >= 3)
+ {
+ uint8_t A = *Source++;
+ uint8_t B = *Source++;
+ uint8_t C = *Source++;
+ Length -= 3;
+
+ // The algorithm takes 24 bits of data (3 bytes) and breaks it into 4 6bit chunks represented as ascii
+ uint32_t ByteTriplet = A << 16 | B << 8 | C;
+
+ // Use the 6bit block to find the representation ascii character for it
+ EncodedBytes[3] = EncodingAlphabet[ByteTriplet & 0x3F];
+ ByteTriplet >>= 6;
+ EncodedBytes[2] = EncodingAlphabet[ByteTriplet & 0x3F];
+ ByteTriplet >>= 6;
+ EncodedBytes[1] = EncodingAlphabet[ByteTriplet & 0x3F];
+ ByteTriplet >>= 6;
+ EncodedBytes[0] = EncodingAlphabet[ByteTriplet & 0x3F];
+
+ // Now we can append this buffer to our destination string
+ EncodedBytes += 4;
+ }
+
+ // Since this algorithm operates on blocks, we may need to pad the last chunks
+ if (Length > 0)
+ {
+ uint8_t A = *Source++;
+ uint8_t B = 0;
+ uint8_t C = 0;
+ // Grab the second character if it is a 2 uint8_t finish
+ if (Length == 2)
+ {
+ B = *Source;
+ }
+ uint32_t ByteTriplet = A << 16 | B << 8 | C;
+ // Pad with = to make a 4 uint8_t chunk
+ EncodedBytes[3] = '=';
+ ByteTriplet >>= 6;
+ // If there's only one 1 uint8_t left in the source, then you need 2 pad chars
+ if (Length == 1)
+ {
+ EncodedBytes[2] = '=';
+ }
+ else
+ {
+ EncodedBytes[2] = EncodingAlphabet[ByteTriplet & 0x3F];
+ }
+ // Now encode the remaining bits the same way
+ ByteTriplet >>= 6;
+ EncodedBytes[1] = EncodingAlphabet[ByteTriplet & 0x3F];
+ ByteTriplet >>= 6;
+ EncodedBytes[0] = EncodingAlphabet[ByteTriplet & 0x3F];
+
+ EncodedBytes += 4;
+ }
+
+ // Add a null terminator
+ *EncodedBytes = 0;
+
+ return uint32_t(EncodedBytes - Dest);
+}
+
+template ZENCORE_API uint32_t Base64::Encode<char>(const uint8_t* Source, uint32_t Length, char* Dest);
+template ZENCORE_API uint32_t Base64::Encode<wchar_t>(const uint8_t* Source, uint32_t Length, wchar_t* Dest);
+
+} // namespace zen
diff --git a/src/zencore/blake3.cpp b/src/zencore/blake3.cpp
new file mode 100644
index 000000000..89826ae5d
--- /dev/null
+++ b/src/zencore/blake3.cpp
@@ -0,0 +1,175 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/blake3.h>
+
+#include <zencore/compositebuffer.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/zencore.h>
+
+#include <string.h>
+
+#include "blake3.h"
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+void
+blake3_forcelink()
+{
+}
+
+BLAKE3 BLAKE3::Zero; // Initialized to all zeroes
+
+BLAKE3
+BLAKE3::HashMemory(const void* data, size_t byteCount)
+{
+ BLAKE3 b3;
+
+ blake3_hasher b3h;
+ blake3_hasher_init(&b3h);
+ blake3_hasher_update(&b3h, data, byteCount);
+ blake3_hasher_finalize(&b3h, b3.Hash, sizeof b3.Hash);
+
+ return b3;
+}
+
+BLAKE3
+BLAKE3::HashBuffer(const CompositeBuffer& Buffer)
+{
+ BLAKE3 Hash;
+
+ blake3_hasher Hasher;
+ blake3_hasher_init(&Hasher);
+
+ for (const SharedBuffer& Segment : Buffer.GetSegments())
+ {
+ blake3_hasher_update(&Hasher, Segment.GetData(), Segment.GetSize());
+ }
+
+ blake3_hasher_finalize(&Hasher, Hash.Hash, sizeof Hash.Hash);
+
+ return Hash;
+}
+
+BLAKE3
+BLAKE3::FromHexString(const char* string)
+{
+ BLAKE3 b3;
+
+ ParseHexBytes(string, 2 * sizeof b3.Hash, b3.Hash);
+
+ return b3;
+}
+
+const char*
+BLAKE3::ToHexString(char* outString /* 40 characters + NUL terminator */) const
+{
+ ToHexBytes(Hash, sizeof(BLAKE3), outString);
+ outString[2 * sizeof(BLAKE3)] = '\0';
+
+ return outString;
+}
+
+StringBuilderBase&
+BLAKE3::ToHexString(StringBuilderBase& outBuilder) const
+{
+ char str[65];
+ ToHexString(str);
+
+ outBuilder.AppendRange(str, &str[65]);
+
+ return outBuilder;
+}
+
+BLAKE3Stream::BLAKE3Stream()
+{
+ blake3_hasher* b3h = reinterpret_cast<blake3_hasher*>(m_HashState);
+ static_assert(sizeof(blake3_hasher) <= sizeof(m_HashState));
+ blake3_hasher_init(b3h);
+}
+
+void
+BLAKE3Stream::Reset()
+{
+ blake3_hasher* b3h = reinterpret_cast<blake3_hasher*>(m_HashState);
+ blake3_hasher_init(b3h);
+}
+
+BLAKE3Stream&
+BLAKE3Stream::Append(const void* data, size_t byteCount)
+{
+ blake3_hasher* b3h = reinterpret_cast<blake3_hasher*>(m_HashState);
+ blake3_hasher_update(b3h, data, byteCount);
+
+ return *this;
+}
+
+BLAKE3
+BLAKE3Stream::GetHash()
+{
+ BLAKE3 b3;
+
+ blake3_hasher* b3h = reinterpret_cast<blake3_hasher*>(m_HashState);
+ blake3_hasher_finalize(b3h, b3.Hash, sizeof b3.Hash);
+
+ return b3;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+// doctest::String
+// toString(const BLAKE3& value)
+// {
+// char text[2 * sizeof(BLAKE3) + 1];
+// value.ToHexString(text);
+
+// return text;
+// }
+
+TEST_CASE("BLAKE3")
+{
+ SUBCASE("Basics")
+ {
+ BLAKE3 b3 = BLAKE3::HashMemory(nullptr, 0);
+ CHECK(BLAKE3::FromHexString("af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262") == b3);
+
+ BLAKE3::String_t b3s;
+ std::string b3ss = b3.ToHexString(b3s);
+ CHECK(b3ss == "af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262");
+ }
+
+ SUBCASE("hashes")
+ {
+ CHECK(BLAKE3::FromHexString("00307ced6a8b278d5e3a9f77b138d0e9d2209717c9d45b205f427a73565cc5fb") == BLAKE3::HashMemory("abc123", 6));
+ CHECK(BLAKE3::FromHexString("a7142c8c3905cd11b1e35105c7ac588b75d6798822f71e1145187ad46f3e8df4") ==
+ BLAKE3::HashMemory("1234567890123456789012345678901234567890", 40));
+ CHECK(BLAKE3::FromHexString("70e708532559265c4662d0285e5e0a4be8bd972bd1f255a93ddf342243adc427") ==
+ BLAKE3::HashMemory("The HttpSendHttpResponse function sends an HTTP response to the specified HTTP request.", 87));
+ }
+
+ SUBCASE("streamHashes")
+ {
+ auto streamHash = [](const void* data, size_t dataBytes) -> BLAKE3 {
+ BLAKE3Stream b3s;
+ b3s.Append(data, dataBytes);
+ return b3s.GetHash();
+ };
+
+ CHECK(BLAKE3::FromHexString("00307ced6a8b278d5e3a9f77b138d0e9d2209717c9d45b205f427a73565cc5fb") == streamHash("abc123", 6));
+ CHECK(BLAKE3::FromHexString("a7142c8c3905cd11b1e35105c7ac588b75d6798822f71e1145187ad46f3e8df4") ==
+ streamHash("1234567890123456789012345678901234567890", 40));
+ CHECK(BLAKE3::FromHexString("70e708532559265c4662d0285e5e0a4be8bd972bd1f255a93ddf342243adc427") ==
+ streamHash("The HttpSendHttpResponse function sends an HTTP response to the specified HTTP request.", 87));
+ }
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/compactbinary.cpp b/src/zencore/compactbinary.cpp
new file mode 100644
index 000000000..0db9f02ea
--- /dev/null
+++ b/src/zencore/compactbinary.cpp
@@ -0,0 +1,2299 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencore/compactbinary.h"
+
+#include <zencore/base64.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/compactbinaryvalue.h>
+#include <zencore/compress.h>
+#include <zencore/endian.h>
+#include <zencore/fmtutils.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/uid.h>
+
+#include <fmt/format.h>
+#include <string_view>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#else
+# include <time.h>
+#endif
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+const int DaysToMonth[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365};
+
+double
+GetJulianDay(uint64_t Ticks)
+{
+ return (double)(1721425.5 + Ticks / TimeSpan::TicksPerDay);
+}
+
+bool
+IsLeapYear(int Year)
+{
+ if ((Year % 4) == 0)
+ {
+ return (((Year % 100) != 0) || ((Year % 400) == 0));
+ }
+
+ return false;
+}
+
+static constexpr uint64_t
+GetPlatformToDateTimeBiasInSeconds()
+{
+#if ZEN_PLATFORM_WINDOWS
+ const uint64_t PlatformEpochYear = 1601;
+#else
+ const uint64_t PlatformEpochYear = 1970;
+#endif
+ const uint64_t DateTimeEpochYear = 1;
+ return uint64_t(double(PlatformEpochYear - DateTimeEpochYear) * 365.2425) * 86400;
+}
+
+uint64_t
+DateTime::NowTicks()
+{
+ static constexpr uint64_t EpochBias = GetPlatformToDateTimeBiasInSeconds();
+
+#if ZEN_PLATFORM_WINDOWS
+ FILETIME SysTime;
+ GetSystemTimePreciseAsFileTime(&SysTime);
+ return (EpochBias * TimeSpan::TicksPerSecond) + ((uint64_t(SysTime.dwHighDateTime) << 32) | SysTime.dwLowDateTime);
+#else
+ int64_t SecondsSinceUnixEpoch = time(nullptr);
+ return (EpochBias + SecondsSinceUnixEpoch) * TimeSpan::TicksPerSecond;
+#endif
+}
+
+DateTime
+DateTime::Now()
+{
+ return DateTime{NowTicks()};
+}
+
+void
+DateTime::Set(int Year, int Month, int Day, int Hour, int Minute, int Second, int MilliSecond)
+{
+ int TotalDays = 0;
+
+ if ((Month > 2) && IsLeapYear(Year))
+ {
+ ++TotalDays;
+ }
+
+ --Year; // the current year is not a full year yet
+ --Month; // the current month is not a full month yet
+
+ TotalDays += Year * 365;
+ TotalDays += Year / 4; // leap year day every four years...
+ TotalDays -= Year / 100; // ...except every 100 years...
+ TotalDays += Year / 400; // ...but also every 400 years
+ TotalDays += DaysToMonth[Month]; // days in this year up to last month
+ TotalDays += Day - 1; // days in this month minus today
+
+ Ticks = TotalDays * TimeSpan::TicksPerDay + Hour * TimeSpan::TicksPerHour + Minute * TimeSpan::TicksPerMinute +
+ Second * TimeSpan::TicksPerSecond + MilliSecond * TimeSpan::TicksPerMillisecond;
+}
+
+int
+DateTime::GetYear() const
+{
+ int Year, Month, Day;
+ GetDate(Year, Month, Day);
+
+ return Year;
+}
+
+int
+DateTime::GetMonth() const
+{
+ int Year, Month, Day;
+ GetDate(Year, Month, Day);
+
+ return Month;
+}
+
+int
+DateTime::GetDay() const
+{
+ int Year, Month, Day;
+ GetDate(Year, Month, Day);
+
+ return Day;
+}
+
+int
+DateTime::GetHour() const
+{
+ return (int)((Ticks / TimeSpan::TicksPerHour) % 24);
+}
+
+int
+DateTime::GetHour12() const
+{
+ int Hour = GetHour();
+
+ if (Hour < 1)
+ {
+ return 12;
+ }
+
+ if (Hour > 12)
+ {
+ return (Hour - 12);
+ }
+
+ return Hour;
+}
+
+int
+DateTime::GetMinute() const
+{
+ return (int)((Ticks / TimeSpan::TicksPerMinute) % 60);
+}
+
+int
+DateTime::GetSecond() const
+{
+ return (int)((Ticks / TimeSpan::TicksPerSecond) % 60);
+}
+
+int
+DateTime::GetMillisecond() const
+{
+ return (int)((Ticks / TimeSpan::TicksPerMillisecond) % 1000);
+}
+
+void
+DateTime::GetDate(int& Year, int& Month, int& Day) const
+{
+ // Based on FORTRAN code in:
+ // Fliegel, H. F. and van Flandern, T. C.,
+ // Communications of the ACM, Vol. 11, No. 10 (October 1968).
+
+ int i, j, k, l, n;
+
+ l = int(GetJulianDay(Ticks) + 0.5) + 68569;
+ n = 4 * l / 146097;
+ l = l - (146097 * n + 3) / 4;
+ i = 4000 * (l + 1) / 1461001;
+ l = l - 1461 * i / 4 + 31;
+ j = 80 * l / 2447;
+ k = l - 2447 * j / 80;
+ l = j / 11;
+ j = j + 2 - 12 * l;
+ i = 100 * (n - 49) + i + l;
+
+ Year = i;
+ Month = j;
+ Day = k;
+}
+
+std::string
+DateTime::ToString(const char* Format) const
+{
+ ExtendableStringBuilder<32> Result;
+ int Year, Month, Day;
+
+ GetDate(Year, Month, Day);
+
+ if (Format != nullptr)
+ {
+ while (*Format != '\0')
+ {
+ if ((*Format == '%') && (*(++Format) != '\0'))
+ {
+ switch (*Format)
+ {
+ // case 'a': Result.Append(IsMorning() ? TEXT("am") : TEXT("pm")); break;
+ // case 'A': Result.Append(IsMorning() ? TEXT("AM") : TEXT("PM")); break;
+ case 'd':
+ Result.Append(fmt::format("{:02}", Day));
+ break;
+ // case 'D': Result.Appendf(TEXT("%03i"), GetDayOfYear()); break;
+ case 'm':
+ Result.Append(fmt::format("{:02}", Month));
+ break;
+ case 'y':
+ Result.Append(fmt::format("{:02}", Year % 100));
+ break;
+ case 'Y':
+ Result.Append(fmt::format("{:04}", Year));
+ break;
+ case 'h':
+ Result.Append(fmt::format("{:02}", GetHour12()));
+ break;
+ case 'H':
+ Result.Append(fmt::format("{:02}", GetHour()));
+ break;
+ case 'M':
+ Result.Append(fmt::format("{:02}", GetMinute()));
+ break;
+ case 'S':
+ Result.Append(fmt::format("{:02}", GetSecond()));
+ break;
+ case 's':
+ Result.Append(fmt::format("{:03}", GetMillisecond()));
+ break;
+ default:
+ Result.Append(*Format);
+ }
+ }
+ else
+ {
+ Result.Append(*Format);
+ }
+
+ // move to the next one
+ Format++;
+ }
+ }
+
+ return Result.ToString();
+}
+
+std::string
+DateTime::ToIso8601() const
+{
+ return ToString("%Y-%m-%dT%H:%M:%S.%sZ");
+}
+
+void
+TimeSpan::Set(int Days, int Hours, int Minutes, int Seconds, int FractionNano)
+{
+ int64_t TotalTicks = 0;
+
+ TotalTicks += Days * TicksPerDay;
+ TotalTicks += Hours * TicksPerHour;
+ TotalTicks += Minutes * TicksPerMinute;
+ TotalTicks += Seconds * TicksPerSecond;
+ TotalTicks += FractionNano / NanosecondsPerTick;
+
+ Ticks = TotalTicks;
+}
+
+std::string
+TimeSpan::ToString(const char* Format) const
+{
+ StringBuilder<128> Result;
+
+ Result.Append((int64_t(Ticks) < 0) ? '-' : '+');
+
+ while (*Format != '\0')
+ {
+ if ((*Format == '%') && (*++Format != '\0'))
+ {
+ switch (*Format)
+ {
+ case 'd':
+ Result.Append(fmt::format("{}", GetDays()));
+ break;
+ case 'D':
+ Result.Append(fmt::format("{:08}", GetDays()));
+ break;
+ case 'h':
+ Result.Append(fmt::format("{:02}", GetHours()));
+ break;
+ case 'm':
+ Result.Append(fmt::format("{:02}", GetMinutes()));
+ break;
+ case 's':
+ Result.Append(fmt::format("{:02}", GetSeconds()));
+ break;
+ case 'f':
+ Result.Append(fmt::format("{:03}", GetFractionMilli()));
+ break;
+ case 'u':
+ Result.Append(fmt::format("{:06}", GetFractionMicro()));
+ break;
+ case 't':
+ Result.Append(fmt::format("{:07}", GetFractionTicks()));
+ break;
+ case 'n':
+ Result.Append(fmt::format("{:09}", GetFractionNano()));
+ break;
+ default:
+ Result.Append(*Format);
+ }
+ }
+ else
+ {
+ Result.Append(*Format);
+ }
+
+ ++Format;
+ }
+
+ return Result.ToString();
+}
+
+std::string
+TimeSpan::ToString() const
+{
+ if (GetDays() == 0)
+ {
+ return ToString("%h:%m:%s.%f");
+ }
+
+ return ToString("%d.%h:%m:%s.%f");
+}
+
+StringBuilderBase&
+Guid::ToString(StringBuilderBase& Sb) const
+{
+ char Buf[128];
+ snprintf(Buf, sizeof Buf, "%08x-%04x-%04x-%04x-%04x%08x", A, B >> 16, B & 0xFFFF, C >> 16, C & 0xFFFF, D);
+ Sb << Buf;
+
+ return Sb;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace CompactBinaryPrivate {
+ static constexpr const uint8_t GEmptyObjectPayload[] = {uint8_t(CbFieldType::Object), 0x00};
+ static constexpr const uint8_t GEmptyArrayPayload[] = {uint8_t(CbFieldType::Array), 0x01, 0x00};
+} // namespace CompactBinaryPrivate
+
+//////////////////////////////////////////////////////////////////////////
+
+CbFieldView::CbFieldView(const void* DataPointer, CbFieldType FieldType)
+{
+ const uint8_t* Bytes = static_cast<const uint8_t*>(DataPointer);
+ const CbFieldType LocalType = CbFieldTypeOps::HasFieldType(FieldType) ? (CbFieldType(*Bytes++) | CbFieldType::HasFieldType) : FieldType;
+
+ uint32_t NameLenByteCount = 0;
+ const uint64_t NameLen64 = CbFieldTypeOps::HasFieldName(LocalType) ? ReadVarUInt(Bytes, NameLenByteCount) : 0;
+ Bytes += NameLen64 + NameLenByteCount;
+
+ Type = LocalType;
+ NameLen = uint32_t(std::clamp<uint64_t>(NameLen64, 0, ~uint32_t(0)));
+ Payload = Bytes;
+}
+
+void
+CbFieldView::IterateAttachments(std::function<void(CbFieldView)> Visitor) const
+{
+ switch (CbFieldTypeOps::GetType(Type))
+ {
+ case CbFieldType::Object:
+ case CbFieldType::UniformObject:
+ return CbObjectView::FromFieldView(*this).IterateAttachments(Visitor);
+ case CbFieldType::Array:
+ case CbFieldType::UniformArray:
+ return CbArrayView::FromFieldView(*this).IterateAttachments(Visitor);
+ case CbFieldType::ObjectAttachment:
+ case CbFieldType::BinaryAttachment:
+ return Visitor(*this);
+ default:
+ return;
+ }
+}
+
+CbObjectView
+CbFieldView::AsObjectView()
+{
+ if (CbFieldTypeOps::IsObject(Type))
+ {
+ Error = CbFieldError::None;
+ return CbObjectView::FromFieldView(*this);
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return CbObjectView();
+ }
+}
+
+CbArrayView
+CbFieldView::AsArrayView()
+{
+ if (CbFieldTypeOps::IsArray(Type))
+ {
+ Error = CbFieldError::None;
+ return CbArrayView::FromFieldView(*this);
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return CbArrayView();
+ }
+}
+
+MemoryView
+CbFieldView::AsBinaryView(const MemoryView Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsBinary(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsBinary();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+std::string_view
+CbFieldView::AsString(const std::string_view Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsString(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsString();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+std::u8string_view
+CbFieldView::AsU8String(const std::u8string_view Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsString(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsU8String();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+uint64_t
+CbFieldView::AsInteger(const uint64_t Default, const CompactBinaryPrivate::IntegerParams Params)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsInteger(Accessor.GetType()))
+ {
+ return Accessor.AsInteger(Params, &Error, Default);
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+float
+CbFieldView::AsFloat(const float Default)
+{
+ switch (CbValue Accessor = GetValue(); Accessor.GetType())
+ {
+ case CbFieldType::IntegerPositive:
+ case CbFieldType::IntegerNegative:
+ {
+ const uint64_t IsNegative = uint8_t(Accessor.GetType()) & 1;
+ constexpr uint64_t OutOfRangeMask = ~((uint64_t(1) << /*FLT_MANT_DIG*/ 24) - 1);
+
+ uint32_t MagnitudeByteCount;
+ const int64_t Magnitude = ReadVarUInt(Accessor.GetData(), MagnitudeByteCount) + IsNegative;
+ const uint64_t IsInRange = !(Magnitude & OutOfRangeMask);
+ Error = IsInRange ? CbFieldError::None : CbFieldError::RangeError;
+ return IsInRange ? float(IsNegative ? -Magnitude : Magnitude) : Default;
+ }
+ case CbFieldType::Float32:
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsFloat32();
+ }
+ case CbFieldType::Float64:
+ Error = CbFieldError::RangeError;
+ return Default;
+ default:
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+double
+CbFieldView::AsDouble(const double Default)
+{
+ switch (CbValue Accessor = GetValue(); Accessor.GetType())
+ {
+ case CbFieldType::IntegerPositive:
+ case CbFieldType::IntegerNegative:
+ {
+ const uint64_t IsNegative = uint8_t(Accessor.GetType()) & 1;
+ constexpr uint64_t OutOfRangeMask = ~((uint64_t(1) << /*DBL_MANT_DIG*/ 53) - 1);
+
+ uint32_t MagnitudeByteCount;
+ const int64_t Magnitude = ReadVarUInt(Accessor.GetData(), MagnitudeByteCount) + IsNegative;
+ const uint64_t IsInRange = !(Magnitude & OutOfRangeMask);
+ Error = IsInRange ? CbFieldError::None : CbFieldError::RangeError;
+ return IsInRange ? double(IsNegative ? -Magnitude : Magnitude) : Default;
+ }
+ case CbFieldType::Float32:
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsFloat32();
+ }
+ case CbFieldType::Float64:
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsFloat64();
+ }
+ default:
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+bool
+CbFieldView::AsBool(const bool bDefault)
+{
+ CbValue Accessor = GetValue();
+ const bool IsBool = CbFieldTypeOps::IsBool(Accessor.GetType());
+ Error = IsBool ? CbFieldError::None : CbFieldError::TypeError;
+ return (uint8_t(IsBool) & Accessor.AsBool()) | ((!IsBool) & bDefault);
+}
+
+IoHash
+CbFieldView::AsObjectAttachment(const IoHash& Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsObjectAttachment(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsObjectAttachment();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+IoHash
+CbFieldView::AsBinaryAttachment(const IoHash& Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsBinaryAttachment(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsBinaryAttachment();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+IoHash
+CbFieldView::AsAttachment(const IoHash& Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsAttachment(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsAttachment();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+IoHash
+CbFieldView::AsHash(const IoHash& Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsHash(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsHash();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+Guid
+CbFieldView::AsUuid()
+{
+ return AsUuid(Guid());
+}
+
+Guid
+CbFieldView::AsUuid(const Guid& Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsUuid(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsUuid();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+Oid
+CbFieldView::AsObjectId()
+{
+ return AsObjectId(Oid());
+}
+
+Oid
+CbFieldView::AsObjectId(const Oid& Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsObjectId(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsObjectId();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+CbCustomById
+CbFieldView::AsCustomById(CbCustomById Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsCustomById(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsCustomById();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+CbCustomByName
+CbFieldView::AsCustomByName(CbCustomByName Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsCustomByName(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsCustomByName();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+int64_t
+CbFieldView::AsDateTimeTicks(const int64_t Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsDateTime(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsDateTimeTicks();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+DateTime
+CbFieldView::AsDateTime()
+{
+ return DateTime(AsDateTimeTicks(0));
+}
+
+DateTime
+CbFieldView::AsDateTime(DateTime Default)
+{
+ return DateTime(AsDateTimeTicks(Default.GetTicks()));
+}
+
+int64_t
+CbFieldView::AsTimeSpanTicks(const int64_t Default)
+{
+ if (CbValue Accessor = GetValue(); CbFieldTypeOps::IsTimeSpan(Accessor.GetType()))
+ {
+ Error = CbFieldError::None;
+ return Accessor.AsTimeSpanTicks();
+ }
+ else
+ {
+ Error = CbFieldError::TypeError;
+ return Default;
+ }
+}
+
+TimeSpan
+CbFieldView::AsTimeSpan()
+{
+ return TimeSpan(AsTimeSpanTicks(0));
+}
+
+TimeSpan
+CbFieldView::AsTimeSpan(TimeSpan Default)
+{
+ return TimeSpan(AsTimeSpanTicks(Default.GetTicks()));
+}
+
+uint64_t
+CbFieldView::GetSize() const
+{
+ return sizeof(CbFieldType) + GetViewNoType().GetSize();
+}
+
+uint64_t
+CbFieldView::GetPayloadSize() const
+{
+ switch (CbFieldTypeOps::GetType(Type))
+ {
+ case CbFieldType::None:
+ case CbFieldType::Null:
+ return 0;
+ case CbFieldType::Object:
+ case CbFieldType::UniformObject:
+ case CbFieldType::Array:
+ case CbFieldType::UniformArray:
+ case CbFieldType::Binary:
+ case CbFieldType::String:
+ {
+ uint32_t PayloadSizeByteCount;
+ const uint64_t PayloadSize = ReadVarUInt(Payload, PayloadSizeByteCount);
+ return PayloadSize + PayloadSizeByteCount;
+ }
+ case CbFieldType::IntegerPositive:
+ case CbFieldType::IntegerNegative:
+ return MeasureVarUInt(Payload);
+ case CbFieldType::Float32:
+ return 4;
+ case CbFieldType::Float64:
+ return 8;
+ case CbFieldType::BoolFalse:
+ case CbFieldType::BoolTrue:
+ return 0;
+ case CbFieldType::ObjectAttachment:
+ case CbFieldType::BinaryAttachment:
+ case CbFieldType::Hash:
+ return 20;
+ case CbFieldType::Uuid:
+ return 16;
+ case CbFieldType::ObjectId:
+ return 12;
+ case CbFieldType::DateTime:
+ case CbFieldType::TimeSpan:
+ return 8;
+ default:
+ return 0;
+ }
+}
+
+IoHash
+CbFieldView::GetHash() const
+{
+ IoHashStream HashStream;
+ GetHash(HashStream);
+ return HashStream.GetHash();
+}
+
+void
+CbFieldView::GetHash(IoHashStream& Hash) const
+{
+ const CbFieldType SerializedType = CbFieldTypeOps::GetSerializedType(Type);
+ Hash.Append(&SerializedType, sizeof(SerializedType));
+ auto View = GetViewNoType();
+ Hash.Append(View.GetData(), View.GetSize());
+}
+
+bool
+CbFieldView::Equals(const CbFieldView& Other) const
+{
+ return CbFieldTypeOps::GetSerializedType(Type) == CbFieldTypeOps::GetSerializedType(Other.Type) &&
+ GetViewNoType().EqualBytes(Other.GetViewNoType());
+}
+
+void
+CbFieldView::CopyTo(MutableMemoryView Buffer) const
+{
+ const MemoryView Source = GetViewNoType();
+ ZEN_ASSERT(Buffer.GetSize() == sizeof(CbFieldType) + Source.GetSize());
+ // TEXT("A buffer of %" UINT64_FMT " bytes was provided when %" UINT64_FMT " bytes are required"),
+ // Buffer.GetSize(),
+ // sizeof(CbFieldType) + Source.GetSize());
+ *static_cast<CbFieldType*>(Buffer.GetData()) = CbFieldTypeOps::GetSerializedType(Type);
+ Buffer.RightChopInline(sizeof(CbFieldType));
+ memcpy(Buffer.GetData(), Source.GetData(), Source.GetSize());
+}
+
+void
+CbFieldView::CopyTo(BinaryWriter& Ar) const
+{
+ const MemoryView SourceView = GetViewNoType();
+ CbFieldType SerializedType = CbFieldTypeOps::GetSerializedType(Type);
+ const MemoryView TypeView(reinterpret_cast<const uint8_t*>(&SerializedType), sizeof(SerializedType));
+ Ar.Write({TypeView, SourceView});
+}
+
+MemoryView
+CbFieldView::GetView() const
+{
+ const uint32_t TypeSize = CbFieldTypeOps::HasFieldType(Type) ? sizeof(CbFieldType) : 0;
+ const uint32_t NameSize = CbFieldTypeOps::HasFieldName(Type) ? NameLen + MeasureVarUInt(NameLen) : 0;
+ const uint64_t PayloadSize = GetPayloadSize();
+ return MemoryView(static_cast<const uint8_t*>(Payload) - TypeSize - NameSize, TypeSize + NameSize + PayloadSize);
+}
+
+MemoryView
+CbFieldView::GetViewNoType() const
+{
+ const uint32_t NameSize = CbFieldTypeOps::HasFieldName(Type) ? NameLen + MeasureVarUInt(NameLen) : 0;
+ const uint64_t PayloadSize = GetPayloadSize();
+ return MemoryView(static_cast<const uint8_t*>(Payload) - NameSize, NameSize + PayloadSize);
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+CbArrayView::CbArrayView() : CbFieldView(CompactBinaryPrivate::GEmptyArrayPayload)
+{
+}
+
+uint64_t
+CbArrayView::Num() const
+{
+ const uint8_t* PayloadBytes = static_cast<const uint8_t*>(GetPayload());
+ PayloadBytes += MeasureVarUInt(PayloadBytes);
+ uint32_t NumByteCount;
+ return ReadVarUInt(PayloadBytes, NumByteCount);
+}
+
+CbFieldViewIterator
+CbArrayView::CreateViewIterator() const
+{
+ const uint8_t* PayloadBytes = static_cast<const uint8_t*>(GetPayload());
+ uint32_t PayloadSizeByteCount;
+ const uint64_t PayloadSize = ReadVarUInt(PayloadBytes, PayloadSizeByteCount);
+ PayloadBytes += PayloadSizeByteCount;
+ const uint64_t NumByteCount = MeasureVarUInt(PayloadBytes);
+ if (PayloadSize > NumByteCount)
+ {
+ const void* const PayloadEnd = PayloadBytes + PayloadSize;
+ PayloadBytes += NumByteCount;
+ const CbFieldType UniformType =
+ CbFieldTypeOps::GetType(GetType()) == CbFieldType::UniformArray ? CbFieldType(*PayloadBytes++) : CbFieldType::HasFieldType;
+ return CbFieldViewIterator::MakeRange(MemoryView(PayloadBytes, PayloadEnd), UniformType);
+ }
+ return CbFieldViewIterator();
+}
+
+void
+CbArrayView::VisitFields(ICbVisitor&)
+{
+}
+
+uint64_t
+CbArrayView::GetSize() const
+{
+ return sizeof(CbFieldType) + GetPayloadSize();
+}
+
+IoHash
+CbArrayView::GetHash() const
+{
+ IoHashStream Hash;
+ GetHash(Hash);
+ return Hash.GetHash();
+}
+
+void
+CbArrayView::GetHash(IoHashStream& HashStream) const
+{
+ const CbFieldType SerializedType = CbFieldTypeOps::GetType(GetType());
+ HashStream.Append(&SerializedType, sizeof(SerializedType));
+ auto _ = GetPayloadView();
+ HashStream.Append(_.GetData(), _.GetSize());
+}
+
+bool
+CbArrayView::Equals(const CbArrayView& Other) const
+{
+ return CbFieldTypeOps::GetType(GetType()) == CbFieldTypeOps::GetType(Other.GetType()) &&
+ GetPayloadView().EqualBytes(Other.GetPayloadView());
+}
+
+void
+CbArrayView::CopyTo(MutableMemoryView Buffer) const
+{
+ const MemoryView Source = GetPayloadView();
+ ZEN_ASSERT(Buffer.GetSize() == sizeof(CbFieldType) + Source.GetSize());
+ // TEXT("Buffer is %" UINT64_FMT " bytes but %" UINT64_FMT " is required."),
+ // Buffer.GetSize(),
+ // sizeof(CbFieldType) + Source.GetSize());
+
+ *static_cast<CbFieldType*>(Buffer.GetData()) = CbFieldTypeOps::GetType(GetType());
+ Buffer.RightChopInline(sizeof(CbFieldType));
+ memcpy(Buffer.GetData(), Source.GetData(), Source.GetSize());
+}
+
+void
+CbArrayView::CopyTo(BinaryWriter& Ar) const
+{
+ const MemoryView SourceView = GetPayloadView();
+ CbFieldType SerializedType = CbFieldTypeOps::GetSerializedType(GetType());
+ const MemoryView TypeView(reinterpret_cast<const uint8_t*>(&SerializedType), sizeof(SerializedType));
+ Ar.Write({TypeView, SourceView});
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+CbObjectView::CbObjectView() : CbFieldView(CompactBinaryPrivate::GEmptyObjectPayload)
+{
+}
+
+CbFieldViewIterator
+CbObjectView::CreateViewIterator() const
+{
+ const uint8_t* PayloadBytes = static_cast<const uint8_t*>(GetPayload());
+ uint32_t PayloadSizeByteCount;
+ const uint64_t PayloadSize = ReadVarUInt(PayloadBytes, PayloadSizeByteCount);
+
+ PayloadBytes += PayloadSizeByteCount;
+
+ if (PayloadSize)
+ {
+ const void* const PayloadEnd = PayloadBytes + PayloadSize;
+ const CbFieldType UniformType =
+ CbFieldTypeOps::GetType(GetType()) == CbFieldType::UniformObject ? CbFieldType(*PayloadBytes++) : CbFieldType::HasFieldType;
+ return CbFieldViewIterator::MakeRange(MemoryView(PayloadBytes, PayloadEnd), UniformType);
+ }
+
+ return CbFieldViewIterator();
+}
+
+void
+CbObjectView::VisitFields(ICbVisitor&)
+{
+}
+
+CbFieldView
+CbObjectView::FindView(const std::string_view Name) const
+{
+ for (const CbFieldView& Field : *this)
+ {
+ if (Name == Field.GetName())
+ {
+ return Field;
+ }
+ }
+ return CbFieldView();
+}
+
+CbFieldView
+CbObjectView::FindViewIgnoreCase(const std::string_view Name) const
+{
+ for (const CbFieldView& Field : *this)
+ {
+ if (Name == Field.GetName())
+ {
+ return Field;
+ }
+ }
+ return CbFieldView();
+}
+
+CbObjectView::operator bool() const
+{
+ return GetSize() > sizeof(CompactBinaryPrivate::GEmptyObjectPayload);
+}
+
+uint64_t
+CbObjectView::GetSize() const
+{
+ return sizeof(CbFieldType) + GetPayloadSize();
+}
+
+IoHash
+CbObjectView::GetHash() const
+{
+ IoHashStream Hash;
+ GetHash(Hash);
+ return Hash.GetHash();
+}
+
+void
+CbObjectView::GetHash(IoHashStream& HashStream) const
+{
+ const CbFieldType SerializedType = CbFieldTypeOps::GetType(GetType());
+ HashStream.Append(&SerializedType, sizeof(SerializedType));
+ HashStream.Append(GetPayloadView());
+}
+
+bool
+CbObjectView::Equals(const CbObjectView& Other) const
+{
+ return CbFieldTypeOps::GetType(GetType()) == CbFieldTypeOps::GetType(Other.GetType()) &&
+ GetPayloadView().EqualBytes(Other.GetPayloadView());
+}
+
+void
+CbObjectView::CopyTo(MutableMemoryView Buffer) const
+{
+ const MemoryView Source = GetPayloadView();
+ ZEN_ASSERT(Buffer.GetSize() == (sizeof(CbFieldType) + Source.GetSize()));
+ // TEXT("Buffer is %" UINT64_FMT " bytes but %" UINT64_FMT " is required."),
+ // Buffer.GetSize(),
+ // sizeof(CbFieldType) + Source.GetSize());
+ *static_cast<CbFieldType*>(Buffer.GetData()) = CbFieldTypeOps::GetType(GetType());
+ Buffer.RightChopInline(sizeof(CbFieldType));
+ memcpy(Buffer.GetData(), Source.GetData(), Source.GetSize());
+}
+
+void
+CbObjectView::CopyTo(BinaryWriter& Ar) const
+{
+ const MemoryView SourceView = GetPayloadView();
+ CbFieldType SerializedType = CbFieldTypeOps::GetSerializedType(GetType());
+ const MemoryView TypeView(reinterpret_cast<const uint8_t*>(&SerializedType), sizeof(SerializedType));
+ Ar.Write({TypeView, SourceView});
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+template<typename FieldType>
+uint64_t
+TCbFieldIterator<FieldType>::GetRangeSize() const
+{
+ MemoryView View;
+ if (TryGetSerializedRangeView(View))
+ {
+ return View.GetSize();
+ }
+ else
+ {
+ uint64_t Size = 0;
+ for (CbFieldViewIterator It(*this); It; ++It)
+ {
+ Size += It.GetSize();
+ }
+ return Size;
+ }
+}
+
+template<typename FieldType>
+IoHash
+TCbFieldIterator<FieldType>::GetRangeHash() const
+{
+ IoHashStream Hash;
+ GetRangeHash(Hash);
+ return IoHash(Hash.GetHash());
+}
+
+template<typename FieldType>
+void
+TCbFieldIterator<FieldType>::GetRangeHash(IoHashStream& Hash) const
+{
+ MemoryView View;
+ if (TryGetSerializedRangeView(View))
+ {
+ Hash.Append(View.GetData(), View.GetSize());
+ }
+ else
+ {
+ for (CbFieldViewIterator It(*this); It; ++It)
+ {
+ It.GetHash(Hash);
+ }
+ }
+}
+
+template<typename FieldType>
+void
+TCbFieldIterator<FieldType>::CopyRangeTo(MutableMemoryView InBuffer) const
+{
+ MemoryView Source;
+ if (TryGetSerializedRangeView(Source))
+ {
+ ZEN_ASSERT(InBuffer.GetSize() == Source.GetSize());
+ // TEXT("Buffer is %" UINT64_FMT " bytes but %" UINT64_FMT " is required."),
+ // InBuffer.GetSize(),
+ // Source.GetSize());
+ memcpy(InBuffer.GetData(), Source.GetData(), Source.GetSize());
+ }
+ else
+ {
+ for (CbFieldViewIterator It(*this); It; ++It)
+ {
+ const uint64_t Size = It.GetSize();
+ It.CopyTo(InBuffer.Left(Size));
+ InBuffer.RightChopInline(Size);
+ }
+ }
+}
+
+template class TCbFieldIterator<CbFieldView>;
+template class TCbFieldIterator<CbField>;
+
+template<typename FieldType>
+void
+TCbFieldIterator<FieldType>::IterateRangeAttachments(std::function<void(CbFieldView)> Visitor) const
+{
+ if (CbFieldTypeOps::HasFieldType(FieldType::GetType()))
+ {
+ // Always iterate over non-uniform ranges because we do not know if they contain an attachment.
+ for (CbFieldViewIterator It(*this); It; ++It)
+ {
+ if (CbFieldTypeOps::MayContainAttachments(It.GetType()))
+ {
+ It.IterateAttachments(Visitor);
+ }
+ }
+ }
+ else
+ {
+ // Only iterate over uniform ranges if the uniform type may contain an attachment.
+ if (CbFieldTypeOps::MayContainAttachments(FieldType::GetType()))
+ {
+ for (CbFieldViewIterator It(*this); It; ++It)
+ {
+ It.IterateAttachments(Visitor);
+ }
+ }
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+CbFieldIterator
+CbFieldIterator::CloneRange(const CbFieldViewIterator& It)
+{
+ MemoryView View;
+ if (It.TryGetSerializedRangeView(View))
+ {
+ return MakeRange(SharedBuffer::Clone(View));
+ }
+ else
+ {
+ UniqueBuffer Buffer = UniqueBuffer::Alloc(It.GetRangeSize());
+ It.CopyRangeTo(MutableMemoryView(Buffer.GetData(), Buffer.GetSize()));
+ return MakeRange(SharedBuffer(std::move(Buffer)));
+ }
+}
+
+SharedBuffer
+CbFieldIterator::GetRangeBuffer() const
+{
+ const MemoryView RangeView = GetRangeView();
+ const SharedBuffer& OuterBuffer = GetOuterBuffer();
+ return OuterBuffer.GetView() == RangeView ? OuterBuffer : SharedBuffer::MakeView(RangeView, OuterBuffer);
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+uint64_t
+MeasureCompactBinary(MemoryView View, CbFieldType Type)
+{
+ uint64_t Size;
+ return TryMeasureCompactBinary(View, Type, Size, Type) ? Size : 0;
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+bool
+TryMeasureCompactBinary(MemoryView View, CbFieldType& OutType, uint64_t& OutSize, CbFieldType Type)
+{
+ uint64_t Size = 0;
+
+ if (CbFieldTypeOps::HasFieldType(Type))
+ {
+ if (View.GetSize() == 0)
+ {
+ OutType = CbFieldType::None;
+ OutSize = 1;
+ return false;
+ }
+
+ Type = *static_cast<const CbFieldType*>(View.GetData());
+ View.RightChopInline(1);
+ Size += 1;
+ }
+
+ bool bDynamicSize = false;
+ uint64_t FixedSize = 0;
+ switch (CbFieldTypeOps::GetType(Type))
+ {
+ case CbFieldType::Null:
+ break;
+ case CbFieldType::Object:
+ case CbFieldType::UniformObject:
+ case CbFieldType::Array:
+ case CbFieldType::UniformArray:
+ case CbFieldType::Binary:
+ case CbFieldType::String:
+ case CbFieldType::IntegerPositive:
+ case CbFieldType::IntegerNegative:
+ bDynamicSize = true;
+ break;
+ case CbFieldType::Float32:
+ FixedSize = 4;
+ break;
+ case CbFieldType::Float64:
+ FixedSize = 8;
+ break;
+ case CbFieldType::BoolFalse:
+ case CbFieldType::BoolTrue:
+ break;
+ case CbFieldType::ObjectAttachment:
+ case CbFieldType::BinaryAttachment:
+ case CbFieldType::Hash:
+ FixedSize = 20;
+ break;
+ case CbFieldType::Uuid:
+ FixedSize = 16;
+ break;
+ case CbFieldType::ObjectId:
+ FixedSize = 12;
+ break;
+ case CbFieldType::DateTime:
+ case CbFieldType::TimeSpan:
+ FixedSize = 8;
+ break;
+ case CbFieldType::None:
+ default:
+ OutType = CbFieldType::None;
+ OutSize = 0;
+ return false;
+ }
+
+ OutType = Type;
+
+ if (CbFieldTypeOps::HasFieldName(Type))
+ {
+ if (View.GetSize() == 0)
+ {
+ OutSize = Size + 1;
+ return false;
+ }
+
+ uint32_t NameLenByteCount = MeasureVarUInt(View.GetData());
+ if (View.GetSize() < NameLenByteCount)
+ {
+ OutSize = Size + NameLenByteCount;
+ return false;
+ }
+
+ const uint64_t NameLen = ReadVarUInt(View.GetData(), NameLenByteCount);
+ const uint64_t NameSize = NameLen + NameLenByteCount;
+
+ if (bDynamicSize && View.GetSize() < NameSize)
+ {
+ OutSize = Size + NameSize;
+ return false;
+ }
+
+ View.RightChopInline(NameSize);
+ Size += NameSize;
+ }
+
+ switch (CbFieldTypeOps::GetType(Type))
+ {
+ case CbFieldType::Object:
+ case CbFieldType::UniformObject:
+ case CbFieldType::Array:
+ case CbFieldType::UniformArray:
+ case CbFieldType::Binary:
+ case CbFieldType::String:
+ if (View.GetSize() == 0)
+ {
+ OutSize = Size + 1;
+ return false;
+ }
+ else
+ {
+ uint32_t PayloadSizeByteCount = MeasureVarUInt(View.GetData());
+ if (View.GetSize() < PayloadSizeByteCount)
+ {
+ OutSize = Size + PayloadSizeByteCount;
+ return false;
+ }
+ const uint64_t PayloadSize = ReadVarUInt(View.GetData(), PayloadSizeByteCount);
+ OutSize = Size + PayloadSize + PayloadSizeByteCount;
+ }
+ return true;
+
+ case CbFieldType::IntegerPositive:
+ case CbFieldType::IntegerNegative:
+ if (View.GetSize() == 0)
+ {
+ OutSize = Size + 1;
+ return false;
+ }
+ OutSize = Size + MeasureVarUInt(View.GetData());
+ return true;
+
+ default:
+ OutSize = Size + FixedSize;
+ return true;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+CbField
+LoadCompactBinary(BinaryReader& Ar, BufferAllocator Allocator)
+{
+ std::vector<uint8_t> HeaderBytes;
+ CbFieldType FieldType;
+ uint64_t FieldSize = 1;
+
+ for (const int64_t StartPos = Ar.CurrentOffset(); FieldSize > 0;)
+ {
+ // Read in small increments until the total field size is known, to avoid reading too far.
+ const int32_t ReadSize = int32_t(FieldSize - HeaderBytes.size());
+ if (Ar.CurrentOffset() + ReadSize > Ar.GetSize())
+ {
+ break;
+ }
+
+ const size_t ReadOffset = HeaderBytes.size();
+ HeaderBytes.resize(ReadOffset + ReadSize);
+
+ Ar.Read(HeaderBytes.data() + ReadOffset, ReadSize);
+ if (TryMeasureCompactBinary(MakeMemoryView(HeaderBytes), FieldType, FieldSize))
+ {
+ if (FieldSize <= uint64_t(Ar.Size() - StartPos))
+ {
+ UniqueBuffer Buffer = Allocator(FieldSize);
+ ZEN_ASSERT(Buffer.GetSize() == FieldSize);
+ MutableMemoryView View = Buffer.GetMutableView();
+ memcpy(View.GetData(), HeaderBytes.data(), HeaderBytes.size());
+ View.RightChopInline(HeaderBytes.size());
+ if (!View.IsEmpty())
+ {
+ // Read the remainder of the field.
+ Ar.Read(View.GetData(), View.GetSize());
+ }
+ if (ValidateCompactBinary(Buffer, CbValidateMode::Default) == CbValidateError::None)
+ {
+ return CbField(SharedBuffer(std::move(Buffer)));
+ }
+ }
+ break;
+ }
+ }
+ return CbField();
+}
+
+CbObject
+LoadCompactBinaryObject(IoBuffer&& Payload)
+{
+ return CbObject{SharedBuffer(std::move(Payload))};
+}
+
+CbObject
+LoadCompactBinaryObject(const IoBuffer& Payload)
+{
+ return CbObject{SharedBuffer(Payload)};
+}
+
+CbObject
+LoadCompactBinaryObject(CompressedBuffer&& Payload)
+{
+ return CbObject{SharedBuffer(Payload.DecompressToComposite().Flatten())};
+}
+
+CbObject
+LoadCompactBinaryObject(const CompressedBuffer& Payload)
+{
+ return CbObject{SharedBuffer(Payload.DecompressToComposite().Flatten())};
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+SaveCompactBinary(BinaryWriter& Ar, const CbFieldView& Field)
+{
+ Field.CopyTo(Ar);
+}
+
+void
+SaveCompactBinary(BinaryWriter& Ar, const CbArrayView& Array)
+{
+ Array.CopyTo(Ar);
+}
+
+void
+SaveCompactBinary(BinaryWriter& Ar, const CbObjectView& Object)
+{
+ Object.CopyTo(Ar);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+class CbJsonWriter
+{
+public:
+ explicit CbJsonWriter(StringBuilderBase& InBuilder) : Builder(InBuilder) { NewLineAndIndent << LINE_TERMINATOR_ANSI; }
+
+ void WriteField(CbFieldView Field)
+ {
+ using namespace std::literals;
+
+ WriteOptionalComma();
+ WriteOptionalNewLine();
+
+ if (std::u8string_view Name = Field.GetU8Name(); !Name.empty())
+ {
+ AppendQuotedString(Name);
+ Builder << ": "sv;
+ }
+
+ switch (CbValue Accessor = Field.GetValue(); Accessor.GetType())
+ {
+ case CbFieldType::Null:
+ Builder << "null"sv;
+ break;
+ case CbFieldType::Object:
+ case CbFieldType::UniformObject:
+ {
+ Builder << '{';
+ NewLineAndIndent << '\t';
+ NeedsNewLine = true;
+ for (CbFieldView It : Field)
+ {
+ WriteField(It);
+ }
+ NewLineAndIndent.RemoveSuffix(1);
+ if (NeedsComma)
+ {
+ WriteOptionalNewLine();
+ }
+ Builder << '}';
+ }
+ break;
+ case CbFieldType::Array:
+ case CbFieldType::UniformArray:
+ {
+ Builder << '[';
+ NewLineAndIndent << '\t';
+ NeedsNewLine = true;
+ for (CbFieldView It : Field)
+ {
+ WriteField(It);
+ }
+ NewLineAndIndent.RemoveSuffix(1);
+ if (NeedsComma)
+ {
+ WriteOptionalNewLine();
+ }
+ Builder << ']';
+ }
+ break;
+ case CbFieldType::Binary:
+ AppendBase64String(Accessor.AsBinary());
+ break;
+ case CbFieldType::String:
+ AppendQuotedString(Accessor.AsU8String());
+ break;
+ case CbFieldType::IntegerPositive:
+ Builder << Accessor.AsIntegerPositive();
+ break;
+ case CbFieldType::IntegerNegative:
+ Builder << Accessor.AsIntegerNegative();
+ break;
+ case CbFieldType::Float32:
+ {
+ const float Value = Accessor.AsFloat32();
+ if (std::isfinite(Value))
+ {
+ Builder.Append(fmt::format("{:.9g}", Value));
+ }
+ else
+ {
+ Builder << "null"sv;
+ }
+ }
+ break;
+ case CbFieldType::Float64:
+ {
+ const double Value = Accessor.AsFloat64();
+ if (std::isfinite(Value))
+ {
+ Builder.Append(fmt::format("{:.17g}", Value));
+ }
+ else
+ {
+ Builder << "null"sv;
+ }
+ }
+ break;
+ case CbFieldType::BoolFalse:
+ Builder << "false"sv;
+ break;
+ case CbFieldType::BoolTrue:
+ Builder << "true"sv;
+ break;
+ case CbFieldType::ObjectAttachment:
+ case CbFieldType::BinaryAttachment:
+ {
+ Builder << '"';
+ Accessor.AsAttachment().ToHexString(Builder);
+ Builder << '"';
+ }
+ break;
+ case CbFieldType::Hash:
+ {
+ Builder << '"';
+ Accessor.AsHash().ToHexString(Builder);
+ Builder << '"';
+ }
+ break;
+ case CbFieldType::Uuid:
+ {
+ Builder << '"';
+ Accessor.AsUuid().ToString(Builder);
+ Builder << '"';
+ }
+ break;
+ case CbFieldType::DateTime:
+ Builder << '"' << DateTime(Accessor.AsDateTimeTicks()).ToIso8601() << '"';
+ break;
+ case CbFieldType::TimeSpan:
+ {
+ const TimeSpan Span(Accessor.AsTimeSpanTicks());
+ if (Span.GetDays() == 0)
+ {
+ Builder << '"' << Span.ToString("%h:%m:%s.%n") << '"';
+ }
+ else
+ {
+ Builder << '"' << Span.ToString("%d.%h:%m:%s.%n") << '"';
+ }
+ break;
+ }
+ case CbFieldType::ObjectId:
+ Builder << '"';
+ Accessor.AsObjectId().ToString(Builder);
+ Builder << '"';
+ break;
+ case CbFieldType::CustomById:
+ {
+ CbCustomById Custom = Accessor.AsCustomById();
+ Builder << "{ \"Id\": ";
+ Builder << Custom.Id;
+ Builder << ", \"Data\": ";
+ AppendBase64String(Custom.Data);
+ Builder << " }";
+ break;
+ }
+ case CbFieldType::CustomByName:
+ {
+ CbCustomByName Custom = Accessor.AsCustomByName();
+ Builder << "{ \"Name\": ";
+ AppendQuotedString(Custom.Name);
+ Builder << ", \"Data\": ";
+ AppendBase64String(Custom.Data);
+ Builder << " }";
+ break;
+ }
+ default:
+ ZEN_ASSERT(false);
+ break;
+ }
+
+ NeedsComma = true;
+ NeedsNewLine = true;
+ }
+
+private:
+ void WriteOptionalComma()
+ {
+ if (NeedsComma)
+ {
+ NeedsComma = false;
+ Builder << ',';
+ }
+ }
+
+ void WriteOptionalNewLine()
+ {
+ if (NeedsNewLine)
+ {
+ NeedsNewLine = false;
+ Builder << NewLineAndIndent;
+ }
+ }
+
+ void AppendQuotedString(std::u8string_view Value)
+ {
+ using namespace std::literals;
+
+ const AsciiSet EscapeSet(
+ "\\\"\b\f\n\r\t"
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f"
+ "\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f");
+
+ Builder << '\"';
+ while (!Value.empty())
+ {
+ std::u8string_view Verbatim = AsciiSet::FindPrefixWithout(Value, EscapeSet);
+ Builder << Verbatim;
+
+ Value = Value.substr(Verbatim.size());
+
+ std::u8string_view Escape = AsciiSet::FindPrefixWith(Value, EscapeSet);
+ for (char Char : Escape)
+ {
+ switch (Char)
+ {
+ case '\\':
+ Builder << "\\\\"sv;
+ break;
+ case '\"':
+ Builder << "\\\""sv;
+ break;
+ case '\b':
+ Builder << "\\b"sv;
+ break;
+ case '\f':
+ Builder << "\\f"sv;
+ break;
+ case '\n':
+ Builder << "\\n"sv;
+ break;
+ case '\r':
+ Builder << "\\r"sv;
+ break;
+ case '\t':
+ Builder << "\\t"sv;
+ break;
+ default:
+ Builder << Char;
+ break;
+ }
+ }
+ Value = Value.substr(Escape.size());
+ }
+ Builder << '\"';
+ }
+
+ void AppendBase64String(MemoryView Value)
+ {
+ Builder << '"';
+ ZEN_ASSERT(Value.GetSize() <= 512 * 1024 * 1024);
+ const uint32_t EncodedSize = Base64::GetEncodedDataSize(uint32_t(Value.GetSize()));
+ const size_t EncodedIndex = Builder.AddUninitialized(size_t(EncodedSize));
+ Base64::Encode(static_cast<const uint8_t*>(Value.GetData()), uint32_t(Value.GetSize()), Builder.Data() + EncodedIndex);
+ }
+
+private:
+ StringBuilderBase& Builder;
+ ExtendableStringBuilder<32> NewLineAndIndent;
+ bool NeedsComma{false};
+ bool NeedsNewLine{false};
+};
+
+void
+CompactBinaryToJson(const CbObjectView& Object, StringBuilderBase& Builder)
+{
+ CbJsonWriter Writer(Builder);
+ Writer.WriteField(Object.AsFieldView());
+}
+
+void
+CompactBinaryToJson(const CbArrayView& Array, StringBuilderBase& Builder)
+{
+ CbJsonWriter Writer(Builder);
+ Writer.WriteField(Array.AsFieldView());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+class CbJsonReader
+{
+public:
+ static CbFieldIterator Read(std::string_view JsonText, std::string& Error)
+ {
+ using namespace json11;
+
+ const Json Json = Json::parse(std::string(JsonText), Error);
+
+ if (Error.empty())
+ {
+ CbWriter Writer;
+ if (ReadField(Writer, Json, std::string_view(), Error))
+ {
+ return Writer.Save();
+ }
+ }
+
+ return CbFieldIterator();
+ }
+
+private:
+ static bool ReadField(CbWriter& Writer, const json11::Json& Json, const std::string_view FieldName, std::string& Error)
+ {
+ using namespace json11;
+
+ switch (Json.type())
+ {
+ case Json::Type::OBJECT:
+ {
+ if (FieldName.empty())
+ {
+ Writer.BeginObject();
+ }
+ else
+ {
+ Writer.BeginObject(FieldName);
+ }
+
+ for (const auto& Kv : Json.object_items())
+ {
+ const std::string& Name = Kv.first;
+ const json11::Json& Item = Kv.second;
+
+ if (ReadField(Writer, Item, Name, Error) == false)
+ {
+ return false;
+ }
+ }
+
+ Writer.EndObject();
+ }
+ break;
+ case Json::Type::ARRAY:
+ {
+ if (FieldName.empty())
+ {
+ Writer.BeginArray();
+ }
+ else
+ {
+ Writer.BeginArray(FieldName);
+ }
+
+ for (const json11::Json& Item : Json.array_items())
+ {
+ if (ReadField(Writer, Item, std::string_view(), Error) == false)
+ {
+ return false;
+ }
+ }
+
+ Writer.EndArray();
+ }
+ break;
+ case Json::Type::NUL:
+ {
+ if (FieldName.empty())
+ {
+ Writer.AddNull();
+ }
+ else
+ {
+ Writer.AddNull(FieldName);
+ }
+ }
+ break;
+ case Json::Type::BOOL:
+ {
+ if (FieldName.empty())
+ {
+ Writer.AddBool(Json.bool_value());
+ }
+ else
+ {
+ Writer.AddBool(FieldName, Json.bool_value());
+ }
+ }
+ break;
+ case Json::Type::NUMBER:
+ {
+ if (FieldName.empty())
+ {
+ Writer.AddFloat(Json.number_value());
+ }
+ else
+ {
+ Writer.AddFloat(FieldName, Json.number_value());
+ }
+ }
+ break;
+ case Json::Type::STRING:
+ {
+ Oid Id;
+ if (TryParseObjectId(Json.string_value(), Id))
+ {
+ if (FieldName.empty())
+ {
+ Writer.AddObjectId(Id);
+ }
+ else
+ {
+ Writer.AddObjectId(FieldName, Id);
+ }
+
+ return true;
+ }
+
+ IoHash Hash;
+ if (TryParseIoHash(Json.string_value(), Hash))
+ {
+ if (FieldName.empty())
+ {
+ Writer.AddHash(Hash);
+ }
+ else
+ {
+ Writer.AddHash(FieldName, Hash);
+ }
+
+ return true;
+ }
+
+ if (FieldName.empty())
+ {
+ Writer.AddString(Json.string_value());
+ }
+ else
+ {
+ Writer.AddString(FieldName, Json.string_value());
+ }
+ }
+ break;
+ default:
+ break;
+ }
+
+ return true;
+ }
+
+ static constexpr AsciiSet HexCharSet = AsciiSet("0123456789abcdefABCDEF");
+
+ static bool TryParseObjectId(std::string_view Str, Oid& Id)
+ {
+ using namespace std::literals;
+
+ if (Str.size() == Oid::StringLength && AsciiSet::HasOnly(Str, HexCharSet))
+ {
+ Id = Oid::FromHexString(Str);
+ return true;
+ }
+
+ if (Str.starts_with("0x"sv))
+ {
+ return TryParseObjectId(Str.substr(2), Id);
+ }
+
+ return false;
+ }
+
+ static bool TryParseIoHash(std::string_view Str, IoHash& Hash)
+ {
+ using namespace std::literals;
+
+ if (Str.size() == IoHash::StringLength && AsciiSet::HasOnly(Str, HexCharSet))
+ {
+ Hash = IoHash::FromHexString(Str);
+ return true;
+ }
+
+ if (Str.starts_with("0x"sv))
+ {
+ return TryParseIoHash(Str.substr(2), Hash);
+ }
+
+ return false;
+ }
+};
+
+CbFieldIterator
+LoadCompactBinaryFromJson(std::string_view Json, std::string& Error)
+{
+ if (Json.empty() == false)
+ {
+ return CbJsonReader::Read(Json, Error);
+ }
+
+ return CbFieldIterator();
+}
+
+CbFieldIterator
+LoadCompactBinaryFromJson(std::string_view Json)
+{
+ std::string Error;
+ return LoadCompactBinaryFromJson(Json, Error);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+void
+uson_forcelink()
+{
+}
+
+TEST_CASE("uson")
+{
+ using namespace std::literals;
+
+ SUBCASE("CbField")
+ {
+ constexpr CbFieldView DefaultField;
+ static_assert(!DefaultField.HasName(), "Error in HasName()");
+ static_assert(!DefaultField.HasValue(), "Error in HasValue()");
+ static_assert(!DefaultField.HasError(), "Error in HasError()");
+ static_assert(DefaultField.GetError() == CbFieldError::None, "Error in GetError()");
+
+ CHECK(DefaultField.GetSize() == 1);
+ CHECK(DefaultField.GetName().size() == 0);
+ CHECK(DefaultField.HasName() == false);
+ CHECK(DefaultField.HasValue() == false);
+ CHECK(DefaultField.HasError() == false);
+ CHECK(DefaultField.GetError() == CbFieldError::None);
+
+ const uint8_t Type = (uint8_t)CbFieldType::None;
+ CHECK(DefaultField.GetHash() == IoHash::HashBuffer(&Type, sizeof Type));
+
+ CHECK(DefaultField.GetView() == MemoryView{});
+ MemoryView SerializedView;
+ CHECK(DefaultField.TryGetSerializedView(SerializedView) == false);
+ }
+
+ SUBCASE("CbField(None)")
+ {
+ CbFieldView NoneField(nullptr, CbFieldType::None);
+ CHECK(NoneField.GetSize() == 1);
+ CHECK(NoneField.GetName().size() == 0);
+ CHECK(NoneField.HasName() == false);
+ CHECK(NoneField.HasValue() == false);
+ CHECK(NoneField.HasError() == false);
+ CHECK(NoneField.GetError() == CbFieldError::None);
+ CHECK(NoneField.GetHash() == CbFieldView().GetHash());
+ CHECK(NoneField.GetView() == MemoryView());
+ MemoryView SerializedView;
+ CHECK(NoneField.TryGetSerializedView(SerializedView) == false);
+ }
+
+ SUBCASE("CbField(None|Type|Name)")
+ {
+ constexpr CbFieldType FieldType = CbFieldType::None | CbFieldType::HasFieldName;
+ const char NoneBytes[] = {char(FieldType), 4, 'N', 'a', 'm', 'e'};
+ CbFieldView NoneField(NoneBytes);
+
+ CHECK(NoneField.GetSize() == sizeof(NoneBytes));
+ CHECK(NoneField.GetName().compare("Name"sv) == 0);
+ CHECK(NoneField.HasName() == true);
+ CHECK(NoneField.HasValue() == false);
+ CHECK(NoneField.GetHash() == IoHash::HashBuffer(NoneBytes, sizeof NoneBytes));
+ CHECK(NoneField.GetView() == MemoryView(NoneBytes, sizeof NoneBytes));
+ MemoryView SerializedView;
+ CHECK(NoneField.TryGetSerializedView(SerializedView) == true);
+ CHECK(SerializedView == MemoryView(NoneBytes, sizeof NoneBytes));
+
+ uint8_t CopyBytes[sizeof(NoneBytes)];
+ NoneField.CopyTo(MutableMemoryView(CopyBytes, sizeof CopyBytes));
+ CHECK(MemoryView(NoneBytes, sizeof NoneBytes).EqualBytes(MemoryView(CopyBytes, sizeof CopyBytes)));
+ }
+
+ SUBCASE("CbField(None|Type)")
+ {
+ constexpr CbFieldType FieldType = CbFieldType::None;
+ const char NoneBytes[] = {char(FieldType)};
+ CbFieldView NoneField(NoneBytes);
+
+ CHECK(NoneField.GetSize() == sizeof NoneBytes);
+ CHECK(NoneField.GetName().size() == 0);
+ CHECK(NoneField.HasName() == false);
+ CHECK(NoneField.HasValue() == false);
+ CHECK(NoneField.GetHash() == CbFieldView().GetHash());
+ CHECK(NoneField.GetView() == MemoryView(NoneBytes, sizeof NoneBytes));
+ MemoryView SerializedView;
+ CHECK(NoneField.TryGetSerializedView(SerializedView) == true);
+ CHECK(SerializedView == MemoryView(NoneBytes, sizeof NoneBytes));
+ }
+
+ SUBCASE("CbField(None|Name)")
+ {
+ constexpr CbFieldType FieldType = CbFieldType::None | CbFieldType::HasFieldName;
+ const char NoneBytes[] = {char(FieldType), 4, 'N', 'a', 'm', 'e'};
+ CbFieldView NoneField(NoneBytes + 1, FieldType);
+ CHECK(NoneField.GetSize() == uint64_t(sizeof NoneBytes));
+ CHECK(NoneField.GetName().compare("Name") == 0);
+ CHECK(NoneField.HasName() == true);
+ CHECK(NoneField.HasValue() == false);
+ CHECK(NoneField.GetHash() == IoHash::HashBuffer(NoneBytes, sizeof NoneBytes));
+ CHECK(NoneField.GetView() == MemoryView(NoneBytes + 1, sizeof NoneBytes - 1));
+ MemoryView SerializedView;
+ CHECK(NoneField.TryGetSerializedView(SerializedView) == false);
+
+ uint8_t CopyBytes[sizeof(NoneBytes)];
+ NoneField.CopyTo(MutableMemoryView(CopyBytes, sizeof CopyBytes));
+ CHECK(MemoryView(NoneBytes, sizeof NoneBytes).EqualBytes(MemoryView(CopyBytes, sizeof CopyBytes)));
+ }
+
+ SUBCASE("CbField(None|EmptyName)")
+ {
+ constexpr CbFieldType FieldType = CbFieldType::None | CbFieldType::HasFieldName;
+ const uint8_t NoneBytes[] = {uint8_t(FieldType), 0};
+ CbFieldView NoneField(NoneBytes + 1, FieldType);
+ CHECK(NoneField.GetSize() == sizeof NoneBytes);
+ CHECK(NoneField.GetName().empty() == true);
+ CHECK(NoneField.HasName() == true);
+ CHECK(NoneField.HasValue() == false);
+ CHECK(NoneField.GetHash() == IoHash::HashBuffer(NoneBytes, sizeof NoneBytes));
+ CHECK(NoneField.GetView() == MemoryView(NoneBytes + 1, sizeof NoneBytes - 1));
+ MemoryView SerializedView;
+ CHECK(NoneField.TryGetSerializedView(SerializedView) == false);
+ }
+
+ static_assert(!std::is_constructible<CbFieldView, const CbObjectView&>::value, "Invalid constructor for CbField");
+ static_assert(!std::is_assignable<CbFieldView, const CbObjectView&>::value, "Invalid assignment for CbField");
+ static_assert(!std::is_convertible<CbFieldView, CbObjectView>::value, "Invalid conversion to CbObject");
+ static_assert(!std::is_assignable<CbObjectView, const CbFieldView&>::value, "Invalid assignment for CbObject");
+
+ static_assert(std::is_constructible<CbField>::value, "Missing constructor for CbField");
+ static_assert(std::is_constructible<CbField, const CbField&>::value, "Missing constructor for CbField");
+ static_assert(std::is_constructible<CbField, CbField&&>::value, "Missing constructor for CbField");
+}
+
+TEST_CASE("uson.null")
+{
+ using namespace std::literals;
+
+ SUBCASE("CbField(Null)")
+ {
+ CbFieldView NullField(nullptr, CbFieldType::Null);
+ CHECK(NullField.GetSize() == 1);
+ CHECK(NullField.IsNull() == true);
+ CHECK(NullField.HasValue() == true);
+ CHECK(NullField.HasError() == false);
+ CHECK(NullField.GetError() == CbFieldError::None);
+ const uint8_t Null[]{uint8_t(CbFieldType::Null)};
+ CHECK(NullField.GetHash() == IoHash::HashBuffer(Null, sizeof Null));
+ }
+
+ SUBCASE("CbField(None)")
+ {
+ CbFieldView Field;
+ CHECK(Field.IsNull() == false);
+ }
+}
+
+TEST_CASE("uson.json")
+{
+ SUBCASE("string")
+ {
+ CbObjectWriter Writer;
+ Writer << "KeyOne"
+ << "ValueOne";
+ Writer << "KeyTwo"
+ << "ValueTwo";
+ CbObject Obj = Writer.Save();
+
+ StringBuilder<128> Sb;
+ const char* JsonText = Obj.ToJson(Sb).Data();
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(JsonText, JsonError);
+
+ const std::string ValueOne = Json["KeyOne"].string_value();
+ const std::string ValueTwo = Json["KeyTwo"].string_value();
+
+ CHECK(JsonError.empty());
+ CHECK(ValueOne == "ValueOne");
+ CHECK(ValueTwo == "ValueTwo");
+ }
+
+ SUBCASE("number")
+ {
+ const float ExpectedFloatValue = 21.21f;
+ const double ExpectedDoubleValue = 42.42;
+
+ CbObjectWriter Writer;
+ Writer << "Float" << ExpectedFloatValue;
+ Writer << "Double" << ExpectedDoubleValue;
+
+ CbObject Obj = Writer.Save();
+
+ StringBuilder<128> Sb;
+ const char* JsonText = Obj.ToJson(Sb).Data();
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(JsonText, JsonError);
+
+ const float FloatValue = float(Json["Float"].number_value());
+ const double DoubleValue = Json["Double"].number_value();
+
+ CHECK(JsonError.empty());
+ CHECK(FloatValue == Approx(ExpectedFloatValue));
+ CHECK(DoubleValue == Approx(ExpectedDoubleValue));
+ }
+
+ SUBCASE("number.nan")
+ {
+ const float FloatNan = std::numeric_limits<float>::quiet_NaN();
+ const double DoubleNan = std::numeric_limits<double>::quiet_NaN();
+
+ CbObjectWriter Writer;
+ Writer << "FloatNan" << FloatNan;
+ Writer << "DoubleNan" << DoubleNan;
+
+ CbObject Obj = Writer.Save();
+
+ StringBuilder<128> Sb;
+ const char* JsonText = Obj.ToJson(Sb).Data();
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(JsonText, JsonError);
+
+ const double FloatValue = Json["FloatNan"].number_value();
+ const double DoubleValue = Json["DoubleNan"].number_value();
+
+ CHECK(JsonError.empty());
+ CHECK(FloatValue == 0);
+ CHECK(DoubleValue == 0);
+ }
+}
+
+TEST_CASE("uson.datetime")
+{
+ using namespace std::literals;
+
+ {
+ DateTime D1600(1601, 1, 1);
+ CHECK_EQ(D1600.GetYear(), 1601);
+ CHECK_EQ(D1600.GetMonth(), 1);
+ CHECK_EQ(D1600.GetDay(), 1);
+ CHECK_EQ(D1600.GetHour(), 0);
+ CHECK_EQ(D1600.GetMinute(), 0);
+ CHECK_EQ(D1600.GetSecond(), 0);
+
+ CHECK_EQ(D1600.ToIso8601(), "1601-01-01T00:00:00.000Z"sv);
+ }
+
+ {
+ DateTime D72(1972, 2, 23, 17, 30, 10);
+ CHECK_EQ(D72.GetYear(), 1972);
+ CHECK_EQ(D72.GetMonth(), 2);
+ CHECK_EQ(D72.GetDay(), 23);
+ CHECK_EQ(D72.GetHour(), 17);
+ CHECK_EQ(D72.GetMinute(), 30);
+ CHECK_EQ(D72.GetSecond(), 10);
+ }
+}
+
+TEST_CASE("json.uson")
+{
+ using namespace std::literals;
+ using namespace json11;
+
+ SUBCASE("empty")
+ {
+ CbFieldIterator It = LoadCompactBinaryFromJson(""sv);
+ CHECK(It.HasValue() == false);
+ }
+
+ SUBCASE("object")
+ {
+ const Json JsonObject = Json::object{{"Null", nullptr},
+ {"String", "Value1"},
+ {"Bool", true},
+ {"Number", 46.2},
+ {"Array", Json::array{1, 2, 3}},
+ {"Object",
+ Json::object{
+ {"String", "Value2"},
+ }}};
+
+ CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject();
+
+ CHECK(Cb["Null"].IsNull());
+ CHECK(Cb["String"].AsString() == "Value1"sv);
+ CHECK(Cb["Bool"].AsBool());
+ CHECK(Cb["Number"].AsDouble() == 46.2);
+ CHECK(Cb["Object"].IsObject());
+ CbObjectView Object = Cb["Object"].AsObjectView();
+ CHECK(Object["String"].AsString() == "Value2"sv);
+ }
+
+ SUBCASE("array")
+ {
+ const Json JsonArray = Json::array{42, 43, 44};
+ CbArray Cb = LoadCompactBinaryFromJson(JsonArray.dump()).AsArray();
+
+ auto It = Cb.CreateIterator();
+ CHECK((*It).AsDouble() == 42);
+ It++;
+ CHECK((*It).AsDouble() == 43);
+ It++;
+ CHECK((*It).AsDouble() == 44);
+ }
+
+ SUBCASE("objectid")
+ {
+ const Oid& Id = Oid::NewOid();
+
+ StringBuilder<64> Sb;
+ Id.ToString(Sb);
+
+ Json JsonObject = Json::object{{"value", Sb.ToString()}};
+ CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject();
+
+ CHECK(Cb["value"sv].IsObjectId());
+ CHECK(Cb["value"sv].AsObjectId() == Id);
+
+ Sb.Reset();
+ Sb << "0x";
+ Id.ToString(Sb);
+
+ JsonObject = Json::object{{"value", Sb.ToString()}};
+ Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject();
+
+ CHECK(Cb["value"sv].IsObjectId());
+ CHECK(Cb["value"sv].AsObjectId() == Id);
+ }
+
+ SUBCASE("iohash")
+ {
+ const uint8_t Data[] = {
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ };
+
+ const IoHash Hash = IoHash::HashBuffer(Data, sizeof(Data));
+
+ Json JsonObject = Json::object{{"value", Hash.ToHexString()}};
+ CbObject Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject();
+
+ CHECK(Cb["value"sv].IsHash());
+ CHECK(Cb["value"sv].AsHash() == Hash);
+
+ JsonObject = Json::object{{"value", "0x" + Hash.ToHexString()}};
+ Cb = LoadCompactBinaryFromJson(JsonObject.dump()).AsObject();
+
+ CHECK(Cb["value"sv].IsHash());
+ CHECK(Cb["value"sv].AsHash() == Hash);
+ }
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/compactbinarybuilder.cpp b/src/zencore/compactbinarybuilder.cpp
new file mode 100644
index 000000000..d4ccd434d
--- /dev/null
+++ b/src/zencore/compactbinarybuilder.cpp
@@ -0,0 +1,1545 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencore/compactbinarybuilder.h"
+
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/endian.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+
+#define _USE_MATH_DEFINES
+#include <math.h>
+
+namespace zen {
+
+template<typename T>
+uint64_t
+AddUninitialized(std::vector<T>& Vector, uint64_t Count)
+{
+ const uint64_t Offset = Vector.size();
+ Vector.resize(Offset + Count);
+ return Offset;
+}
+
+template<typename T>
+uint64_t
+Append(std::vector<T>& Vector, const T* Data, uint64_t Count)
+{
+ const uint64_t Offset = Vector.size();
+ Vector.resize(Offset + Count);
+
+ memcpy(Vector.data() + Offset, Data, sizeof(T) * Count);
+
+ return Offset;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+enum class CbWriter::StateFlags : uint8_t
+{
+ None = 0,
+ /** Whether a name has been written for the current field. */
+ Name = 1 << 0,
+ /** Whether this state is in the process of writing a field. */
+ Field = 1 << 1,
+ /** Whether this state is for array fields. */
+ Array = 1 << 2,
+ /** Whether this state is for object fields. */
+ Object = 1 << 3,
+};
+
+ENUM_CLASS_FLAGS(CbWriter::StateFlags);
+
+/** Whether the field type can be used in a uniform array or uniform object. */
+static constexpr bool
+IsUniformType(const CbFieldType Type)
+{
+ if (CbFieldTypeOps::HasFieldName(Type))
+ {
+ return true;
+ }
+
+ switch (Type)
+ {
+ case CbFieldType::None:
+ case CbFieldType::Null:
+ case CbFieldType::BoolFalse:
+ case CbFieldType::BoolTrue:
+ return false;
+ default:
+ return true;
+ }
+}
+
+/** Append the payload from the compact binary value to the array and return its type. */
+static inline CbFieldType
+AppendCompactBinary(const CbFieldView& Value, std::vector<uint8_t>& OutData)
+{
+ struct FCopy : public CbFieldView
+ {
+ using CbFieldView::GetPayloadView;
+ using CbFieldView::GetType;
+ };
+ const FCopy& ValueCopy = static_cast<const FCopy&>(Value);
+ const MemoryView SourceView = ValueCopy.GetPayloadView();
+ const uint64_t TargetOffset = OutData.size();
+ OutData.resize(TargetOffset + SourceView.GetSize());
+ memcpy(OutData.data() + TargetOffset, SourceView.GetData(), SourceView.GetSize());
+ return CbFieldTypeOps::GetType(ValueCopy.GetType());
+}
+
+CbWriter::CbWriter()
+{
+ States.emplace_back();
+}
+
+CbWriter::CbWriter(const int64_t InitialSize) : CbWriter()
+{
+ Data.reserve(InitialSize);
+}
+
+CbWriter::~CbWriter()
+{
+}
+
+void
+CbWriter::Reset()
+{
+ Data.resize(0);
+ States.resize(0);
+ States.emplace_back();
+}
+
+CbFieldIterator
+CbWriter::Save()
+{
+ const uint64_t Size = GetSaveSize();
+ UniqueBuffer Buffer = UniqueBuffer::Alloc(Size);
+ const CbFieldViewIterator Output = Save(Buffer);
+
+ SharedBuffer SharedBuf(std::move(Buffer));
+ SharedBuf.MakeImmutable();
+
+ return CbFieldIterator::MakeRangeView(Output, SharedBuf);
+}
+
+CbFieldViewIterator
+CbWriter::Save(const MutableMemoryView Buffer)
+{
+ ZEN_ASSERT(States.size() == 1 && States.back().Flags == StateFlags::None);
+ // TEXT("It is invalid to save while there are incomplete write operations."));
+ ZEN_ASSERT(Data.size() > 0); // TEXT("It is invalid to save when nothing has been written."));
+ ZEN_ASSERT(Buffer.GetSize() == Data.size());
+ // TEXT("Buffer is %" UINT64_FMT " bytes but %" INT64_FMT " is required."),
+ // Buffer.GetSize(),
+ // Data.Num());
+ memcpy(Buffer.GetData(), Data.data(), Data.size());
+ return CbFieldViewIterator::MakeRange(Buffer);
+}
+
+void
+CbWriter::Save(BinaryWriter& Writer)
+{
+ ZEN_ASSERT(States.size() == 1 && States.back().Flags == StateFlags::None);
+ // TEXT("It is invalid to save while there are incomplete write operations."));
+ ZEN_ASSERT(Data.size() > 0); // TEXT("It is invalid to save when nothing has been written."));
+ Writer.Write(Data.data(), Data.size());
+}
+
+uint64_t
+CbWriter::GetSaveSize() const
+{
+ return Data.size();
+}
+
+void
+CbWriter::BeginField()
+{
+ WriterState& State = States.back();
+ if ((State.Flags & StateFlags::Field) == StateFlags::None)
+ {
+ State.Flags |= StateFlags::Field;
+ State.Offset = Data.size();
+ Data.push_back(0);
+ }
+ else
+ {
+ ZEN_ASSERT((State.Flags & StateFlags::Name) == StateFlags::Name);
+ // TEXT("A new field cannot be written until the previous field '%.*hs' is finished."),
+ // GetActiveName().Len(),
+ // GetActiveName().GetData());
+ }
+}
+
+void
+CbWriter::EndField(CbFieldType Type)
+{
+ WriterState& State = States.back();
+
+ if ((State.Flags & StateFlags::Name) == StateFlags::Name)
+ {
+ Type |= CbFieldType::HasFieldName;
+ }
+ else
+ {
+ ZEN_ASSERT((State.Flags & StateFlags::Object) == StateFlags::None);
+ // TEXT("It is invalid to write an object field without a unique non-empty name."));
+ }
+
+ if (State.Count == 0)
+ {
+ State.UniformType = Type;
+ }
+ else if (State.UniformType != Type)
+ {
+ State.UniformType = CbFieldType::None;
+ }
+
+ State.Flags &= ~(StateFlags::Name | StateFlags::Field);
+ ++State.Count;
+ Data[State.Offset] = uint8_t(Type);
+}
+
+ZEN_NOINLINE
+CbWriter&
+CbWriter::SetName(const std::string_view Name)
+{
+ WriterState& State = States.back();
+ ZEN_ASSERT((State.Flags & StateFlags::Array) != StateFlags::Array);
+ // TEXT("It is invalid to write a name for an array field. Name '%.*hs'"),
+ // Name.Len(),
+ // Name.GetData());
+ ZEN_ASSERT(!Name.empty());
+ // TEXT("%s"),
+ //(State.Flags & EStateFlags::Object) == EStateFlags::Object
+ // ? TEXT("It is invalid to write an empty name for an object field. Specify a unique non-empty name.")
+ // : TEXT("It is invalid to write an empty name for a top-level field. Specify a name or avoid this call."));
+ ZEN_ASSERT((State.Flags & (StateFlags::Name | StateFlags::Field)) == StateFlags::None);
+ // TEXT("A new field '%.*hs' cannot be written until the previous field '%.*hs' is finished."),
+ // Name.Len(),
+ // Name.GetData(),
+ // GetActiveName().Len(),
+ // GetActiveName().GetData());
+
+ BeginField();
+ State.Flags |= StateFlags::Name;
+ const uint32_t NameLenByteCount = MeasureVarUInt(uint32_t(Name.size()));
+ const int64_t NameLenOffset = Data.size();
+ Data.resize(NameLenOffset + NameLenByteCount);
+
+ WriteVarUInt(uint64_t(Name.size()), Data.data() + NameLenOffset);
+
+ const uint8_t* NamePtr = reinterpret_cast<const uint8_t*>(Name.data());
+ Data.insert(Data.end(), NamePtr, NamePtr + Name.size());
+ return *this;
+}
+
+void
+CbWriter::SetNameOrAddString(const std::string_view NameOrValue)
+{
+ // A name is only written if it would begin a new field inside of an object.
+ if ((States.back().Flags & (StateFlags::Name | StateFlags::Field | StateFlags::Object)) == StateFlags::Object)
+ {
+ SetName(NameOrValue);
+ }
+ else
+ {
+ AddString(NameOrValue);
+ }
+}
+
+std::string_view
+CbWriter::GetActiveName() const
+{
+ const WriterState& State = States.back();
+ if ((State.Flags & StateFlags::Name) == StateFlags::Name)
+ {
+ const uint8_t* const EncodedName = Data.data() + State.Offset + sizeof(CbFieldType);
+ uint32_t NameLenByteCount;
+ const uint64_t NameLen = ReadVarUInt(EncodedName, NameLenByteCount);
+ const size_t ClampedNameLen = std::clamp<uint64_t>(NameLen, 0, ~uint64_t(0));
+ return std::string_view(reinterpret_cast<const char*>(EncodedName + NameLenByteCount), ClampedNameLen);
+ }
+ return std::string_view();
+}
+
+void
+CbWriter::MakeFieldsUniform(const int64_t FieldBeginOffset, const int64_t FieldEndOffset)
+{
+ MutableMemoryView SourceView(Data.data() + FieldBeginOffset, uint64_t(FieldEndOffset - FieldBeginOffset));
+ MutableMemoryView TargetView = SourceView;
+ TargetView.RightChopInline(sizeof(CbFieldType));
+
+ while (!SourceView.IsEmpty())
+ {
+ const uint64_t FieldSize = MeasureCompactBinary(SourceView) - sizeof(CbFieldType);
+ SourceView.RightChopInline(sizeof(CbFieldType));
+ if (TargetView.GetData() != SourceView.GetData())
+ {
+ memmove(TargetView.GetData(), SourceView.GetData(), FieldSize);
+ }
+ SourceView.RightChopInline(FieldSize);
+ TargetView.RightChopInline(FieldSize);
+ }
+
+ if (!TargetView.IsEmpty())
+ {
+ const auto EraseBegin = Data.begin() + (FieldEndOffset - TargetView.GetSize());
+ const auto EraseEnd = EraseBegin + TargetView.GetSize();
+
+ Data.erase(EraseBegin, EraseEnd);
+ }
+}
+
+void
+CbWriter::AddField(const CbFieldView& Value)
+{
+ ZEN_ASSERT(Value.HasValue()); // , TEXT("It is invalid to write a field with no value."));
+ BeginField();
+ EndField(AppendCompactBinary(Value, Data));
+}
+
+void
+CbWriter::AddField(const CbField& Value)
+{
+ AddField(CbFieldView(Value));
+}
+
+void
+CbWriter::BeginObject()
+{
+ BeginField();
+ States.push_back(WriterState());
+ States.back().Flags |= StateFlags::Object;
+}
+
+void
+CbWriter::EndObject()
+{
+ ZEN_ASSERT(States.size() > 1 && (States.back().Flags & StateFlags::Object) == StateFlags::Object);
+
+ // TEXT("It is invalid to end an object when an object is not at the top of the stack."));
+ ZEN_ASSERT((States.back().Flags & StateFlags::Field) == StateFlags::None);
+ // TEXT("It is invalid to end an object until the previous field is finished."));
+
+ const bool bUniform = IsUniformType(States.back().UniformType);
+ const uint64_t Count = States.back().Count;
+ States.pop_back();
+
+ // Calculate the offset of the payload.
+ const WriterState& State = States.back();
+ int64_t PayloadOffset = State.Offset + 1;
+ if ((State.Flags & StateFlags::Name) == StateFlags::Name)
+ {
+ uint32_t NameLenByteCount;
+ const uint64_t NameLen = ReadVarUInt(Data.data() + PayloadOffset, NameLenByteCount);
+ PayloadOffset += NameLen + NameLenByteCount;
+ }
+
+ // Remove redundant field types for uniform objects.
+ if (bUniform && Count > 1)
+ {
+ MakeFieldsUniform(PayloadOffset, Data.size());
+ }
+
+ // Insert the object size.
+ const uint64_t Size = uint64_t(Data.size() - PayloadOffset);
+ const uint32_t SizeByteCount = MeasureVarUInt(Size);
+ Data.insert(Data.begin() + PayloadOffset, SizeByteCount, 0);
+ WriteVarUInt(Size, Data.data() + PayloadOffset);
+
+ EndField(bUniform ? CbFieldType::UniformObject : CbFieldType::Object);
+}
+
+void
+CbWriter::AddObject(const CbObjectView& Value)
+{
+ BeginField();
+ EndField(AppendCompactBinary(Value.AsFieldView(), Data));
+}
+
+void
+CbWriter::AddObject(const CbObject& Value)
+{
+ AddObject(CbObjectView(Value));
+}
+
+ZEN_NOINLINE
+void
+CbWriter::BeginArray()
+{
+ BeginField();
+ States.push_back(WriterState());
+ States.back().Flags |= StateFlags::Array;
+}
+
+void
+CbWriter::EndArray()
+{
+ ZEN_ASSERT(States.size() > 1 && (States.back().Flags & StateFlags::Array) == StateFlags::Array);
+ // TEXT("Invalid attempt to end an array when an array is not at the top of the stack."));
+ ZEN_ASSERT((States.back().Flags & StateFlags::Field) == StateFlags::None);
+ // TEXT("It is invalid to end an array until the previous field is finished."));
+ const bool bUniform = IsUniformType(States.back().UniformType);
+ const uint64_t Count = States.back().Count;
+ States.pop_back();
+
+ // Calculate the offset of the payload.
+ const WriterState& State = States.back();
+ int64_t PayloadOffset = State.Offset + 1;
+ if ((State.Flags & StateFlags::Name) == StateFlags::Name)
+ {
+ uint32_t NameLenByteCount;
+ const uint64_t NameLen = ReadVarUInt(Data.data() + PayloadOffset, NameLenByteCount);
+ PayloadOffset += NameLen + NameLenByteCount;
+ }
+
+ // Remove redundant field types for uniform arrays.
+ if (bUniform && Count > 1)
+ {
+ MakeFieldsUniform(PayloadOffset, Data.size());
+ }
+
+ // Insert the array size and field count.
+ const uint32_t CountByteCount = MeasureVarUInt(Count);
+ const uint64_t Size = uint64_t(Data.size() - PayloadOffset) + CountByteCount;
+ const uint32_t SizeByteCount = MeasureVarUInt(Size);
+ Data.insert(Data.begin() + PayloadOffset, SizeByteCount + CountByteCount, 0);
+ WriteVarUInt(Size, Data.data() + PayloadOffset);
+ WriteVarUInt(Count, Data.data() + PayloadOffset + SizeByteCount);
+
+ EndField(bUniform ? CbFieldType::UniformArray : CbFieldType::Array);
+}
+
+void
+CbWriter::AddArray(const CbArrayView& Value)
+{
+ BeginField();
+ EndField(AppendCompactBinary(Value.AsFieldView(), Data));
+}
+
+void
+CbWriter::AddArray(const CbArray& Value)
+{
+ AddArray(CbArrayView(Value));
+}
+
+void
+CbWriter::AddNull()
+{
+ BeginField();
+ EndField(CbFieldType::Null);
+}
+
+void
+CbWriter::AddBinary(const void* const Value, const uint64_t Size)
+{
+ const size_t SizeByteCount = MeasureVarUInt(Size);
+ Data.reserve(Data.size() + 1 + SizeByteCount + Size);
+ BeginField();
+ const size_t SizeOffset = Data.size();
+ Data.resize(Data.size() + SizeByteCount);
+ WriteVarUInt(Size, Data.data() + SizeOffset);
+ Data.insert(Data.end(), static_cast<const uint8_t*>(Value), static_cast<const uint8_t*>(Value) + Size);
+ EndField(CbFieldType::Binary);
+}
+
+void
+CbWriter::AddBinary(IoBuffer Buffer)
+{
+ AddBinary(Buffer.Data(), Buffer.Size());
+}
+
+void
+CbWriter::AddBinary(SharedBuffer Buffer)
+{
+ AddBinary(Buffer.GetData(), Buffer.GetSize());
+}
+
+void
+CbWriter::AddBinary(const CompositeBuffer& Buffer)
+{
+ AddBinary(Buffer.Flatten());
+}
+
+void
+CbWriter::AddString(const std::string_view Value)
+{
+ BeginField();
+ const uint64_t Size = uint64_t(Value.size());
+ const uint32_t SizeByteCount = MeasureVarUInt(Size);
+ const int64_t Offset = Data.size();
+
+ Data.resize(Offset + SizeByteCount + Size);
+
+ uint8_t* StringData = Data.data() + Offset;
+ WriteVarUInt(Size, StringData);
+ StringData += SizeByteCount;
+ if (Size > 0)
+ {
+ memcpy(StringData, Value.data(), Value.size() * sizeof(char));
+ }
+ EndField(CbFieldType::String);
+}
+
+void
+CbWriter::AddString(const std::wstring_view Value)
+{
+ BeginField();
+ ExtendableStringBuilder<128> Utf8;
+ WideToUtf8(Value, Utf8);
+
+ const uint32_t Size = uint32_t(Utf8.Size());
+ const uint32_t SizeByteCount = MeasureVarUInt(Size);
+ const int64_t Offset = Data.size();
+ Data.resize(Offset + SizeByteCount + Size);
+ uint8_t* StringData = Data.data() + Offset;
+ WriteVarUInt(Size, StringData);
+ StringData += SizeByteCount;
+ if (Size > 0)
+ {
+ memcpy(reinterpret_cast<char*>(StringData), Utf8.Data(), Utf8.Size());
+ }
+ EndField(CbFieldType::String);
+}
+
+ZEN_NOINLINE
+void
+CbWriter::AddInteger(const int32_t Value)
+{
+ if (Value >= 0)
+ {
+ return AddInteger(uint32_t(Value));
+ }
+ BeginField();
+ const uint32_t Magnitude = ~uint32_t(Value);
+ const uint32_t MagnitudeByteCount = MeasureVarUInt(Magnitude);
+ const int64_t Offset = Data.size();
+ Data.resize(Offset + MagnitudeByteCount);
+ WriteVarUInt(Magnitude, Data.data() + Offset);
+ EndField(CbFieldType::IntegerNegative);
+}
+
+void
+CbWriter::AddInteger(const int64_t Value)
+{
+ if (Value >= 0)
+ {
+ return AddInteger(uint64_t(Value));
+ }
+ BeginField();
+ const uint64_t Magnitude = ~uint64_t(Value);
+ const uint32_t MagnitudeByteCount = MeasureVarUInt(Magnitude);
+ const uint64_t Offset = AddUninitialized(Data, MagnitudeByteCount);
+ WriteVarUInt(Magnitude, Data.data() + Offset);
+ EndField(CbFieldType::IntegerNegative);
+}
+
+ZEN_NOINLINE
+void
+CbWriter::AddInteger(const uint32_t Value)
+{
+ BeginField();
+ const uint32_t ValueByteCount = MeasureVarUInt(Value);
+ const uint64_t Offset = AddUninitialized(Data, ValueByteCount);
+ WriteVarUInt(Value, Data.data() + Offset);
+ EndField(CbFieldType::IntegerPositive);
+}
+
+ZEN_NOINLINE
+void
+CbWriter::AddInteger(const uint64_t Value)
+{
+ BeginField();
+ const uint32_t ValueByteCount = MeasureVarUInt(Value);
+ const uint64_t Offset = AddUninitialized(Data, ValueByteCount);
+ WriteVarUInt(Value, Data.data() + Offset);
+ EndField(CbFieldType::IntegerPositive);
+}
+
+ZEN_NOINLINE
+void
+CbWriter::AddFloat(const float Value)
+{
+ BeginField();
+ const uint32_t RawValue = FromNetworkOrder(reinterpret_cast<const uint32_t&>(Value));
+ Append(Data, reinterpret_cast<const uint8_t*>(&RawValue), sizeof(uint32_t));
+ EndField(CbFieldType::Float32);
+}
+
+ZEN_NOINLINE
+void
+CbWriter::AddFloat(const double Value)
+{
+ const float Value32 = float(Value);
+ if (Value == double(Value32))
+ {
+ return AddFloat(Value32);
+ }
+ BeginField();
+ const uint64_t RawValue = FromNetworkOrder(reinterpret_cast<const uint64_t&>(Value));
+ Append(Data, reinterpret_cast<const uint8_t*>(&RawValue), sizeof(uint64_t));
+ EndField(CbFieldType::Float64);
+}
+
+ZEN_NOINLINE
+void
+CbWriter::AddBool(const bool bValue)
+{
+ BeginField();
+ EndField(bValue ? CbFieldType::BoolTrue : CbFieldType::BoolFalse);
+}
+
+ZEN_NOINLINE
+void
+CbWriter::AddObjectAttachment(const IoHash& Value)
+{
+ BeginField();
+ Append(Data, Value.Hash, sizeof Value.Hash);
+ EndField(CbFieldType::ObjectAttachment);
+}
+
+ZEN_NOINLINE
+void
+CbWriter::AddBinaryAttachment(const IoHash& Value)
+{
+ BeginField();
+ Append(Data, Value.Hash, sizeof Value.Hash);
+ EndField(CbFieldType::BinaryAttachment);
+}
+
+ZEN_NOINLINE
+void
+CbWriter::AddAttachment(const CbAttachment& Attachment)
+{
+ BeginField();
+ const IoHash& Value = Attachment.GetHash();
+ Append(Data, Value.Hash, sizeof Value.Hash);
+ EndField(CbFieldType::BinaryAttachment);
+}
+
+ZEN_NOINLINE
+void
+CbWriter::AddHash(const IoHash& Value)
+{
+ BeginField();
+ Append(Data, Value.Hash, sizeof Value.Hash);
+ EndField(CbFieldType::Hash);
+}
+
+void
+CbWriter::AddUuid(const Guid& Value)
+{
+ const auto AppendSwappedBytes = [this](uint32_t In) {
+ In = FromNetworkOrder(In);
+ Append(Data, reinterpret_cast<const uint8_t*>(&In), sizeof In);
+ };
+ BeginField();
+ AppendSwappedBytes(Value.A);
+ AppendSwappedBytes(Value.B);
+ AppendSwappedBytes(Value.C);
+ AppendSwappedBytes(Value.D);
+ EndField(CbFieldType::Uuid);
+}
+
+void
+CbWriter::AddObjectId(const Oid& Value)
+{
+ BeginField();
+ Append(Data, reinterpret_cast<const uint8_t*>(&Value.OidBits), sizeof Value.OidBits);
+ EndField(CbFieldType::ObjectId);
+}
+
+void
+CbWriter::AddDateTimeTicks(const int64_t Ticks)
+{
+ BeginField();
+ const uint64_t RawValue = FromNetworkOrder(uint64_t(Ticks));
+ Append(Data, reinterpret_cast<const uint8_t*>(&RawValue), sizeof(uint64_t));
+ EndField(CbFieldType::DateTime);
+}
+
+void
+CbWriter::AddDateTime(const DateTime Value)
+{
+ AddDateTimeTicks(Value.GetTicks());
+}
+
+void
+CbWriter::AddTimeSpanTicks(const int64_t Ticks)
+{
+ BeginField();
+ const uint64_t RawValue = FromNetworkOrder(uint64_t(Ticks));
+ Append(Data, reinterpret_cast<const uint8_t*>(&RawValue), sizeof(uint64_t));
+ EndField(CbFieldType::TimeSpan);
+}
+
+void
+CbWriter::AddTimeSpan(const TimeSpan Value)
+{
+ AddTimeSpanTicks(Value.GetTicks());
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+CbWriter&
+operator<<(CbWriter& Writer, const DateTime Value)
+{
+ Writer.AddDateTime(Value);
+ return Writer;
+}
+
+CbWriter&
+operator<<(CbWriter& Writer, const TimeSpan Value)
+{
+ Writer.AddTimeSpan(Value);
+ return Writer;
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+void
+usonbuilder_forcelink()
+{
+}
+
+// doctest::String
+// toString(const DateTime&)
+// {
+// // TODO:implement
+// return "";
+// }
+
+// doctest::String
+// toString(const TimeSpan&)
+// {
+// // TODO:implement
+// return "";
+// }
+
+TEST_CASE("usonbuilder.object")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ SUBCASE("EmptyObject")
+ {
+ Writer.BeginObject();
+ Writer.EndObject();
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ CHECK(Field.IsObject() == true);
+ CHECK(Field.AsObjectView().CreateViewIterator().HasValue() == false);
+ }
+
+ SUBCASE("NamedEmptyObject")
+ {
+ Writer.SetName("Object"sv);
+ Writer.BeginObject();
+ Writer.EndObject();
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ CHECK(Field.IsObject() == true);
+ CHECK(Field.AsObjectView().CreateViewIterator().HasValue() == false);
+ }
+
+ SUBCASE("BasicObject")
+ {
+ Writer.BeginObject();
+ Writer.SetName("Integer"sv).AddInteger(0);
+ Writer.SetName("Float"sv).AddFloat(0.0f);
+ Writer.EndObject();
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ CHECK(Field.IsObject() == true);
+
+ CbObjectView Object = Field.AsObjectView();
+ CHECK(Object["Integer"sv].IsInteger() == true);
+ CHECK(Object["Float"sv].IsFloat() == true);
+ }
+
+ SUBCASE("UniformObject")
+ {
+ Writer.BeginObject();
+ Writer.SetName("Field1"sv).AddInteger(0);
+ Writer.SetName("Field2"sv).AddInteger(1);
+ Writer.EndObject();
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ CHECK(Field.IsObject() == true);
+
+ CbObjectView Object = Field.AsObjectView();
+ CHECK(Object["Field1"sv].IsInteger() == true);
+ CHECK(Object["Field2"sv].IsInteger() == true);
+ }
+}
+
+TEST_CASE("usonbuilder.array")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ SUBCASE("EmptyArray")
+ {
+ Writer.BeginArray();
+ Writer.EndArray();
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ CHECK(Field.IsArray() == true);
+ CHECK(Field.AsArrayView().Num() == 0);
+ }
+
+ SUBCASE("NamedEmptyArray")
+ {
+ Writer.SetName("Array"sv);
+ Writer.BeginArray();
+ Writer.EndArray();
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ CHECK(Field.IsArray() == true);
+ CHECK(Field.AsArrayView().Num() == 0);
+ }
+
+ SUBCASE("BasicArray")
+ {
+ Writer.BeginArray();
+ Writer.AddInteger(0);
+ Writer.AddFloat(0.0f);
+ Writer.EndArray();
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ CHECK(Field.IsArray() == true);
+ CbFieldViewIterator Iterator = Field.AsArrayView().CreateViewIterator();
+ CHECK(Iterator.IsInteger() == true);
+ ++Iterator;
+ CHECK(Iterator.IsFloat() == true);
+ ++Iterator;
+ CHECK(Iterator.HasValue() == false);
+ }
+
+ SUBCASE("UniformArray")
+ {
+ Writer.BeginArray();
+ Writer.AddInteger(0);
+ Writer.AddInteger(1);
+ Writer.EndArray();
+
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ CHECK(Field.IsArray() == true);
+ CbFieldViewIterator Iterator = Field.AsArrayView().CreateViewIterator();
+ CHECK(Iterator.IsInteger() == true);
+ ++Iterator;
+ CHECK(Iterator.IsInteger() == true);
+ ++Iterator;
+ CHECK(Iterator.HasValue() == false);
+ }
+}
+
+TEST_CASE("usonbuilder.null")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ SUBCASE("Null")
+ {
+ Writer.AddNull();
+ CbField Field = Writer.Save();
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ CHECK(Field.HasName() == false);
+ CHECK(Field.IsNull() == true);
+ }
+
+ SUBCASE("NullWithName")
+ {
+ Writer.SetName("Null"sv);
+ Writer.AddNull();
+ CbField Field = Writer.Save();
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ CHECK(Field.HasName() == true);
+ CHECK(Field.GetName().compare("Null"sv) == 0);
+ CHECK(Field.IsNull() == true);
+ }
+
+ SUBCASE("Null Array/Object Uniformity")
+ {
+ Writer.BeginArray();
+ Writer.AddNull();
+ Writer.AddNull();
+ Writer.AddNull();
+ Writer.EndArray();
+
+ Writer.BeginObject();
+ Writer.SetName("N1"sv).AddNull();
+ Writer.SetName("N2"sv).AddNull();
+ Writer.SetName("N3"sv).AddNull();
+ Writer.EndObject();
+
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ }
+
+ SUBCASE("Null with Save(Buffer)")
+ {
+ constexpr int NullCount = 3;
+ for (int Index = 0; Index < NullCount; ++Index)
+ {
+ Writer.AddNull();
+ }
+ uint8_t Buffer[NullCount]{};
+ MutableMemoryView BufferView(Buffer, sizeof Buffer);
+ CbFieldViewIterator Fields = Writer.Save(BufferView);
+
+ CHECK(ValidateCompactBinaryRange(BufferView, CbValidateMode::All) == CbValidateError::None);
+
+ for (int Index = 0; Index < NullCount; ++Index)
+ {
+ CHECK(Fields.IsNull() == true);
+ ++Fields;
+ }
+ CHECK(Fields.HasValue() == false);
+ }
+}
+
+TEST_CASE("usonbuilder.binary")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+}
+
+TEST_CASE("usonbuilder.string")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ SUBCASE("Empty Strings")
+ {
+ Writer.AddString(std::string_view());
+ Writer.AddString(std::wstring_view());
+
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ for (CbFieldView Field : Fields)
+ {
+ CHECK(Field.HasName() == false);
+ CHECK(Field.IsString() == true);
+ CHECK(Field.AsString().empty() == true);
+ }
+ }
+
+ SUBCASE("Test Basic Strings")
+ {
+ Writer.SetName("String"sv).AddString("Value"sv);
+ Writer.SetName("String"sv).AddString(L"Value"sv);
+
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ for (CbFieldView Field : Fields)
+ {
+ CHECK(Field.GetName().compare("String"sv) == 0);
+ CHECK(Field.HasName() == true);
+ CHECK(Field.IsString() == true);
+ CHECK(Field.AsString().compare("Value"sv) == 0);
+ }
+ }
+
+ SUBCASE("Long Strings")
+ {
+ constexpr int DotCount = 256;
+ StringBuilder<DotCount + 1> Dots;
+ for (int Index = 0; Index < DotCount; ++Index)
+ {
+ Dots.Append('.');
+ }
+ Writer.AddString(Dots);
+ Writer.AddString(std::wstring().append(256, L'.'));
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ for (CbFieldView Field : Fields)
+ {
+ CHECK((Field.AsString() == std::string_view(Dots)));
+ }
+ }
+
+ SUBCASE("Non-ASCII String")
+ {
+# if ZEN_SIZEOF_WCHAR_T == 2
+ wchar_t Value[2] = {0xd83d, 0xde00};
+# else
+ wchar_t Value[1] = {0x1f600};
+# endif
+
+ Writer.AddString("\xf0\x9f\x98\x80"sv);
+ Writer.AddString(std::wstring_view(Value, ZEN_ARRAY_COUNT(Value)));
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ for (CbFieldView Field : Fields)
+ {
+ CHECK((Field.AsString() == "\xf0\x9f\x98\x80"sv));
+ }
+ }
+}
+
+TEST_CASE("usonbuilder.integer")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ auto TestInt32 = [&Writer](int32_t Value) {
+ Writer.Reset();
+ Writer.AddInteger(Value);
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ CHECK(Field.AsInt32() == Value);
+ CHECK(Field.HasError() == false);
+ };
+
+ auto TestUInt32 = [&Writer](uint32_t Value) {
+ Writer.Reset();
+ Writer.AddInteger(Value);
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ CHECK(Field.AsUInt32() == Value);
+ CHECK(Field.HasError() == false);
+ };
+
+ auto TestInt64 = [&Writer](int64_t Value) {
+ Writer.Reset();
+ Writer.AddInteger(Value);
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ CHECK(Field.AsInt64() == Value);
+ CHECK(Field.HasError() == false);
+ };
+
+ auto TestUInt64 = [&Writer](uint64_t Value) {
+ Writer.Reset();
+ Writer.AddInteger(Value);
+ CbField Field = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Field.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ CHECK(Field.AsUInt64() == Value);
+ CHECK(Field.HasError() == false);
+ };
+
+ TestUInt32(uint32_t(0x00));
+ TestUInt32(uint32_t(0x7f));
+ TestUInt32(uint32_t(0x80));
+ TestUInt32(uint32_t(0xff));
+ TestUInt32(uint32_t(0x0100));
+ TestUInt32(uint32_t(0x7fff));
+ TestUInt32(uint32_t(0x8000));
+ TestUInt32(uint32_t(0xffff));
+ TestUInt32(uint32_t(0x0001'0000));
+ TestUInt32(uint32_t(0x7fff'ffff));
+ TestUInt32(uint32_t(0x8000'0000));
+ TestUInt32(uint32_t(0xffff'ffff));
+
+ TestUInt64(uint64_t(0x0000'0001'0000'0000));
+ TestUInt64(uint64_t(0x7fff'ffff'ffff'ffff));
+ TestUInt64(uint64_t(0x8000'0000'0000'0000));
+ TestUInt64(uint64_t(0xffff'ffff'ffff'ffff));
+
+ TestInt32(int32_t(0x01));
+ TestInt32(int32_t(0x80));
+ TestInt32(int32_t(0x81));
+ TestInt32(int32_t(0x8000));
+ TestInt32(int32_t(0x8001));
+ TestInt32(int32_t(0x7fff'ffff));
+ TestInt32(int32_t(0x8000'0000));
+ TestInt32(int32_t(0x8000'0001));
+
+ TestInt64(int64_t(0x0000'0001'0000'0000));
+ TestInt64(int64_t(0x8000'0000'0000'0000));
+ TestInt64(int64_t(0x7fff'ffff'ffff'ffff));
+ TestInt64(int64_t(0x8000'0000'0000'0001));
+ TestInt64(int64_t(0xffff'ffff'ffff'ffff));
+}
+
+TEST_CASE("usonbuilder.float")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ SUBCASE("Float32")
+ {
+ constexpr float Values[] = {
+ 0.0f,
+ 1.0f,
+ -1.0f,
+ 3.14159265358979323846f, // PI
+ 3.402823466e+38f, // FLT_MAX
+ 1.175494351e-38f // FLT_MIN
+ };
+
+ for (float Value : Values)
+ {
+ Writer.AddFloat(Value);
+ }
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ const float* CheckValue = Values;
+ for (CbFieldView Field : Fields)
+ {
+ CHECK(Field.AsFloat() == *CheckValue++);
+ CHECK(Field.HasError() == false);
+ }
+ }
+
+ SUBCASE("Float64")
+ {
+ constexpr double Values[] = {
+ 0.0f,
+ 1.0f,
+ -1.0f,
+ 3.14159265358979323846, // PI
+ 1.9999998807907104,
+ 1.9999999403953552,
+ 3.4028234663852886e38,
+ 6.8056469327705771e38,
+ 2.2250738585072014e-308, // DBL_MIN
+ 1.7976931348623158e+308 // DBL_MAX
+ };
+
+ for (double Value : Values)
+ {
+ Writer.AddFloat(Value);
+ }
+
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ const double* CheckValue = Values;
+ for (CbFieldView Field : Fields)
+ {
+ CHECK(Field.AsDouble() == *CheckValue++);
+ CHECK(Field.HasError() == false);
+ }
+ }
+}
+
+TEST_CASE("usonbuilder.bool")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ SUBCASE("Bool")
+ {
+ Writer.AddBool(true);
+ Writer.AddBool(false);
+
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ CHECK(Fields.AsBool() == true);
+ CHECK(Fields.HasError() == false);
+ ++Fields;
+ CHECK(Fields.AsBool() == false);
+ CHECK(Fields.HasError() == false);
+ ++Fields;
+ CHECK(Fields.HasValue() == false);
+ }
+
+ SUBCASE("Bool Array/Object Uniformity")
+ {
+ Writer.BeginArray();
+ Writer.AddBool(false);
+ Writer.AddBool(false);
+ Writer.AddBool(false);
+ Writer.EndArray();
+
+ Writer.BeginObject();
+ Writer.SetName("B1"sv).AddBool(false);
+ Writer.SetName("B2"sv).AddBool(false);
+ Writer.SetName("B3"sv).AddBool(false);
+ Writer.EndObject();
+
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ }
+}
+
+TEST_CASE("usonbuilder.usonattachment")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+}
+
+TEST_CASE("usonbuilder.binaryattachment")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+}
+
+TEST_CASE("usonbuilder.hash")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+}
+
+TEST_CASE("usonbuilder.uuid")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+}
+
+TEST_CASE("usonbuilder.datetime")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ const DateTime Values[] = {DateTime(0), DateTime(2020, 5, 13, 15, 10)};
+ for (DateTime Value : Values)
+ {
+ Writer.AddDateTime(Value);
+ }
+
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ const DateTime* CheckValue = Values;
+ for (CbFieldView Field : Fields)
+ {
+ CHECK(Field.AsDateTime() == *CheckValue++);
+ CHECK(Field.HasError() == false);
+ }
+}
+
+TEST_CASE("usonbuilder.timespan")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ const TimeSpan Values[] = {TimeSpan(0), TimeSpan(1, 2, 4, 8)};
+ for (TimeSpan Value : Values)
+ {
+ Writer.AddTimeSpan(Value);
+ }
+
+ CbFieldIterator Fields = Writer.Save();
+
+ CHECK(ValidateCompactBinary(Fields.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+
+ const TimeSpan* CheckValue = Values;
+ for (CbFieldView Field : Fields)
+ {
+ CHECK(Field.AsTimeSpan() == *CheckValue++);
+ CHECK(Field.HasError() == false);
+ }
+}
+
+TEST_CASE("usonbuilder.complex")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ SUBCASE("complex")
+ {
+ CbObject Object;
+
+ {
+ Writer.BeginObject();
+
+ const uint8_t LocalField[] = {uint8_t(CbFieldType::IntegerPositive | CbFieldType::HasFieldName), 1, 'I', 42};
+ Writer.AddField("FieldCopy"sv, CbFieldView(LocalField));
+ Writer.AddField("FieldRefCopy"sv, CbField(SharedBuffer::Clone(MakeMemoryView(LocalField))));
+
+ const uint8_t LocalObject[] = {uint8_t(CbFieldType::Object | CbFieldType::HasFieldName),
+ 1,
+ 'O',
+ 7,
+ uint8_t(CbFieldType::IntegerPositive | CbFieldType::HasFieldName),
+ 1,
+ 'I',
+ 42,
+ uint8_t(CbFieldType::Null | CbFieldType::HasFieldName),
+ 1,
+ 'N'};
+ Writer.AddObject("ObjectCopy"sv, CbObjectView(LocalObject));
+ Writer.AddObject("ObjectRefCopy"sv, CbObject(SharedBuffer::Clone(MakeMemoryView(LocalObject))));
+
+ const uint8_t LocalArray[] = {uint8_t(CbFieldType::UniformArray | CbFieldType::HasFieldName),
+ 1,
+ 'A',
+ 4,
+ 2,
+ uint8_t(CbFieldType::IntegerPositive),
+ 42,
+ 21};
+ Writer.AddArray("ArrayCopy"sv, CbArrayView(LocalArray));
+ Writer.AddArray("ArrayRefCopy"sv, CbArray(SharedBuffer::Clone(MakeMemoryView(LocalArray))));
+
+ Writer.AddNull("Null"sv);
+
+ Writer.BeginObject("Binary"sv);
+ {
+ Writer.AddBinary("Empty"sv, MemoryView());
+ Writer.AddBinary("Value"sv, MakeMemoryView("BinaryValue"));
+ Writer.AddBinary("LargeValue"sv, MakeMemoryView(std::wstring().append(256, L'.')));
+ Writer.AddBinary("LargeRefValue"sv, SharedBuffer::Clone(MakeMemoryView(std::wstring().append(256, L'!'))));
+ }
+ Writer.EndObject();
+
+ Writer.BeginObject("Strings"sv);
+ {
+ Writer.AddString("AnsiString"sv, "AnsiValue"sv);
+ Writer.AddString("WideString"sv, std::wstring().append(256, L'.'));
+ Writer.AddString("EmptyAnsiString"sv, std::string_view());
+ Writer.AddString("EmptyWideString"sv, std::wstring_view());
+ Writer.AddString("AnsiStringLiteral", "AnsiValue");
+ Writer.AddString("WideStringLiteral", L"AnsiValue");
+ }
+ Writer.EndObject();
+
+ Writer.BeginArray("Integers"sv);
+ {
+ Writer.AddInteger(int32_t(-1));
+ Writer.AddInteger(int64_t(-1));
+ Writer.AddInteger(uint32_t(1));
+ Writer.AddInteger(uint64_t(1));
+ Writer.AddInteger(std::numeric_limits<int32_t>::min());
+ Writer.AddInteger(std::numeric_limits<int32_t>::max());
+ Writer.AddInteger(std::numeric_limits<uint32_t>::max());
+ Writer.AddInteger(std::numeric_limits<int64_t>::min());
+ Writer.AddInteger(std::numeric_limits<int64_t>::max());
+ Writer.AddInteger(std::numeric_limits<uint64_t>::max());
+ }
+ Writer.EndArray();
+
+ Writer.BeginArray("UniformIntegers"sv);
+ {
+ Writer.AddInteger(0);
+ Writer.AddInteger(std::numeric_limits<int32_t>::max());
+ Writer.AddInteger(std::numeric_limits<uint32_t>::max());
+ Writer.AddInteger(std::numeric_limits<int64_t>::max());
+ Writer.AddInteger(std::numeric_limits<uint64_t>::max());
+ }
+ Writer.EndArray();
+
+ Writer.AddFloat("Float32"sv, 1.0f);
+ Writer.AddFloat("Float64as32"sv, 2.0);
+ Writer.AddFloat("Float64"sv, 3.0e100);
+
+ Writer.AddBool("False"sv, false);
+ Writer.AddBool("True"sv, true);
+
+ Writer.AddObjectAttachment("ObjectAttachment"sv, IoHash());
+ Writer.AddBinaryAttachment("BinaryAttachment"sv, IoHash());
+ Writer.AddAttachment("Attachment"sv, CbAttachment());
+
+ Writer.AddHash("Hash"sv, IoHash());
+ Writer.AddUuid("Uuid"sv, Guid());
+
+ Writer.AddDateTimeTicks("DateTimeZero"sv, 0);
+ Writer.AddDateTime("DateTime2020"sv, DateTime(2020, 5, 13, 15, 10));
+
+ Writer.AddTimeSpanTicks("TimeSpanZero"sv, 0);
+ Writer.AddTimeSpan("TimeSpan"sv, TimeSpan(1, 2, 4, 8));
+
+ Writer.BeginObject("NestedObjects"sv);
+ {
+ Writer.BeginObject("Empty"sv);
+ Writer.EndObject();
+
+ Writer.BeginObject("Null"sv);
+ Writer.AddNull("Null"sv);
+ Writer.EndObject();
+ }
+ Writer.EndObject();
+
+ Writer.BeginArray("NestedArrays"sv);
+ {
+ Writer.BeginArray();
+ Writer.EndArray();
+
+ Writer.BeginArray();
+ Writer.AddNull();
+ Writer.AddNull();
+ Writer.AddNull();
+ Writer.EndArray();
+
+ Writer.BeginArray();
+ Writer.AddBool(false);
+ Writer.AddBool(false);
+ Writer.AddBool(false);
+ Writer.EndArray();
+
+ Writer.BeginArray();
+ Writer.AddBool(true);
+ Writer.AddBool(true);
+ Writer.AddBool(true);
+ Writer.EndArray();
+ }
+ Writer.EndArray();
+
+ Writer.BeginArray("ArrayOfObjects"sv);
+ {
+ Writer.BeginObject();
+ Writer.EndObject();
+
+ Writer.BeginObject();
+ Writer.AddNull("Null"sv);
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+
+ Writer.BeginArray("LargeArray"sv);
+ for (int Index = 0; Index < 256; ++Index)
+ {
+ Writer.AddInteger(Index - 128);
+ }
+ Writer.EndArray();
+
+ Writer.BeginArray("LargeUniformArray"sv);
+ for (int Index = 0; Index < 256; ++Index)
+ {
+ Writer.AddInteger(Index);
+ }
+ Writer.EndArray();
+
+ Writer.BeginArray("NestedUniformArray"sv);
+ for (int Index = 0; Index < 16; ++Index)
+ {
+ Writer.BeginArray();
+ for (int Value = 0; Value < 4; ++Value)
+ {
+ Writer.AddInteger(Value);
+ }
+ Writer.EndArray();
+ }
+ Writer.EndArray();
+
+ Writer.EndObject();
+ Object = Writer.Save().AsObject();
+ }
+ CHECK(ValidateCompactBinary(Object.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ }
+}
+
+TEST_CASE("usonbuilder.stream")
+{
+ using namespace std::literals;
+
+ FixedCbWriter<256> Writer;
+
+ SUBCASE("basic")
+ {
+ CbObject Object;
+ {
+ Writer.BeginObject();
+
+ const uint8_t LocalField[] = {uint8_t(CbFieldType::IntegerPositive | CbFieldType::HasFieldName), 1, 'I', 42};
+ Writer << "FieldCopy"sv << CbFieldView(LocalField);
+
+ const uint8_t LocalObject[] = {uint8_t(CbFieldType::Object | CbFieldType::HasFieldName),
+ 1,
+ 'O',
+ 7,
+ uint8_t(CbFieldType::IntegerPositive | CbFieldType::HasFieldName),
+ 1,
+ 'I',
+ 42,
+ uint8_t(CbFieldType::Null | CbFieldType::HasFieldName),
+ 1,
+ 'N'};
+ Writer << "ObjectCopy"sv << CbObjectView(LocalObject);
+
+ const uint8_t LocalArray[] = {uint8_t(CbFieldType::UniformArray | CbFieldType::HasFieldName),
+ 1,
+ 'A',
+ 4,
+ 2,
+ uint8_t(CbFieldType::IntegerPositive),
+ 42,
+ 21};
+ Writer << "ArrayCopy"sv << CbArrayView(LocalArray);
+
+ Writer << "Null"sv << nullptr;
+
+ Writer << "Strings"sv;
+ Writer.BeginObject();
+ Writer << "AnsiString"sv
+ << "AnsiValue"sv
+ << "AnsiStringLiteral"sv
+ << "AnsiValue"
+ << "WideString"sv << L"WideValue"sv << "WideStringLiteral"sv << L"WideValue";
+ Writer.EndObject();
+
+ Writer << "Integers"sv;
+ Writer.BeginArray();
+ Writer << int32_t(-1) << int64_t(-1) << uint32_t(1) << uint64_t(1);
+ Writer.EndArray();
+
+ Writer << "Float32"sv << 1.0f;
+ Writer << "Float64"sv << 2.0;
+
+ Writer << "False"sv << false << "True"sv << true;
+
+ Writer << "Attachment"sv << CbAttachment();
+
+ Writer << "Hash"sv << IoHash();
+ Writer << "Uuid"sv << Guid();
+
+ Writer << "DateTime"sv << DateTime(2020, 5, 13, 15, 10);
+ Writer << "TimeSpan"sv << TimeSpan(1, 2, 4, 8);
+
+ Writer << "LiteralName" << nullptr;
+
+ Writer.EndObject();
+ Object = Writer.Save().AsObject();
+ }
+
+ CHECK(ValidateCompactBinary(Object.GetBuffer(), CbValidateMode::All) == CbValidateError::None);
+ }
+}
+#endif
+
+} // namespace zen
diff --git a/src/zencore/compactbinarypackage.cpp b/src/zencore/compactbinarypackage.cpp
new file mode 100644
index 000000000..a4fa38a1d
--- /dev/null
+++ b/src/zencore/compactbinarypackage.cpp
@@ -0,0 +1,1350 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencore/compactbinarypackage.h"
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/endian.h>
+#include <zencore/stream.h>
+#include <zencore/testing.h>
+
+namespace zen {
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+CbAttachment::CbAttachment(const CompressedBuffer& InValue, const IoHash& Hash) : CbAttachment(InValue.MakeOwned(), Hash)
+{
+}
+
+CbAttachment::CbAttachment(const SharedBuffer& InValue) : CbAttachment(CompositeBuffer(InValue))
+{
+}
+
+CbAttachment::CbAttachment(const SharedBuffer& InValue, const IoHash& InHash) : CbAttachment(CompositeBuffer(InValue), InHash)
+{
+}
+
+CbAttachment::CbAttachment(const CompositeBuffer& InValue)
+: Hash(InValue.IsNull() ? IoHash::Zero : IoHash::HashBuffer(InValue))
+, Value(InValue)
+{
+ if (std::get<CompositeBuffer>(Value).IsNull())
+ {
+ Value.emplace<std::nullptr_t>();
+ }
+}
+
+CbAttachment::CbAttachment(CompositeBuffer&& InValue)
+: Hash(InValue.IsNull() ? IoHash::Zero : IoHash::HashBuffer(InValue))
+, Value(std::move(InValue))
+
+{
+ if (std::get<CompositeBuffer>(Value).IsNull())
+ {
+ Value.emplace<std::nullptr_t>();
+ }
+}
+
+CbAttachment::CbAttachment(CompositeBuffer&& InValue, const IoHash& InHash) : Hash(InHash), Value(InValue)
+{
+ if (std::get<CompositeBuffer>(Value).IsNull())
+ {
+ Value.emplace<std::nullptr_t>();
+ }
+}
+
+CbAttachment::CbAttachment(CompressedBuffer&& InValue, const IoHash& InHash) : Hash(InHash), Value(InValue)
+{
+ if (std::get<CompressedBuffer>(Value).IsNull())
+ {
+ Value.emplace<std::nullptr_t>();
+ }
+}
+
+CbAttachment::CbAttachment(const CbObject& InValue, const IoHash* const InHash)
+{
+ auto SetValue = [&](const CbObject& ValueToSet) {
+ if (InHash)
+ {
+ Value.emplace<CbObject>(ValueToSet);
+ Hash = *InHash;
+ }
+ else
+ {
+ Value.emplace<CbObject>(ValueToSet);
+ Hash = ValueToSet.GetHash();
+ }
+ };
+
+ MemoryView View;
+ if (!InValue.IsOwned() || !InValue.TryGetSerializedView(View))
+ {
+ SetValue(CbObject::Clone(InValue));
+ }
+ else
+ {
+ SetValue(InValue);
+ }
+}
+
+bool
+CbAttachment::TryLoad(IoBuffer& InBuffer, BufferAllocator Allocator)
+{
+ BinaryReader Reader(InBuffer.Data(), InBuffer.Size());
+
+ return TryLoad(Reader, Allocator);
+}
+
+bool
+CbAttachment::TryLoad(CbFieldIterator& Fields)
+{
+ if (const CbObjectView ObjectView = Fields.AsObjectView(); !Fields.HasError())
+ {
+ // Is a null object or object not prefixed with a precomputed hash value
+ Value.emplace<CbObject>(CbObject(ObjectView, Fields.GetOuterBuffer()));
+ Hash = ObjectView.GetHash();
+ ++Fields;
+ }
+ else if (const IoHash ObjectAttachmentHash = Fields.AsObjectAttachment(); !Fields.HasError())
+ {
+ // Is an object
+ ++Fields;
+ const CbObjectView InnerObjectView = Fields.AsObjectView();
+ if (Fields.HasError())
+ {
+ return false;
+ }
+ Value.emplace<CbObject>(CbObject(InnerObjectView, Fields.GetOuterBuffer()));
+ Hash = ObjectAttachmentHash;
+ ++Fields;
+ }
+ else if (const IoHash BinaryAttachmentHash = Fields.AsBinaryAttachment(); !Fields.HasError())
+ {
+ // Is an uncompressed binary blob
+ ++Fields;
+ MemoryView BinaryView = Fields.AsBinaryView();
+ if (Fields.HasError())
+ {
+ return false;
+ }
+ Value.emplace<CompositeBuffer>(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer()));
+ Hash = BinaryAttachmentHash;
+ ++Fields;
+ }
+ else if (MemoryView BinaryView = Fields.AsBinaryView(); !Fields.HasError())
+ {
+ if (BinaryView.GetSize() > 0)
+ {
+ // Is a compressed binary blob
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed =
+ CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer()), RawHash, RawSize).MakeOwned();
+ Value.emplace<CompressedBuffer>(Compressed);
+ Hash = RawHash;
+ ++Fields;
+ }
+ else
+ {
+ // Is an uncompressed empty binary blob
+ Value.emplace<CompositeBuffer>(SharedBuffer::MakeView(BinaryView, Fields.GetOuterBuffer()));
+ Hash = IoHash::HashBuffer(nullptr, 0);
+ ++Fields;
+ }
+ }
+ else
+ {
+ return false;
+ }
+
+ return true;
+}
+
+static bool
+TryLoad_ArchiveFieldIntoAttachment(CbAttachment& TargetAttachment, CbField&& Field, BinaryReader& Reader, BufferAllocator Allocator)
+{
+ if (const CbObjectView ObjectView = Field.AsObjectView(); !Field.HasError())
+ {
+ // Is a null object or object not prefixed with a precomputed hash value
+ TargetAttachment = CbAttachment(CbObject(ObjectView, std::move(Field)), ObjectView.GetHash());
+ }
+ else if (const IoHash ObjectAttachmentHash = Field.AsObjectAttachment(); !Field.HasError())
+ {
+ // Is an object
+ Field = LoadCompactBinary(Reader, Allocator);
+ if (!Field.IsObject())
+ {
+ return false;
+ }
+ TargetAttachment = CbAttachment(std::move(Field).AsObject(), ObjectAttachmentHash);
+ }
+ else if (const IoHash BinaryAttachmentHash = Field.AsBinaryAttachment(); !Field.HasError())
+ {
+ // Is an uncompressed binary blob
+ Field = LoadCompactBinary(Reader, Allocator);
+ SharedBuffer Buffer = Field.AsBinary();
+ if (Field.HasError())
+ {
+ return false;
+ }
+ TargetAttachment = CbAttachment(CompositeBuffer(Buffer), BinaryAttachmentHash);
+ }
+ else if (SharedBuffer Buffer = Field.AsBinary(); !Field.HasError())
+ {
+ if (Buffer.GetSize() > 0)
+ {
+ // Is a compressed binary blob
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(std::move(Buffer), RawHash, RawSize);
+ TargetAttachment = CbAttachment(Compressed, RawHash);
+ }
+ else
+ {
+ // Is an uncompressed empty binary blob
+ TargetAttachment = CbAttachment(CompositeBuffer(Buffer), IoHash::HashBuffer(nullptr, 0));
+ }
+ }
+ else
+ {
+ return false;
+ }
+
+ return true;
+}
+
+bool
+CbAttachment::TryLoad(BinaryReader& Reader, BufferAllocator Allocator)
+{
+ CbField Field = LoadCompactBinary(Reader, Allocator);
+ return TryLoad_ArchiveFieldIntoAttachment(*this, std::move(Field), Reader, Allocator);
+}
+
+void
+CbAttachment::Save(CbWriter& Writer) const
+{
+ if (const CbObject* Object = std::get_if<CbObject>(&Value))
+ {
+ if (*Object)
+ {
+ Writer.AddObjectAttachment(Hash);
+ }
+ Writer.AddObject(*Object);
+ }
+ else if (const CompositeBuffer* Binary = std::get_if<CompositeBuffer>(&Value))
+ {
+ if (Binary->GetSize() > 0)
+ {
+ Writer.AddBinaryAttachment(Hash);
+ }
+ Writer.AddBinary(*Binary);
+ }
+ else if (const CompressedBuffer* Compressed = std::get_if<CompressedBuffer>(&Value))
+ {
+ Writer.AddBinary(Compressed->GetCompressed());
+ }
+}
+
+void
+CbAttachment::Save(BinaryWriter& Writer) const
+{
+ CbWriter TempWriter;
+ Save(TempWriter);
+ TempWriter.Save(Writer);
+}
+
+bool
+CbAttachment::IsNull() const
+{
+ return std::holds_alternative<std::nullptr_t>(Value);
+}
+
+bool
+CbAttachment::IsBinary() const
+{
+ return std::holds_alternative<CompositeBuffer>(Value);
+}
+
+bool
+CbAttachment::IsCompressedBinary() const
+{
+ return std::holds_alternative<CompressedBuffer>(Value);
+}
+
+bool
+CbAttachment::IsObject() const
+{
+ return std::holds_alternative<CbObject>(Value);
+}
+
+IoHash
+CbAttachment::GetHash() const
+{
+ return Hash;
+}
+
+CompositeBuffer
+CbAttachment::AsCompositeBinary() const
+{
+ if (const CompositeBuffer* BinValue = std::get_if<CompositeBuffer>(&Value))
+ {
+ return *BinValue;
+ }
+
+ return CompositeBuffer::Null;
+}
+
+SharedBuffer
+CbAttachment::AsBinary() const
+{
+ if (const CompositeBuffer* BinValue = std::get_if<CompositeBuffer>(&Value))
+ {
+ return BinValue->Flatten();
+ }
+
+ return {};
+}
+
+CompressedBuffer
+CbAttachment::AsCompressedBinary() const
+{
+ if (const CompressedBuffer* CompValue = std::get_if<CompressedBuffer>(&Value))
+ {
+ return *CompValue;
+ }
+
+ return CompressedBuffer::Null;
+}
+
+/** Access the attachment as compact binary. Defaults to a field iterator with no value on error. */
+CbObject
+CbAttachment::AsObject() const
+{
+ if (const CbObject* ObjectValue = std::get_if<CbObject>(&Value))
+ {
+ return *ObjectValue;
+ }
+
+ return {};
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+void
+CbPackage::SetObject(CbObject InObject, const IoHash* InObjectHash, AttachmentResolver* InResolver)
+{
+ if (InObject)
+ {
+ Object = InObject.IsOwned() ? std::move(InObject) : CbObject::Clone(InObject);
+ if (InObjectHash)
+ {
+ ObjectHash = *InObjectHash;
+ ZEN_ASSERT_SLOW(ObjectHash == Object.GetHash());
+ }
+ else
+ {
+ ObjectHash = Object.GetHash();
+ }
+ if (InResolver)
+ {
+ GatherAttachments(Object, *InResolver);
+ }
+ }
+ else
+ {
+ Object.Reset();
+ ObjectHash = IoHash::Zero;
+ }
+}
+
+void
+CbPackage::AddAttachment(const CbAttachment& Attachment, AttachmentResolver* Resolver)
+{
+ if (!Attachment.IsNull())
+ {
+ auto It = std::lower_bound(begin(Attachments), end(Attachments), Attachment);
+ if (It != Attachments.end() && *It == Attachment)
+ {
+ CbAttachment& Existing = *It;
+ Existing = Attachment;
+ }
+ else
+ {
+ Attachments.insert(It, Attachment);
+ }
+
+ if (Attachment.IsObject() && Resolver)
+ {
+ GatherAttachments(Attachment.AsObject(), *Resolver);
+ }
+ }
+}
+
+void
+CbPackage::AddAttachments(std::span<const CbAttachment> InAttachments)
+{
+ if (InAttachments.empty())
+ {
+ return;
+ }
+ // Assume we have no duplicates!
+ Attachments.insert(Attachments.end(), InAttachments.begin(), InAttachments.end());
+ std::sort(Attachments.begin(), Attachments.end());
+ ZEN_ASSERT_SLOW(std::unique(Attachments.begin(), Attachments.end()) == Attachments.end());
+}
+
+int32_t
+CbPackage::RemoveAttachment(const IoHash& Hash)
+{
+ return gsl::narrow_cast<int32_t>(
+ std::erase_if(Attachments, [&Hash](const CbAttachment& Attachment) -> bool { return Attachment.GetHash() == Hash; }));
+}
+
+bool
+CbPackage::Equals(const CbPackage& Package) const
+{
+ return ObjectHash == Package.ObjectHash && Attachments == Package.Attachments;
+}
+
+const CbAttachment*
+CbPackage::FindAttachment(const IoHash& Hash) const
+{
+ auto It = std::find_if(begin(Attachments), end(Attachments), [&Hash](const CbAttachment& Attachment) -> bool {
+ return Attachment.GetHash() == Hash;
+ });
+
+ if (It == end(Attachments))
+ return nullptr;
+
+ return &*It;
+}
+
+void
+CbPackage::GatherAttachments(const CbObject& Value, AttachmentResolver Resolver)
+{
+ Value.IterateAttachments([this, &Resolver](CbFieldView Field) {
+ const IoHash& Hash = Field.AsAttachment();
+
+ if (SharedBuffer Buffer = Resolver(Hash))
+ {
+ if (Field.IsObjectAttachment())
+ {
+ AddAttachment(CbAttachment(CbObject(std::move(Buffer)), Hash), &Resolver);
+ }
+ else
+ {
+ AddAttachment(CbAttachment(std::move(Buffer)));
+ }
+ }
+ });
+}
+
+bool
+CbPackage::TryLoad(IoBuffer InBuffer, BufferAllocator Allocator, AttachmentResolver* Mapper)
+{
+ BinaryReader Reader(InBuffer.Data(), InBuffer.Size());
+
+ return TryLoad(Reader, Allocator, Mapper);
+}
+
+bool
+CbPackage::TryLoad(CbFieldIterator& Fields)
+{
+ *this = CbPackage();
+
+ while (Fields)
+ {
+ if (Fields.IsNull())
+ {
+ ++Fields;
+ break;
+ }
+ else if (IoHash Hash = Fields.AsHash(); !Fields.HasError() && !Fields.IsAttachment())
+ {
+ ++Fields;
+ CbObjectView ObjectView = Fields.AsObjectView();
+ if (Fields.HasError() || Hash != ObjectView.GetHash())
+ {
+ return false;
+ }
+ Object = CbObject(ObjectView, Fields.GetOuterBuffer());
+ Object.MakeOwned();
+ ObjectHash = Hash;
+ ++Fields;
+ }
+ else
+ {
+ CbAttachment Attachment;
+ if (!Attachment.TryLoad(Fields))
+ {
+ return false;
+ }
+ AddAttachment(Attachment);
+ }
+ }
+ return true;
+}
+
+bool
+CbPackage::TryLoad(BinaryReader& Reader, BufferAllocator Allocator, AttachmentResolver* Mapper)
+{
+ // TODO: this needs to re-grow the ability to accept a reference to an attachment which is
+ // not embedded
+
+ ZEN_UNUSED(Mapper);
+
+#if 1
+ *this = CbPackage();
+ for (;;)
+ {
+ CbField Field = LoadCompactBinary(Reader, Allocator);
+ if (!Field)
+ {
+ return false;
+ }
+
+ if (Field.IsNull())
+ {
+ return true;
+ }
+ else if (IoHash Hash = Field.AsHash(); !Field.HasError() && !Field.IsAttachment())
+ {
+ Field = LoadCompactBinary(Reader, Allocator);
+ CbObjectView ObjectView = Field.AsObjectView();
+ if (Field.HasError() || Hash != ObjectView.GetHash())
+ {
+ return false;
+ }
+ Object = CbObject(ObjectView, Field.GetOuterBuffer());
+ ObjectHash = Hash;
+ }
+ else
+ {
+ CbAttachment Attachment;
+ if (!TryLoad_ArchiveFieldIntoAttachment(Attachment, std::move(Field), Reader, Allocator))
+ {
+ return false;
+ }
+ AddAttachment(Attachment);
+ }
+ }
+#else
+ uint8_t StackBuffer[64];
+ const auto StackAllocator = [&Allocator, &StackBuffer](uint64_t Size) -> UniqueBuffer {
+ if (Size <= sizeof(StackBuffer))
+ {
+ return UniqueBuffer::MakeMutableView(StackBuffer, Size);
+ }
+
+ return Allocator(Size);
+ };
+
+ *this = CbPackage();
+
+ for (;;)
+ {
+ CbField ValueField = LoadCompactBinary(Reader, StackAllocator);
+ if (!ValueField)
+ {
+ return false;
+ }
+ if (ValueField.IsNull())
+ {
+ return true;
+ }
+ else if (ValueField.IsBinary())
+ {
+ const MemoryView View = ValueField.AsBinaryView();
+ if (View.GetSize() > 0)
+ {
+ SharedBuffer Buffer = SharedBuffer::MakeView(View, ValueField.GetOuterBuffer()).MakeOwned();
+ CbField HashField = LoadCompactBinary(Reader, StackAllocator);
+ const IoHash& Hash = HashField.AsAttachment();
+ ZEN_ASSERT(!HashField.HasError(), "Attachments must be a non-empty binary value with a content hash.");
+ if (HashField.IsObjectAttachment())
+ {
+ AddAttachment(CbAttachment(CbObject(std::move(Buffer)), Hash));
+ }
+ else
+ {
+ AddAttachment(CbAttachment(std::move(Buffer), Hash));
+ }
+ }
+ }
+ else if (ValueField.IsHash())
+ {
+ const IoHash Hash = ValueField.AsHash();
+
+ ZEN_ASSERT(Mapper);
+
+ AddAttachment(CbAttachment((*Mapper)(Hash), Hash));
+ }
+ else
+ {
+ Object = ValueField.AsObject();
+ if (ValueField.HasError())
+ {
+ return false;
+ }
+ Object.MakeOwned();
+ if (Object)
+ {
+ CbField HashField = LoadCompactBinary(Reader, StackAllocator);
+ ObjectHash = HashField.AsObjectAttachment();
+ if (HashField.HasError() || Object.GetHash() != ObjectHash)
+ {
+ return false;
+ }
+ }
+ else
+ {
+ Object.Reset();
+ }
+ }
+ }
+#endif
+}
+
+void
+CbPackage::Save(CbWriter& Writer) const
+{
+ if (Object)
+ {
+ Writer.AddHash(ObjectHash);
+ Writer.AddObject(Object);
+ }
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ Attachment.Save(Writer);
+ }
+ Writer.AddNull();
+}
+
+void
+CbPackage::Save(BinaryWriter& StreamWriter) const
+{
+ CbWriter Writer;
+ Save(Writer);
+ Writer.Save(StreamWriter);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Legacy package serialization support
+//
+
+namespace legacy {
+
+ void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer)
+ {
+ if (Attachment.IsObject())
+ {
+ CbObject Object = Attachment.AsObject();
+ Writer.AddBinary(Object.GetBuffer());
+ if (Object)
+ {
+ Writer.AddObjectAttachment(Attachment.GetHash());
+ }
+ }
+ else if (Attachment.IsBinary())
+ {
+ Writer.AddBinary(Attachment.AsBinary());
+ Writer.AddBinaryAttachment(Attachment.GetHash());
+ }
+ else if (Attachment.IsCompressedBinary())
+ {
+ Writer.AddBinary(Attachment.AsCompressedBinary().GetCompressed());
+ Writer.AddBinaryAttachment(Attachment.GetHash());
+ }
+ else if (Attachment.IsNull())
+ {
+ Writer.AddBinary(MemoryView());
+ }
+ else
+ {
+ ZEN_NOT_IMPLEMENTED("Compressed binary is not supported in this serialization format");
+ }
+ }
+
+ void SaveCbPackage(const CbPackage& Package, CbWriter& Writer)
+ {
+ if (const CbObject& RootObject = Package.GetObject())
+ {
+ Writer.AddObject(RootObject);
+ Writer.AddObjectAttachment(Package.GetObjectHash());
+ }
+ for (const CbAttachment& Attachment : Package.GetAttachments())
+ {
+ SaveCbAttachment(Attachment, Writer);
+ }
+ Writer.AddNull();
+ }
+
+ void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar)
+ {
+ CbWriter Writer;
+ SaveCbPackage(Package, Writer);
+ Writer.Save(Ar);
+ }
+
+ bool TryLoadCbPackage(CbPackage& Package, IoBuffer InBuffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper)
+ {
+ BinaryReader Reader(InBuffer.Data(), InBuffer.Size());
+
+ return TryLoadCbPackage(Package, Reader, Allocator, Mapper);
+ }
+
+ bool TryLoadCbPackage(CbPackage& Package, BinaryReader& Reader, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper)
+ {
+ Package = CbPackage();
+ for (;;)
+ {
+ CbField ValueField = LoadCompactBinary(Reader, Allocator);
+ if (!ValueField)
+ {
+ return false;
+ }
+ if (ValueField.IsNull())
+ {
+ return true;
+ }
+ if (ValueField.IsBinary())
+ {
+ const MemoryView View = ValueField.AsBinaryView();
+ if (View.GetSize() > 0)
+ {
+ SharedBuffer Buffer = SharedBuffer::MakeView(View, ValueField.GetOuterBuffer()).MakeOwned();
+ CbField HashField = LoadCompactBinary(Reader, Allocator);
+ const IoHash& Hash = HashField.AsAttachment();
+ if (HashField.HasError())
+ {
+ return false;
+ }
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer Compressed = CompressedBuffer::FromCompressed(Buffer, RawHash, RawSize))
+ {
+ if (RawHash != Hash)
+ {
+ return false;
+ }
+ Package.AddAttachment(CbAttachment(Compressed, Hash));
+ }
+ else
+ {
+ if (IoHash::HashBuffer(Buffer) != Hash)
+ {
+ return false;
+ }
+ if (HashField.IsObjectAttachment())
+ {
+ Package.AddAttachment(CbAttachment(CbObject(std::move(Buffer)), Hash));
+ }
+ else
+ {
+ Package.AddAttachment(CbAttachment(CompositeBuffer(std::move(Buffer)), Hash));
+ }
+ }
+ }
+ }
+ else if (ValueField.IsHash())
+ {
+ const IoHash Hash = ValueField.AsHash();
+
+ ZEN_ASSERT(Mapper);
+ if (SharedBuffer AttachmentData = (*Mapper)(Hash))
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer Compressed = CompressedBuffer::FromCompressed(AttachmentData, RawHash, RawSize))
+ {
+ if (RawHash != Hash)
+ {
+ return false;
+ }
+ Package.AddAttachment(CbAttachment(Compressed, Hash));
+ }
+ else
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(AttachmentData.GetView(), CbValidateMode::All);
+ if (ValidationResult != CbValidateError::None)
+ {
+ return false;
+ }
+ Package.AddAttachment(CbAttachment(CbObject(std::move(AttachmentData)), Hash));
+ }
+ }
+ }
+ else
+ {
+ CbObject Object = ValueField.AsObject();
+ if (ValueField.HasError())
+ {
+ return false;
+ }
+
+ if (Object)
+ {
+ CbField HashField = LoadCompactBinary(Reader, Allocator);
+ IoHash ObjectHash = HashField.AsObjectAttachment();
+ if (HashField.HasError() || Object.GetHash() != ObjectHash)
+ {
+ return false;
+ }
+ Package.SetObject(Object, ObjectHash);
+ }
+ }
+ }
+ }
+
+} // namespace legacy
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+void
+usonpackage_forcelink()
+{
+}
+
+TEST_CASE("usonpackage")
+{
+ using namespace std::literals;
+
+ const auto TestSaveLoadValidate = [&](const char* Test, const CbAttachment& Attachment) {
+ ZEN_UNUSED(Test);
+
+ CbWriter Writer;
+ Attachment.Save(Writer);
+ CbFieldIterator Fields = Writer.Save();
+
+ BinaryWriter StreamWriter;
+ Attachment.Save(StreamWriter);
+
+ CHECK(MakeMemoryView(StreamWriter).EqualBytes(Fields.GetRangeBuffer().GetView()));
+ CHECK(ValidateCompactBinaryRange(MakeMemoryView(StreamWriter), CbValidateMode::All) == CbValidateError::None);
+ CHECK(ValidateObjectAttachment(MakeMemoryView(StreamWriter), CbValidateMode::All) == CbValidateError::None);
+
+ CbAttachment FromFields;
+ FromFields.TryLoad(Fields);
+ CHECK(!bool(Fields));
+ CHECK(FromFields == Attachment);
+
+ CbAttachment FromArchive;
+ BinaryReader Reader(MakeMemoryView(StreamWriter));
+ FromArchive.TryLoad(Reader);
+ CHECK(Reader.CurrentOffset() == Reader.Size());
+ CHECK(FromArchive == Attachment);
+ };
+
+ SUBCASE("Empty Attachment")
+ {
+ CbAttachment Attachment;
+ CHECK(Attachment.IsNull());
+ CHECK_FALSE(bool(Attachment));
+ CHECK_FALSE(bool(Attachment.AsBinary()));
+ CHECK_FALSE(bool(Attachment.AsObject()));
+ CHECK_FALSE(Attachment.IsBinary());
+ CHECK_FALSE(Attachment.IsObject());
+ CHECK(Attachment.GetHash() == IoHash::Zero);
+ }
+
+ SUBCASE("Binary Attachment")
+ {
+ const SharedBuffer Buffer = SharedBuffer::Clone(MakeMemoryView<uint8_t>({0, 1, 2, 3}));
+ CbAttachment Attachment(Buffer);
+ CHECK_FALSE(Attachment.IsNull());
+ CHECK(bool(Attachment));
+ CHECK(Attachment.AsBinary() == Buffer);
+ CHECK_FALSE(bool(Attachment.AsObject()));
+ CHECK(Attachment.IsBinary());
+ CHECK_FALSE(Attachment.IsObject());
+ CHECK(Attachment.GetHash() == IoHash::HashBuffer(Buffer));
+ TestSaveLoadValidate("Binary", Attachment);
+ }
+
+ SUBCASE("Object Attachment")
+ {
+ CbWriter Writer;
+ Writer.BeginObject();
+ Writer << "Name"sv << 42;
+ Writer.EndObject();
+ CbObject Object = Writer.Save().AsObject();
+ CbAttachment Attachment(Object);
+
+ CHECK_FALSE(Attachment.IsNull());
+ CHECK(bool(Attachment));
+ CHECK(Attachment.AsBinary() == SharedBuffer());
+ CHECK(Attachment.AsObject().Equals(Object));
+ CHECK_FALSE(Attachment.IsBinary());
+ CHECK(Attachment.IsObject());
+ CHECK(Attachment.GetHash() == Object.GetHash());
+ TestSaveLoadValidate("Object", Attachment);
+ }
+
+ SUBCASE("Binary View")
+ {
+ const uint8_t Value[]{0, 1, 2, 3};
+ SharedBuffer Buffer = SharedBuffer::MakeView(MakeMemoryView(Value));
+ CbAttachment Attachment(Buffer);
+ CHECK_FALSE(Attachment.IsNull());
+ CHECK(bool(Attachment));
+ CHECK(Attachment.AsBinary().GetView().EqualBytes(Buffer.GetView()));
+ CHECK_FALSE(bool(Attachment.AsObject()));
+ CHECK(Attachment.IsBinary());
+ CHECK_FALSE(Attachment.IsObject());
+ CHECK(Attachment.GetHash() == IoHash::HashBuffer(Buffer));
+ }
+
+ SUBCASE("Object View")
+ {
+ CbWriter Writer;
+ Writer.BeginObject();
+ Writer << "Name"sv << 42;
+ Writer.EndObject();
+ CbObject Object = Writer.Save().AsObject();
+ CbObject ObjectView = CbObject::MakeView(Object);
+ CbAttachment Attachment(ObjectView);
+
+ CHECK_FALSE(Attachment.IsNull());
+ CHECK(bool(Attachment));
+
+ CHECK(Attachment.AsBinary() != ObjectView.GetBuffer());
+ CHECK(Attachment.AsObject().Equals(Object));
+ CHECK_FALSE(Attachment.IsBinary());
+ CHECK(Attachment.IsObject());
+ CHECK(Attachment.GetHash() == IoHash(Object.GetHash()));
+ }
+
+ SUBCASE("Binary Load from View")
+ {
+ const uint8_t Value[]{0, 1, 2, 3};
+ const SharedBuffer Buffer = SharedBuffer::MakeView(MakeMemoryView(Value));
+ CbAttachment Attachment(Buffer);
+
+ CbWriter Writer;
+ Attachment.Save(Writer);
+ CbFieldIterator Fields = Writer.Save();
+ CbFieldIterator FieldsView = CbFieldIterator::MakeRangeView(CbFieldViewIterator(Fields));
+ Attachment.TryLoad(FieldsView);
+
+ CHECK_FALSE(Attachment.IsNull());
+ CHECK(bool(Attachment));
+ CHECK_FALSE(FieldsView.GetRangeBuffer().GetView().Contains(Attachment.AsBinary().GetView()));
+ CHECK(Attachment.AsBinary().GetView().EqualBytes(Buffer.GetView()));
+ CHECK_FALSE(Attachment.AsObject());
+ CHECK(Attachment.IsBinary());
+ CHECK_FALSE(Attachment.IsObject());
+ CHECK(Attachment.GetHash() == IoHash::HashBuffer(MakeMemoryView(Value)));
+ }
+
+ SUBCASE("Object Load from View")
+ {
+ CbWriter ValueWriter;
+ ValueWriter.BeginObject();
+ ValueWriter << "Name"sv << 42;
+ ValueWriter.EndObject();
+ const CbObject Value = ValueWriter.Save().AsObject();
+
+ CHECK(ValidateCompactBinaryRange(Value.GetView(), CbValidateMode::All) == CbValidateError::None);
+ CbAttachment Attachment(Value);
+
+ CbWriter Writer;
+ Attachment.Save(Writer);
+ CbFieldIterator Fields = Writer.Save();
+ CbFieldIterator FieldsView = CbFieldIterator::MakeRangeView(CbFieldViewIterator(Fields));
+
+ Attachment.TryLoad(FieldsView);
+ MemoryView View;
+
+ CHECK_FALSE(Attachment.IsNull());
+ CHECK(bool(Attachment));
+ CHECK(Attachment.AsBinary().GetView().EqualBytes(MemoryView()));
+ CHECK_FALSE((!Attachment.AsObject().TryGetSerializedView(View) || FieldsView.GetOuterBuffer().GetView().Contains(View)));
+ CHECK_FALSE(Attachment.IsBinary());
+ CHECK(Attachment.IsObject());
+ CHECK(Attachment.GetHash() == Value.GetHash());
+ }
+
+ SUBCASE("Binary Null")
+ {
+ const CbAttachment Attachment(SharedBuffer{});
+
+ CHECK(Attachment.IsNull());
+ CHECK_FALSE(Attachment.IsBinary());
+ CHECK_FALSE(Attachment.IsObject());
+ CHECK(Attachment.GetHash() == IoHash::Zero);
+ }
+
+ SUBCASE("Binary Empty")
+ {
+ const CbAttachment Attachment(UniqueBuffer::Alloc(0).MoveToShared());
+
+ CHECK_FALSE(Attachment.IsNull());
+ CHECK(Attachment.IsBinary());
+ CHECK_FALSE(Attachment.IsObject());
+ CHECK(Attachment.GetHash() == IoHash::HashBuffer(SharedBuffer{}));
+ }
+
+ SUBCASE("Object Empty")
+ {
+ const CbAttachment Attachment(CbObject{});
+
+ CHECK_FALSE(Attachment.IsNull());
+ CHECK_FALSE(Attachment.IsBinary());
+ CHECK(Attachment.IsObject());
+ CHECK(Attachment.GetHash() == CbObject().GetHash());
+ }
+}
+
+TEST_CASE("usonpackage.serialization")
+{
+ using namespace std::literals;
+
+ const auto TestSaveLoadValidate = [&](const char* Test, CbPackage& InOutPackage) {
+ ZEN_UNUSED(Test);
+
+ CbWriter Writer;
+ InOutPackage.Save(Writer);
+ CbFieldIterator Fields = Writer.Save();
+
+ BinaryWriter MemStream;
+ InOutPackage.Save(MemStream);
+
+ CHECK(MakeMemoryView(MemStream).EqualBytes(Fields.GetRangeBuffer().GetView()));
+ CHECK(ValidateCompactBinaryRange(MakeMemoryView(MemStream), CbValidateMode::All) == CbValidateError::None);
+ CHECK(ValidateCompactBinaryPackage(MakeMemoryView(MemStream), CbValidateMode::All) == CbValidateError::None);
+
+ CbPackage FromFields;
+ FromFields.TryLoad(Fields);
+ CHECK_FALSE(bool(Fields));
+ CHECK(FromFields == InOutPackage);
+
+ CbPackage FromArchive;
+ BinaryReader ReadAr(MakeMemoryView(MemStream));
+ FromArchive.TryLoad(ReadAr);
+ CHECK(ReadAr.CurrentOffset() == ReadAr.Size());
+ CHECK(FromArchive == InOutPackage);
+ InOutPackage = FromArchive;
+ };
+
+ SUBCASE("Empty")
+ {
+ CbPackage Package;
+ CHECK(Package.IsNull());
+ CHECK_FALSE(bool(Package));
+ CHECK(Package.GetAttachments().size() == 0);
+ TestSaveLoadValidate("Empty", Package);
+ }
+
+ SUBCASE("Object Only")
+ {
+ CbWriter Writer;
+ Writer.BeginObject();
+ Writer << "Field" << 42;
+ Writer.EndObject();
+
+ const CbObject Object = Writer.Save().AsObject();
+ CbPackage Package(Object);
+ CHECK_FALSE(Package.IsNull());
+ CHECK(bool(Package));
+ CHECK(Package.GetAttachments().size() == 0);
+ CHECK(Package.GetObject().GetOuterBuffer() == Object.GetOuterBuffer());
+ CHECK(Package.GetObject()["Field"].AsInt32() == 42);
+ CHECK(Package.GetObjectHash() == Package.GetObject().GetHash());
+ TestSaveLoadValidate("Object", Package);
+ }
+
+ // Object View Only
+ {
+ CbWriter Writer;
+ Writer.BeginObject();
+ Writer << "Field" << 42;
+ Writer.EndObject();
+
+ const CbObject Object = Writer.Save().AsObject();
+ CbPackage Package(CbObject::MakeView(Object));
+ CHECK_FALSE(Package.IsNull());
+ CHECK(bool(Package));
+ CHECK(Package.GetAttachments().size() == 0);
+ CHECK(Package.GetObject().GetOuterBuffer() != Object.GetOuterBuffer());
+ CHECK(Package.GetObject()["Field"].AsInt32() == 42);
+ CHECK(Package.GetObjectHash() == Package.GetObject().GetHash());
+ TestSaveLoadValidate("Object", Package);
+ }
+
+ // Attachment Only
+ {
+ CbObject Object1;
+ {
+ CbWriter Writer;
+ Writer.BeginObject();
+ Writer << "Field1" << 42;
+ Writer.EndObject();
+ Object1 = Writer.Save().AsObject();
+ }
+ CbObject Object2;
+ {
+ CbWriter Writer;
+ Writer.BeginObject();
+ Writer << "Field2" << 42;
+ Writer.EndObject();
+ Object2 = Writer.Save().AsObject();
+ }
+
+ CbPackage Package;
+ Package.AddAttachment(CbAttachment(Object1));
+ Package.AddAttachment(CbAttachment(Object2.GetBuffer()));
+
+ CHECK_FALSE(Package.IsNull());
+ CHECK(bool(Package));
+ CHECK(Package.GetAttachments().size() == 2);
+ CHECK(Package.GetObject().Equals(CbObject()));
+ CHECK(Package.GetObjectHash() == IoHash::Zero);
+ TestSaveLoadValidate("Attachments", Package);
+
+ const CbAttachment* const Object1Attachment = Package.FindAttachment(Object1.GetHash());
+ const CbAttachment* const Object2Attachment = Package.FindAttachment(Object2.GetHash());
+
+ CHECK((Object1Attachment && Object1Attachment->AsObject().Equals(Object1)));
+ CHECK((Object2Attachment && Object2Attachment->AsBinary().GetView().EqualBytes(Object2.GetBuffer().GetView())));
+
+ SharedBuffer Object1ClonedBuffer = SharedBuffer::Clone(Object1.GetOuterBuffer());
+ Package.AddAttachment(CbAttachment(Object1ClonedBuffer));
+ Package.AddAttachment(CbAttachment(CbObject::Clone(Object2)));
+
+ CHECK(Package.GetAttachments().size() == 2);
+ CHECK(Package.FindAttachment(Object1.GetHash()) == Object1Attachment);
+ CHECK(Package.FindAttachment(Object2.GetHash()) == Object2Attachment);
+
+ CHECK((Object1Attachment && Object1Attachment->AsBinary() == Object1ClonedBuffer));
+ CHECK((Object2Attachment && Object2Attachment->AsObject().Equals(Object2)));
+
+ CHECK(std::is_sorted(begin(Package.GetAttachments()), end(Package.GetAttachments())));
+ }
+
+ // Shared Values
+ const uint8_t Level4Values[]{0, 1, 2, 3};
+ SharedBuffer Level4 = SharedBuffer::MakeView(MakeMemoryView(Level4Values));
+ const IoHash Level4Hash = IoHash::HashBuffer(Level4);
+
+ CbObject Level3;
+ {
+ CbWriter Writer;
+ Writer.BeginObject();
+ Writer.AddBinaryAttachment("Level4", Level4Hash);
+ Writer.EndObject();
+ Level3 = Writer.Save().AsObject();
+ }
+ const IoHash Level3Hash = Level3.GetHash();
+
+ CbObject Level2;
+ {
+ CbWriter Writer;
+ Writer.BeginObject();
+ Writer.AddObjectAttachment("Level3", Level3Hash);
+ Writer.EndObject();
+ Level2 = Writer.Save().AsObject();
+ }
+ const IoHash Level2Hash = Level2.GetHash();
+
+ CbObject Level1;
+ {
+ CbWriter Writer;
+ Writer.BeginObject();
+ Writer.AddObjectAttachment("Level2", Level2Hash);
+ Writer.EndObject();
+ Level1 = Writer.Save().AsObject();
+ }
+ const IoHash Level1Hash = Level1.GetHash();
+
+ const auto Resolver = [&Level2, &Level2Hash, &Level3, &Level3Hash, &Level4, &Level4Hash](const IoHash& Hash) -> SharedBuffer {
+ return Hash == Level2Hash ? Level2.GetOuterBuffer()
+ : Hash == Level3Hash ? Level3.GetOuterBuffer()
+ : Hash == Level4Hash ? Level4
+ : SharedBuffer();
+ };
+
+ // Object + Attachments
+ {
+ CbPackage Package;
+ Package.SetObject(Level1, Level1Hash, Resolver);
+
+ CHECK_FALSE(Package.IsNull());
+ CHECK(bool(Package));
+ CHECK(Package.GetAttachments().size() == 3);
+ CHECK(Package.GetObject().GetBuffer() == Level1.GetBuffer());
+ CHECK(Package.GetObjectHash() == Level1Hash);
+ TestSaveLoadValidate("Object+Attachments", Package);
+
+ const CbAttachment* const Level2Attachment = Package.FindAttachment(Level2Hash);
+ const CbAttachment* const Level3Attachment = Package.FindAttachment(Level3Hash);
+ const CbAttachment* const Level4Attachment = Package.FindAttachment(Level4Hash);
+ CHECK((Level2Attachment && Level2Attachment->AsObject().Equals(Level2)));
+ CHECK((Level3Attachment && Level3Attachment->AsObject().Equals(Level3)));
+ REQUIRE(Level4Attachment);
+ CHECK(Level4Attachment->AsBinary() != Level4);
+ CHECK(Level4Attachment->AsBinary().GetView().EqualBytes(Level4.GetView()));
+
+ CHECK(std::is_sorted(begin(Package.GetAttachments()), end(Package.GetAttachments())));
+
+ const CbPackage PackageCopy = Package;
+ CHECK(PackageCopy == Package);
+
+ CHECK(Package.RemoveAttachment(Level1Hash) == 0);
+ CHECK(Package.RemoveAttachment(Level2Hash) == 1);
+ CHECK(Package.RemoveAttachment(Level3Hash) == 1);
+ CHECK(Package.RemoveAttachment(Level4Hash) == 1);
+ CHECK(Package.RemoveAttachment(Level4Hash) == 0);
+ CHECK(Package.GetAttachments().size() == 0);
+
+ CHECK(PackageCopy != Package);
+ Package = PackageCopy;
+ CHECK(PackageCopy == Package);
+ Package.SetObject(CbObject());
+ CHECK(PackageCopy != Package);
+ CHECK(Package.GetObjectHash() == IoHash());
+ }
+
+ // Out of Order
+ {
+ CbWriter Writer;
+ CbAttachment Attachment2(Level2, Level2Hash);
+ Attachment2.Save(Writer);
+ CbAttachment Attachment4(Level4);
+ Attachment4.Save(Writer);
+ Writer.AddHash(Level1Hash);
+ Writer.AddObject(Level1);
+ CbAttachment Attachment3(Level3, Level3Hash);
+ Attachment3.Save(Writer);
+ Writer.AddNull();
+
+ CbFieldIterator Fields = Writer.Save();
+ CbPackage FromFields;
+ FromFields.TryLoad(Fields);
+
+ const CbAttachment* const Level2Attachment = FromFields.FindAttachment(Level2Hash);
+ REQUIRE(Level2Attachment);
+ const CbAttachment* const Level3Attachment = FromFields.FindAttachment(Level3Hash);
+ REQUIRE(Level3Attachment);
+ const CbAttachment* const Level4Attachment = FromFields.FindAttachment(Level4Hash);
+ REQUIRE(Level4Attachment);
+
+ CHECK(FromFields.GetObject().Equals(Level1));
+ CHECK(FromFields.GetObject().GetOuterBuffer() == Fields.GetOuterBuffer());
+ CHECK(FromFields.GetObjectHash() == Level1Hash);
+
+ const MemoryView FieldsOuterBufferView = Fields.GetOuterBuffer().GetView();
+
+ CHECK(Level2Attachment->AsObject().Equals(Level2));
+ CHECK(Level2Attachment->GetHash() == Level2Hash);
+
+ CHECK(Level3Attachment->AsObject().Equals(Level3));
+ CHECK(Level3Attachment->GetHash() == Level3Hash);
+
+ CHECK(Level4Attachment->AsBinary().GetView().EqualBytes(Level4.GetView()));
+ CHECK(FieldsOuterBufferView.Contains(Level4Attachment->AsBinary().GetView()));
+ CHECK(Level4Attachment->GetHash() == Level4Hash);
+
+ BinaryWriter WriteStream;
+ Writer.Save(WriteStream);
+ CbPackage FromArchive;
+ BinaryReader ReadAr(MakeMemoryView(WriteStream));
+ FromArchive.TryLoad(ReadAr);
+
+ Writer.Reset();
+ FromArchive.Save(Writer);
+ CbFieldIterator Saved = Writer.Save();
+
+ CHECK(Saved.AsHash() == Level1Hash);
+ ++Saved;
+ CHECK(Saved.AsObject().Equals(Level1));
+ ++Saved;
+ CHECK_EQ(Saved.AsObjectAttachment(), Level2Hash);
+ ++Saved;
+ CHECK(Saved.AsObject().Equals(Level2));
+ ++Saved;
+ CHECK_EQ(Saved.AsObjectAttachment(), Level3Hash);
+ ++Saved;
+ CHECK(Saved.AsObject().Equals(Level3));
+ ++Saved;
+ CHECK_EQ(Saved.AsBinaryAttachment(), Level4Hash);
+ ++Saved;
+ SharedBuffer SavedLevel4Buffer = SharedBuffer::MakeView(Saved.AsBinaryView());
+ CHECK(SavedLevel4Buffer.GetView().EqualBytes(Level4.GetView()));
+ ++Saved;
+ CHECK(Saved.IsNull());
+ ++Saved;
+ CHECK(!Saved);
+ }
+
+ // Null Attachment
+ {
+ CbAttachment NullAttachment;
+ CbPackage Package;
+ Package.AddAttachment(NullAttachment);
+ CHECK(Package.IsNull());
+ CHECK_FALSE(bool(Package));
+ CHECK(Package.GetAttachments().size() == 0);
+ CHECK_FALSE(Package.FindAttachment(NullAttachment));
+ }
+
+ // Resolve After Merge
+ {
+ bool bResolved = false;
+ CbPackage Package;
+ Package.AddAttachment(CbAttachment(Level3.GetBuffer()));
+ Package.AddAttachment(CbAttachment(Level3), [&bResolved](const IoHash& Hash) -> SharedBuffer {
+ ZEN_UNUSED(Hash);
+ bResolved = true;
+ return SharedBuffer();
+ });
+ CHECK(bResolved);
+ }
+}
+
+TEST_CASE("usonpackage.invalidpackage")
+{
+ const auto TestLoad = [](std::initializer_list<uint8_t> RawData, BufferAllocator Allocator = UniqueBuffer::Alloc) {
+ const MemoryView RawView = MakeMemoryView(RawData);
+ CbPackage FromArchive;
+ BinaryReader ReadAr(RawView);
+ CHECK_FALSE(FromArchive.TryLoad(ReadAr, Allocator));
+ };
+ const auto AllocFail = [](uint64_t) -> UniqueBuffer {
+ FAIL_CHECK("Allocation is not expected");
+ return UniqueBuffer();
+ };
+ SUBCASE("Empty") { TestLoad({}, AllocFail); }
+ SUBCASE("Invalid Initial Field")
+ {
+ TestLoad({uint8_t(CbFieldType::None)});
+ TestLoad({uint8_t(CbFieldType::Array)});
+ TestLoad({uint8_t(CbFieldType::UniformArray)});
+ TestLoad({uint8_t(CbFieldType::Binary)});
+ TestLoad({uint8_t(CbFieldType::String)});
+ TestLoad({uint8_t(CbFieldType::IntegerPositive)});
+ TestLoad({uint8_t(CbFieldType::IntegerNegative)});
+ TestLoad({uint8_t(CbFieldType::Float32)});
+ TestLoad({uint8_t(CbFieldType::Float64)});
+ TestLoad({uint8_t(CbFieldType::BoolFalse)});
+ TestLoad({uint8_t(CbFieldType::BoolTrue)});
+ TestLoad({uint8_t(CbFieldType::ObjectAttachment)});
+ TestLoad({uint8_t(CbFieldType::BinaryAttachment)});
+ TestLoad({uint8_t(CbFieldType::Uuid)});
+ TestLoad({uint8_t(CbFieldType::DateTime)});
+ TestLoad({uint8_t(CbFieldType::TimeSpan)});
+ TestLoad({uint8_t(CbFieldType::ObjectId)});
+ TestLoad({uint8_t(CbFieldType::CustomById)});
+ TestLoad({uint8_t(CbFieldType::CustomByName)});
+ }
+ SUBCASE("Size Out Of Bounds")
+ {
+ TestLoad({uint8_t(CbFieldType::Object), 1}, AllocFail);
+ TestLoad({uint8_t(CbFieldType::Object), 0xff, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0}, AllocFail);
+ }
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/compactbinaryvalidation.cpp b/src/zencore/compactbinaryvalidation.cpp
new file mode 100644
index 000000000..02148d96a
--- /dev/null
+++ b/src/zencore/compactbinaryvalidation.cpp
@@ -0,0 +1,664 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencore/compactbinaryvalidation.h"
+
+#include <zencore/compactbinarypackage.h>
+#include <zencore/endian.h>
+#include <zencore/memory.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+
+#include <algorithm>
+
+namespace zen {
+
+namespace CbValidationPrivate {
+
+ template<typename T>
+ static constexpr inline T ReadUnaligned(const void* const Memory)
+ {
+#if ZEN_PLATFORM_SUPPORTS_UNALIGNED_LOADS
+ return *static_cast<const T*>(Memory);
+#else
+ T Value;
+ memcpy(&Value, Memory, sizeof(Value));
+ return Value;
+#endif
+ }
+
+} // namespace CbValidationPrivate
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * Adds the given error(s) to the error mask.
+ *
+ * This function exists to make validation errors easier to debug by providing one location to set a breakpoint.
+ */
+ZEN_NOINLINE static void
+AddError(CbValidateError& OutError, const CbValidateError InError)
+{
+ OutError |= InError;
+}
+
+/**
+ * Validate and read a field type from the view.
+ *
+ * A type argument with the HasFieldType flag indicates that the type will not be read from the view.
+ */
+static CbFieldType
+ValidateCbFieldType(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, CbFieldType Type = CbFieldType::HasFieldType)
+{
+ ZEN_UNUSED(Mode);
+ if (CbFieldTypeOps::HasFieldType(Type))
+ {
+ if (View.GetSize() >= 1)
+ {
+ Type = *static_cast<const CbFieldType*>(View.GetData());
+ View += 1;
+ if (CbFieldTypeOps::HasFieldType(Type))
+ {
+ AddError(Error, CbValidateError::InvalidType);
+ }
+ }
+ else
+ {
+ AddError(Error, CbValidateError::OutOfBounds);
+ View.Reset();
+ return CbFieldType::None;
+ }
+ }
+
+ if (CbFieldTypeOps::GetSerializedType(Type) != Type)
+ {
+ AddError(Error, CbValidateError::InvalidType);
+ View.Reset();
+ }
+
+ return Type;
+}
+
+/**
+ * Validate and read an unsigned integer from the view.
+ *
+ * Modifies the view to start at the end of the value, and adds error flags if applicable.
+ */
+static uint64_t
+ValidateCbUInt(MemoryView& View, CbValidateMode Mode, CbValidateError& Error)
+{
+ if (View.GetSize() > 0 && View.GetSize() >= MeasureVarUInt(View.GetData()))
+ {
+ uint32_t ValueByteCount;
+ const uint64_t Value = ReadVarUInt(View.GetData(), ValueByteCount);
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Format) && ValueByteCount > MeasureVarUInt(Value))
+ {
+ AddError(Error, CbValidateError::InvalidInteger);
+ }
+ View += ValueByteCount;
+ return Value;
+ }
+ else
+ {
+ AddError(Error, CbValidateError::OutOfBounds);
+ View.Reset();
+ return 0;
+ }
+}
+
+/**
+ * Validate a 64-bit floating point value from the view.
+ *
+ * Modifies the view to start at the end of the value, and adds error flags if applicable.
+ */
+static void
+ValidateCbFloat64(MemoryView& View, CbValidateMode Mode, CbValidateError& Error)
+{
+ if (View.GetSize() >= sizeof(double))
+ {
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Format))
+ {
+ const uint64_t RawValue = FromNetworkOrder(CbValidationPrivate::ReadUnaligned<uint64_t>(View.GetData()));
+ const double Value = reinterpret_cast<const double&>(RawValue);
+ if (Value == double(float(Value)))
+ {
+ AddError(Error, CbValidateError::InvalidFloat);
+ }
+ }
+ View += sizeof(double);
+ }
+ else
+ {
+ AddError(Error, CbValidateError::OutOfBounds);
+ View.Reset();
+ }
+}
+
+/**
+ * Validate and read a string from the view.
+ *
+ * Modifies the view to start at the end of the string, and adds error flags if applicable.
+ */
+static std::string_view
+ValidateCbString(MemoryView& View, CbValidateMode Mode, CbValidateError& Error)
+{
+ const uint64_t NameSize = ValidateCbUInt(View, Mode, Error);
+ if (View.GetSize() >= NameSize)
+ {
+ const std::string_view Name(static_cast<const char*>(View.GetData()), static_cast<int32_t>(NameSize));
+ View += NameSize;
+ return Name;
+ }
+ else
+ {
+ AddError(Error, CbValidateError::OutOfBounds);
+ View.Reset();
+ return std::string_view();
+ }
+}
+
+static CbFieldView ValidateCbField(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, CbFieldType ExternalType);
+
+/** A type that checks whether all validated fields are of the same type. */
+class CbUniformFieldsValidator
+{
+public:
+ inline explicit CbUniformFieldsValidator(CbFieldType InExternalType) : ExternalType(InExternalType) {}
+
+ inline CbFieldView ValidateField(MemoryView& View, CbValidateMode Mode, CbValidateError& Error)
+ {
+ const void* const FieldData = View.GetData();
+ if (CbFieldView Field = ValidateCbField(View, Mode, Error, ExternalType))
+ {
+ ++FieldCount;
+ if (CbFieldTypeOps::HasFieldType(ExternalType))
+ {
+ const CbFieldType FieldType = *static_cast<const CbFieldType*>(FieldData);
+ if (FieldCount == 1)
+ {
+ FirstType = FieldType;
+ }
+ else if (FieldType != FirstType)
+ {
+ bUniform = false;
+ }
+ }
+ return Field;
+ }
+
+ // It may not safe to check for uniformity if the field was invalid.
+ bUniform = false;
+ return CbFieldView();
+ }
+
+ inline bool IsUniform() const { return FieldCount > 0 && bUniform; }
+
+private:
+ uint32_t FieldCount = 0;
+ bool bUniform = true;
+ CbFieldType FirstType = CbFieldType::None;
+ CbFieldType ExternalType;
+};
+
+static void
+ValidateCbObject(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, CbFieldType ObjectType)
+{
+ const uint64_t Size = ValidateCbUInt(View, Mode, Error);
+ MemoryView ObjectView = View.Left(Size);
+ View += Size;
+
+ if (Size > 0)
+ {
+ std::vector<std::string_view> Names;
+
+ const bool bUniformObject = CbFieldTypeOps::GetType(ObjectType) == CbFieldType::UniformObject;
+ const CbFieldType ExternalType = bUniformObject ? ValidateCbFieldType(ObjectView, Mode, Error) : CbFieldType::HasFieldType;
+ CbUniformFieldsValidator UniformValidator(ExternalType);
+ do
+ {
+ if (CbFieldView Field = UniformValidator.ValidateField(ObjectView, Mode, Error))
+ {
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Names))
+ {
+ if (Field.HasName())
+ {
+ Names.push_back(Field.GetName());
+ }
+ else
+ {
+ AddError(Error, CbValidateError::MissingName);
+ }
+ }
+ }
+ } while (!ObjectView.IsEmpty());
+
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Names) && Names.size() > 1)
+ {
+ std::sort(begin(Names), end(Names), [](std::string_view L, std::string_view R) { return L.compare(R) < 0; });
+
+ for (const std::string_view *NamesIt = Names.data(), *NamesEnd = NamesIt + Names.size() - 1; NamesIt != NamesEnd; ++NamesIt)
+ {
+ if (NamesIt[0] == NamesIt[1])
+ {
+ AddError(Error, CbValidateError::DuplicateName);
+ break;
+ }
+ }
+ }
+
+ if (!bUniformObject && EnumHasAnyFlags(Mode, CbValidateMode::Format) && UniformValidator.IsUniform())
+ {
+ AddError(Error, CbValidateError::NonUniformObject);
+ }
+ }
+}
+
+static void
+ValidateCbArray(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, CbFieldType ArrayType)
+{
+ const uint64_t Size = ValidateCbUInt(View, Mode, Error);
+ MemoryView ArrayView = View.Left(Size);
+ View += Size;
+
+ const uint64_t Count = ValidateCbUInt(ArrayView, Mode, Error);
+ const uint64_t FieldsSize = ArrayView.GetSize();
+ const bool bUniformArray = CbFieldTypeOps::GetType(ArrayType) == CbFieldType::UniformArray;
+ const CbFieldType ExternalType = bUniformArray ? ValidateCbFieldType(ArrayView, Mode, Error) : CbFieldType::HasFieldType;
+ CbUniformFieldsValidator UniformValidator(ExternalType);
+
+ for (uint64_t Index = 0; Index < Count; ++Index)
+ {
+ if (CbFieldView Field = UniformValidator.ValidateField(ArrayView, Mode, Error))
+ {
+ if (Field.HasName() && EnumHasAnyFlags(Mode, CbValidateMode::Names))
+ {
+ AddError(Error, CbValidateError::ArrayName);
+ }
+ }
+ }
+
+ if (!bUniformArray && EnumHasAnyFlags(Mode, CbValidateMode::Format) && UniformValidator.IsUniform() && FieldsSize > Count)
+ {
+ AddError(Error, CbValidateError::NonUniformArray);
+ }
+}
+
+static CbFieldView
+ValidateCbField(MemoryView& View, CbValidateMode Mode, CbValidateError& Error, const CbFieldType ExternalType = CbFieldType::HasFieldType)
+{
+ const MemoryView FieldView = View;
+ const CbFieldType Type = ValidateCbFieldType(View, Mode, Error, ExternalType);
+ [[maybe_unused]] const std::string_view Name =
+ CbFieldTypeOps::HasFieldName(Type) ? ValidateCbString(View, Mode, Error) : std::string_view();
+
+ auto ValidateFixedPayload = [&View, &Error](uint32_t PayloadSize) {
+ if (View.GetSize() >= PayloadSize)
+ {
+ View += PayloadSize;
+ }
+ else
+ {
+ AddError(Error, CbValidateError::OutOfBounds);
+ View.Reset();
+ }
+ };
+
+ if (EnumHasAnyFlags(Error, CbValidateError::OutOfBounds | CbValidateError::InvalidType))
+ {
+ return CbFieldView();
+ }
+
+ switch (CbFieldType FieldType = CbFieldTypeOps::GetType(Type))
+ {
+ default:
+ case CbFieldType::None:
+ AddError(Error, CbValidateError::InvalidType);
+ View.Reset();
+ break;
+ case CbFieldType::Null:
+ case CbFieldType::BoolFalse:
+ case CbFieldType::BoolTrue:
+ if (FieldView == View)
+ {
+ // Reset the view because a zero-sized field can cause infinite field iteration.
+ AddError(Error, CbValidateError::InvalidType);
+ View.Reset();
+ }
+ break;
+ case CbFieldType::Object:
+ case CbFieldType::UniformObject:
+ ValidateCbObject(View, Mode, Error, FieldType);
+ break;
+ case CbFieldType::Array:
+ case CbFieldType::UniformArray:
+ ValidateCbArray(View, Mode, Error, FieldType);
+ break;
+ case CbFieldType::Binary:
+ {
+ const uint64_t ValueSize = ValidateCbUInt(View, Mode, Error);
+ if (View.GetSize() < ValueSize)
+ {
+ AddError(Error, CbValidateError::OutOfBounds);
+ View.Reset();
+ }
+ else
+ {
+ View += ValueSize;
+ }
+ break;
+ }
+ case CbFieldType::String:
+ ValidateCbString(View, Mode, Error);
+ break;
+ case CbFieldType::IntegerPositive:
+ ValidateCbUInt(View, Mode, Error);
+ break;
+ case CbFieldType::IntegerNegative:
+ ValidateCbUInt(View, Mode, Error);
+ break;
+ case CbFieldType::Float32:
+ ValidateFixedPayload(4);
+ break;
+ case CbFieldType::Float64:
+ ValidateCbFloat64(View, Mode, Error);
+ break;
+ case CbFieldType::ObjectAttachment:
+ case CbFieldType::BinaryAttachment:
+ case CbFieldType::Hash:
+ ValidateFixedPayload(20);
+ break;
+ case CbFieldType::Uuid:
+ ValidateFixedPayload(16);
+ break;
+ case CbFieldType::DateTime:
+ case CbFieldType::TimeSpan:
+ ValidateFixedPayload(8);
+ break;
+ case CbFieldType::ObjectId:
+ ValidateFixedPayload(12);
+ break;
+ case CbFieldType::CustomById:
+ case CbFieldType::CustomByName:
+ ZEN_NOT_IMPLEMENTED(); // TODO: FIX!
+ break;
+ }
+
+ if (EnumHasAnyFlags(Error, CbValidateError::OutOfBounds | CbValidateError::InvalidType))
+ {
+ return CbFieldView();
+ }
+
+ return CbFieldView(FieldView.GetData(), ExternalType);
+}
+
+static CbFieldView
+ValidateCbPackageField(MemoryView& View, CbValidateMode Mode, CbValidateError& Error)
+{
+ if (View.IsEmpty())
+ {
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ AddError(Error, CbValidateError::InvalidPackageFormat);
+ }
+ return CbFieldView();
+ }
+ if (CbFieldView Field = ValidateCbField(View, Mode, Error))
+ {
+ if (Field.HasName() && EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ AddError(Error, CbValidateError::InvalidPackageFormat);
+ }
+ return Field;
+ }
+ return CbFieldView();
+}
+
+static IoHash
+ValidateCbPackageAttachment(CbFieldView& Value, MemoryView& View, CbValidateMode Mode, CbValidateError& Error)
+{
+ if (const CbObjectView ObjectView = Value.AsObjectView(); !Value.HasError())
+ {
+ return ObjectView.GetHash();
+ }
+
+ if (const IoHash ObjectAttachmentHash = Value.AsObjectAttachment(); !Value.HasError())
+ {
+ if (CbFieldView ObjectField = ValidateCbPackageField(View, Mode, Error))
+ {
+ const CbObjectView InnerObjectView = ObjectField.AsObjectView();
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && ObjectField.HasError())
+ {
+ AddError(Error, CbValidateError::InvalidPackageFormat);
+ }
+ else if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (ObjectAttachmentHash != InnerObjectView.GetHash()))
+ {
+ AddError(Error, CbValidateError::InvalidPackageHash);
+ }
+ return ObjectAttachmentHash;
+ }
+ }
+ else if (const IoHash BinaryAttachmentHash = Value.AsBinaryAttachment(); !Value.HasError())
+ {
+ if (CbFieldView BinaryField = ValidateCbPackageField(View, Mode, Error))
+ {
+ const MemoryView BinaryView = BinaryField.AsBinaryView();
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && BinaryField.HasError())
+ {
+ AddError(Error, CbValidateError::InvalidPackageFormat);
+ }
+ else
+ {
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && BinaryView.IsEmpty())
+ {
+ AddError(Error, CbValidateError::NullPackageAttachment);
+ }
+ if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (BinaryAttachmentHash != IoHash::HashBuffer(BinaryView)))
+ {
+ AddError(Error, CbValidateError::InvalidPackageHash);
+ }
+ }
+ return BinaryAttachmentHash;
+ }
+ }
+ else if (const MemoryView BinaryView = Value.AsBinaryView(); !Value.HasError())
+ {
+ if (BinaryView.GetSize() > 0)
+ {
+ IoHash DecodedHash;
+ uint64_t DecodedRawSize;
+ CompressedBuffer Buffer = CompressedBuffer::FromCompressed(SharedBuffer::MakeView(BinaryView), DecodedHash, DecodedRawSize);
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package) && Buffer.IsNull())
+ {
+ AddError(Error, CbValidateError::NullPackageAttachment);
+ }
+ if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (DecodedHash != IoHash::HashBuffer(Buffer.DecompressToComposite())))
+ {
+ AddError(Error, CbValidateError::InvalidPackageHash);
+ }
+ return DecodedHash;
+ }
+ else
+ {
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ AddError(Error, CbValidateError::NullPackageAttachment);
+ }
+ return IoHash::HashBuffer(MemoryView());
+ }
+ }
+ else
+ {
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ AddError(Error, CbValidateError::InvalidPackageFormat);
+ }
+ }
+
+ return IoHash();
+}
+
+static IoHash
+ValidateCbPackageObject(CbFieldView& Value, MemoryView& View, CbValidateMode Mode, CbValidateError& Error)
+{
+ if (IoHash RootObjectHash = Value.AsHash(); !Value.HasError() && !Value.IsAttachment())
+ {
+ CbFieldView RootObjectField = ValidateCbPackageField(View, Mode, Error);
+
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ if (RootObjectField.HasError())
+ {
+ AddError(Error, CbValidateError::InvalidPackageFormat);
+ }
+ }
+
+ const CbObjectView RootObjectView = RootObjectField.AsObjectView();
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ if (!RootObjectView)
+ {
+ AddError(Error, CbValidateError::NullPackageObject);
+ }
+ }
+
+ if (EnumHasAnyFlags(Mode, CbValidateMode::PackageHash) && (RootObjectHash != RootObjectView.GetHash()))
+ {
+ AddError(Error, CbValidateError::InvalidPackageHash);
+ }
+
+ return RootObjectHash;
+ }
+ else
+ {
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ AddError(Error, CbValidateError::InvalidPackageFormat);
+ }
+ }
+
+ return IoHash();
+}
+
+CbValidateError
+ValidateCompactBinary(MemoryView View, CbValidateMode Mode, CbFieldType Type)
+{
+ CbValidateError Error = CbValidateError::None;
+ if (EnumHasAnyFlags(Mode, CbValidateMode::All))
+ {
+ ValidateCbField(View, Mode, Error, Type);
+ if (!View.IsEmpty() && EnumHasAnyFlags(Mode, CbValidateMode::Padding))
+ {
+ AddError(Error, CbValidateError::Padding);
+ }
+ }
+ return Error;
+}
+
+CbValidateError
+ValidateCompactBinaryRange(MemoryView View, CbValidateMode Mode)
+{
+ CbValidateError Error = CbValidateError::None;
+ if (EnumHasAnyFlags(Mode, CbValidateMode::All))
+ {
+ while (!View.IsEmpty())
+ {
+ ValidateCbField(View, Mode, Error);
+ }
+ }
+ return Error;
+}
+
+CbValidateError
+ValidateObjectAttachment(MemoryView View, CbValidateMode Mode)
+{
+ CbValidateError Error = CbValidateError::None;
+ if (EnumHasAnyFlags(Mode, CbValidateMode::All))
+ {
+ if (CbFieldView Value = ValidateCbPackageField(View, Mode, Error))
+ {
+ ValidateCbPackageAttachment(Value, View, Mode, Error);
+ }
+ if (!View.IsEmpty() && EnumHasAnyFlags(Mode, CbValidateMode::Padding))
+ {
+ AddError(Error, CbValidateError::Padding);
+ }
+ }
+ return Error;
+}
+
+CbValidateError
+ValidateCompactBinaryPackage(MemoryView View, CbValidateMode Mode)
+{
+ std::vector<IoHash> Attachments;
+ CbValidateError Error = CbValidateError::None;
+ if (EnumHasAnyFlags(Mode, CbValidateMode::All))
+ {
+ uint32_t ObjectCount = 0;
+ while (CbFieldView Value = ValidateCbPackageField(View, Mode, Error))
+ {
+ if (Value.IsHash() && !Value.IsAttachment())
+ {
+ ValidateCbPackageObject(Value, View, Mode, Error);
+ if (++ObjectCount > 1 && EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ AddError(Error, CbValidateError::MultiplePackageObjects);
+ }
+ }
+ else if (Value.IsBinary() || Value.IsAttachment() || Value.IsObject())
+ {
+ const IoHash Hash = ValidateCbPackageAttachment(Value, View, Mode, Error);
+ if (EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ Attachments.push_back(Hash);
+ }
+ }
+ else if (Value.IsNull())
+ {
+ break;
+ }
+ else if (EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ AddError(Error, CbValidateError::InvalidPackageFormat);
+ }
+
+ if (EnumHasAnyFlags(Error, CbValidateError::OutOfBounds))
+ {
+ break;
+ }
+ }
+
+ if (!View.IsEmpty() && EnumHasAnyFlags(Mode, CbValidateMode::Padding))
+ {
+ AddError(Error, CbValidateError::Padding);
+ }
+
+ if (Attachments.size() && EnumHasAnyFlags(Mode, CbValidateMode::Package))
+ {
+ std::sort(begin(Attachments), end(Attachments));
+ for (const IoHash *It = Attachments.data(), *End = It + Attachments.size() - 1; It != End; ++It)
+ {
+ if (It[0] == It[1])
+ {
+ AddError(Error, CbValidateError::DuplicateAttachments);
+ break;
+ }
+ }
+ }
+ }
+ return Error;
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+void
+usonvalidation_forcelink()
+{
+}
+
+TEST_CASE("usonvalidation")
+{
+ SUBCASE("Basic") {}
+}
+#endif
+
+} // namespace zen
diff --git a/src/zencore/compositebuffer.cpp b/src/zencore/compositebuffer.cpp
new file mode 100644
index 000000000..735020451
--- /dev/null
+++ b/src/zencore/compositebuffer.cpp
@@ -0,0 +1,446 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/compositebuffer.h>
+
+#include <zencore/sharedbuffer.h>
+#include <zencore/testing.h>
+
+namespace zen {
+
+const CompositeBuffer CompositeBuffer::Null;
+
+void
+CompositeBuffer::Reset()
+{
+ m_Segments.clear();
+}
+
+uint64_t
+CompositeBuffer::GetSize() const
+{
+ uint64_t Accum = 0;
+
+ for (const SharedBuffer& It : m_Segments)
+ {
+ Accum += It.GetSize();
+ }
+
+ return Accum;
+}
+
+bool
+CompositeBuffer::IsOwned() const
+{
+ for (const SharedBuffer& It : m_Segments)
+ {
+ if (It.IsOwned() == false)
+ {
+ return false;
+ }
+ }
+ return true;
+}
+
+CompositeBuffer
+CompositeBuffer::MakeOwned() const&
+{
+ return CompositeBuffer(*this).MakeOwned();
+}
+
+CompositeBuffer
+CompositeBuffer::MakeOwned() &&
+{
+ for (SharedBuffer& Segment : m_Segments)
+ {
+ Segment = std::move(Segment).MakeOwned();
+ }
+ return std::move(*this);
+}
+
+SharedBuffer
+CompositeBuffer::Flatten() const&
+{
+ switch (m_Segments.size())
+ {
+ case 0:
+ return SharedBuffer();
+ case 1:
+ return m_Segments[0];
+ default:
+ UniqueBuffer Buffer = UniqueBuffer::Alloc(GetSize());
+ MutableMemoryView OutView = Buffer.GetMutableView();
+
+ for (const SharedBuffer& Segment : m_Segments)
+ {
+ OutView.CopyFrom(Segment.GetView());
+ OutView += Segment.GetSize();
+ }
+
+ return Buffer.MoveToShared();
+ }
+}
+
+SharedBuffer
+CompositeBuffer::Flatten() &&
+{
+ return m_Segments.size() == 1 ? std::move(m_Segments[0]) : std::as_const(*this).Flatten();
+}
+
+CompositeBuffer
+CompositeBuffer::Mid(uint64_t Offset, uint64_t Size) const
+{
+ const uint64_t BufferSize = GetSize();
+ Offset = Min(Offset, BufferSize);
+ Size = Min(Size, BufferSize - Offset);
+ CompositeBuffer Buffer;
+ IterateRange(Offset, Size, [&Buffer](MemoryView View, const SharedBuffer& ViewOuter) {
+ Buffer.m_Segments.push_back(SharedBuffer::MakeView(View, ViewOuter));
+ });
+ return Buffer;
+}
+
+MemoryView
+CompositeBuffer::ViewOrCopyRange(uint64_t Offset, uint64_t Size, UniqueBuffer& CopyBuffer) const
+{
+ MemoryView View;
+ IterateRange(Offset, Size, [Size, &View, &CopyBuffer, WriteView = MutableMemoryView()](MemoryView Segment) mutable {
+ if (Size == Segment.GetSize())
+ {
+ View = Segment;
+ }
+ else
+ {
+ if (WriteView.IsEmpty())
+ {
+ if (CopyBuffer.GetSize() < Size)
+ {
+ CopyBuffer = UniqueBuffer::Alloc(Size);
+ }
+ View = WriteView = CopyBuffer.GetMutableView().Left(Size);
+ }
+ WriteView = WriteView.CopyFrom(Segment);
+ }
+ });
+ return View;
+}
+
+CompositeBuffer::Iterator
+CompositeBuffer::GetIterator(uint64_t Offset) const
+{
+ size_t SegmentCount = m_Segments.size();
+ size_t SegmentIndex = 0;
+ while (SegmentIndex < SegmentCount)
+ {
+ size_t SegmentSize = m_Segments[SegmentIndex].GetSize();
+ if (Offset < SegmentSize)
+ {
+ return {.SegmentIndex = SegmentIndex, .OffsetInSegment = Offset};
+ }
+ Offset -= SegmentSize;
+ SegmentIndex++;
+ }
+ return {.SegmentIndex = ~0ull, .OffsetInSegment = ~0ull};
+}
+
+MemoryView
+CompositeBuffer::ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& CopyBuffer) const
+{
+ // We use a sub range IoBuffer when we want to copy data from a segment.
+ // This means we will only materialize that range of the segment when doing
+ // GetView() rather than the full segment.
+ // A hot path for this code is when we call CompressedBuffer::FromCompressed which
+ // is only interested in reading the header (first 64 bytes or so) and then throws
+ // away the materialized data.
+ MutableMemoryView WriteView;
+ size_t SegmentCount = m_Segments.size();
+ ZEN_ASSERT(It.SegmentIndex < SegmentCount);
+ uint64_t SizeLeft = Size;
+ while (SizeLeft > 0 && It.SegmentIndex < SegmentCount)
+ {
+ const SharedBuffer& Segment = m_Segments[It.SegmentIndex];
+ size_t SegmentSize = Segment.GetSize();
+ if (Size == SizeLeft && Size <= (SegmentSize - It.OffsetInSegment))
+ {
+ IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, SizeLeft);
+ MemoryView View = SubSegment.GetView();
+ It.OffsetInSegment += SizeLeft;
+ ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize);
+ if (It.OffsetInSegment == SegmentSize)
+ {
+ It.SegmentIndex++;
+ It.OffsetInSegment = 0;
+ }
+ return View;
+ }
+ if (WriteView.GetSize() == 0)
+ {
+ if (CopyBuffer.GetSize() < Size)
+ {
+ CopyBuffer = UniqueBuffer::Alloc(Size);
+ }
+ WriteView = CopyBuffer.GetMutableView();
+ }
+ size_t CopySize = zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft);
+ IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize);
+ MemoryView ReadView = SubSegment.GetView();
+ WriteView = WriteView.CopyFrom(ReadView);
+ It.OffsetInSegment += CopySize;
+ ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize);
+ if (It.OffsetInSegment == SegmentSize)
+ {
+ It.SegmentIndex++;
+ It.OffsetInSegment = 0;
+ }
+ SizeLeft -= CopySize;
+ }
+ return CopyBuffer.GetView().Left(Size - SizeLeft);
+}
+
+void
+CompositeBuffer::CopyTo(MutableMemoryView WriteView, Iterator& It) const
+{
+ // We use a sub range IoBuffer when we want to copy data from a segment.
+ // This means we will only materialize that range of the segment when doing
+ // GetView() rather than the full segment.
+ // A hot path for this code is when we call CompressedBuffer::FromCompressed which
+ // is only interested in reading the header (first 64 bytes or so) and then throws
+ // away the materialized data.
+
+ size_t SizeLeft = WriteView.GetSize();
+ size_t SegmentCount = m_Segments.size();
+ ZEN_ASSERT(It.SegmentIndex < SegmentCount);
+ while (WriteView.GetSize() > 0 && It.SegmentIndex < SegmentCount)
+ {
+ const SharedBuffer& Segment = m_Segments[It.SegmentIndex];
+ size_t SegmentSize = Segment.GetSize();
+ size_t CopySize = zen::Min(SegmentSize - It.OffsetInSegment, SizeLeft);
+ IoBuffer SubSegment(Segment.AsIoBuffer(), It.OffsetInSegment, CopySize);
+ MemoryView ReadView = SubSegment.GetView();
+ WriteView = WriteView.CopyFrom(ReadView);
+ It.OffsetInSegment += CopySize;
+ ZEN_ASSERT_SLOW(It.OffsetInSegment <= SegmentSize);
+ if (It.OffsetInSegment == SegmentSize)
+ {
+ It.SegmentIndex++;
+ It.OffsetInSegment = 0;
+ }
+ SizeLeft -= CopySize;
+ }
+}
+
+void
+CompositeBuffer::CopyTo(MutableMemoryView Target, uint64_t Offset) const
+{
+ IterateRange(Offset, Target.GetSize(), [Target](MemoryView View, [[maybe_unused]] const SharedBuffer& ViewOuter) mutable {
+ Target = Target.CopyFrom(View);
+ });
+}
+
+void
+CompositeBuffer::IterateRange(uint64_t Offset, uint64_t Size, std::function<void(MemoryView View)> Visitor) const
+{
+ IterateRange(Offset, Size, [Visitor](MemoryView View, [[maybe_unused]] const SharedBuffer& ViewOuter) { Visitor(View); });
+}
+
+void
+CompositeBuffer::IterateRange(uint64_t Offset,
+ uint64_t Size,
+ std::function<void(MemoryView View, const SharedBuffer& ViewOuter)> Visitor) const
+{
+ ZEN_ASSERT(Offset + Size <= GetSize());
+ for (const SharedBuffer& Segment : m_Segments)
+ {
+ if (const uint64_t SegmentSize = Segment.GetSize(); Offset <= SegmentSize)
+ {
+ const MemoryView View = Segment.GetView().Mid(Offset, Size);
+ Offset = 0;
+ if (Size == 0 || !View.IsEmpty())
+ {
+ Visitor(View, Segment);
+ }
+ Size -= View.GetSize();
+ if (Size == 0)
+ {
+ break;
+ }
+ }
+ else
+ {
+ Offset -= SegmentSize;
+ }
+ }
+}
+
+#if ZEN_WITH_TESTS
+TEST_CASE("CompositeBuffer Null")
+{
+ CompositeBuffer Buffer;
+ CHECK(Buffer.IsNull());
+ CHECK(Buffer.IsOwned());
+ CHECK(Buffer.MakeOwned().IsNull());
+ CHECK(Buffer.Flatten().IsNull());
+ CHECK(Buffer.Mid(0, 0).IsNull());
+ CHECK(Buffer.GetSize() == 0);
+ CHECK(Buffer.GetSegments().size() == 0);
+
+ UniqueBuffer CopyBuffer;
+ CHECK(Buffer.ViewOrCopyRange(0, 0, CopyBuffer).IsEmpty());
+ CHECK(CopyBuffer.IsNull());
+
+ MutableMemoryView CopyView;
+ Buffer.CopyTo(CopyView);
+
+ uint32_t VisitCount = 0;
+ Buffer.IterateRange(0, 0, [&VisitCount](MemoryView) { ++VisitCount; });
+ CHECK(VisitCount == 0);
+}
+
+TEST_CASE("CompositeBuffer Empty")
+{
+ const uint8_t EmptyArray[]{0};
+ const SharedBuffer EmptyView = SharedBuffer::MakeView(EmptyArray, 0);
+ CompositeBuffer Buffer(EmptyView);
+ CHECK(Buffer.IsNull() == false);
+ CHECK(Buffer.IsOwned() == false);
+ CHECK(Buffer.MakeOwned().IsNull() == false);
+ CHECK(Buffer.MakeOwned().IsOwned() == true);
+ CHECK(Buffer.Flatten() == EmptyView);
+ CHECK(Buffer.Mid(0, 0).Flatten() == EmptyView);
+ CHECK(Buffer.GetSize() == 0);
+ CHECK(Buffer.GetSegments().size() == 1);
+ CHECK(Buffer.GetSegments()[0] == EmptyView);
+
+ UniqueBuffer CopyBuffer;
+ CHECK(Buffer.ViewOrCopyRange(0, 0, CopyBuffer) == EmptyView.GetView());
+ CHECK(CopyBuffer.IsNull());
+
+ MutableMemoryView CopyView;
+ Buffer.CopyTo(CopyView);
+
+ uint32_t VisitCount = 0;
+ Buffer.IterateRange(0, 0, [&VisitCount](MemoryView) { ++VisitCount; });
+ CHECK(VisitCount == 1);
+}
+
+TEST_CASE("CompositeBuffer Empty[1]")
+{
+ const uint8_t EmptyArray[1]{};
+ const SharedBuffer EmptyView1 = SharedBuffer::MakeView(EmptyArray, 0);
+ const SharedBuffer EmptyView2 = SharedBuffer::MakeView(EmptyArray + 1, 0);
+ CompositeBuffer Buffer(EmptyView1, EmptyView2);
+ CHECK(Buffer.Mid(0, 0).Flatten() == EmptyView1);
+ CHECK(Buffer.GetSize() == 0);
+ CHECK(Buffer.GetSegments().size() == 2);
+ CHECK(Buffer.GetSegments()[0] == EmptyView1);
+ CHECK(Buffer.GetSegments()[1] == EmptyView2);
+
+ UniqueBuffer CopyBuffer;
+ CHECK(Buffer.ViewOrCopyRange(0, 0, CopyBuffer) == EmptyView1.GetView());
+ CHECK(CopyBuffer.IsNull());
+
+ MutableMemoryView CopyView;
+ Buffer.CopyTo(CopyView);
+
+ uint32_t VisitCount = 0;
+ Buffer.IterateRange(0, 0, [&VisitCount](MemoryView) { ++VisitCount; });
+ CHECK(VisitCount == 1);
+}
+
+TEST_CASE("CompositeBuffer Flat")
+{
+ const uint8_t FlatArray[]{1, 2, 3, 4, 5, 6, 7, 8};
+ const SharedBuffer FlatView = SharedBuffer::Clone(MakeMemoryView(FlatArray));
+ CompositeBuffer Buffer(FlatView);
+
+ CHECK(Buffer.IsNull() == false);
+ CHECK(Buffer.IsOwned() == true);
+ CHECK(Buffer.Flatten() == FlatView);
+ CHECK(Buffer.MakeOwned().Flatten() == FlatView);
+ CHECK(Buffer.Mid(0).Flatten() == FlatView);
+ CHECK(Buffer.Mid(4).Flatten().GetView() == FlatView.GetView().Mid(4));
+ CHECK(Buffer.Mid(8).Flatten().GetView() == FlatView.GetView().Mid(8));
+ CHECK(Buffer.Mid(4, 2).Flatten().GetView() == FlatView.GetView().Mid(4, 2));
+ CHECK(Buffer.Mid(8, 0).Flatten().GetView() == FlatView.GetView().Mid(8, 0));
+ CHECK(Buffer.GetSize() == sizeof(FlatArray));
+ CHECK(Buffer.GetSegments().size() == 1);
+ CHECK(Buffer.GetSegments()[0] == FlatView);
+
+ UniqueBuffer CopyBuffer;
+ CHECK(Buffer.ViewOrCopyRange(0, sizeof(FlatArray), CopyBuffer) == FlatView.GetView());
+ CHECK(CopyBuffer.IsNull());
+
+ uint8_t CopyArray[sizeof(FlatArray) - 3];
+ Buffer.CopyTo(MakeMutableMemoryView(CopyArray), 3);
+ CHECK(MakeMemoryView(CopyArray).EqualBytes(MakeMemoryView(FlatArray) + 3));
+
+ uint32_t VisitCount = 0;
+ Buffer.IterateRange(0, sizeof(FlatArray), [&VisitCount](MemoryView) { ++VisitCount; });
+ CHECK(VisitCount == 1);
+}
+
+TEST_CASE("CompositeBuffer Composite")
+{
+ const uint8_t FlatArray[]{1, 2, 3, 4, 5, 6, 7, 8};
+ const SharedBuffer FlatView1 = SharedBuffer::MakeView(MakeMemoryView(FlatArray).Left(4));
+ const SharedBuffer FlatView2 = SharedBuffer::MakeView(MakeMemoryView(FlatArray).Right(4));
+ CompositeBuffer Buffer(FlatView1, FlatView2);
+
+ CHECK(Buffer.IsNull() == false);
+ CHECK(Buffer.IsOwned() == false);
+ CHECK(Buffer.Flatten().GetView().EqualBytes(MakeMemoryView(FlatArray)));
+ CHECK(Buffer.Mid(2, 4).Flatten().GetView().EqualBytes(MakeMemoryView(FlatArray).Mid(2, 4)));
+ CHECK(Buffer.Mid(0, 4).Flatten() == FlatView1);
+ CHECK(Buffer.Mid(4, 4).Flatten() == FlatView2);
+ CHECK(Buffer.GetSize() == sizeof(FlatArray));
+ CHECK(Buffer.GetSegments().size() == 2);
+ CHECK(Buffer.GetSegments()[0] == FlatView1);
+ CHECK(Buffer.GetSegments()[1] == FlatView2);
+
+ UniqueBuffer CopyBuffer;
+
+ CHECK(Buffer.ViewOrCopyRange(0, 4, CopyBuffer) == FlatView1.GetView());
+ CHECK(CopyBuffer.IsNull() == true);
+ CHECK(Buffer.ViewOrCopyRange(4, 4, CopyBuffer) == FlatView2.GetView());
+ CHECK(CopyBuffer.IsNull() == true);
+ CHECK(Buffer.ViewOrCopyRange(3, 2, CopyBuffer).EqualBytes(MakeMemoryView(FlatArray).Mid(3, 2)));
+ CHECK(CopyBuffer.GetSize() == 2);
+ CHECK(Buffer.ViewOrCopyRange(1, 6, CopyBuffer).EqualBytes(MakeMemoryView(FlatArray).Mid(1, 6)));
+ CHECK(CopyBuffer.GetSize() == 6);
+ CHECK(Buffer.ViewOrCopyRange(2, 4, CopyBuffer).EqualBytes(MakeMemoryView(FlatArray).Mid(2, 4)));
+ CHECK(CopyBuffer.GetSize() == 6);
+
+ uint8_t CopyArray[4];
+ Buffer.CopyTo(MakeMutableMemoryView(CopyArray), 2);
+ CHECK(MakeMemoryView(CopyArray).EqualBytes(MakeMemoryView(FlatArray).Mid(2, 4)));
+
+ uint32_t VisitCount = 0;
+ Buffer.IterateRange(0, sizeof(FlatArray), [&VisitCount](MemoryView) { ++VisitCount; });
+ CHECK(VisitCount == 2);
+
+ const auto TestIterateRange =
+ [&Buffer](uint64_t Offset, uint64_t Size, MemoryView ExpectedView, const SharedBuffer& ExpectedViewOuter) {
+ uint32_t VisitCount = 0;
+ MemoryView ActualView;
+ SharedBuffer ActualViewOuter;
+ Buffer.IterateRange(Offset, Size, [&VisitCount, &ActualView, &ActualViewOuter](MemoryView View, const SharedBuffer& ViewOuter) {
+ ++VisitCount;
+ ActualView = View;
+ ActualViewOuter = ViewOuter;
+ });
+ CHECK(VisitCount == 1);
+ CHECK(ActualView == ExpectedView);
+ CHECK(ActualViewOuter == ExpectedViewOuter);
+ };
+ TestIterateRange(0, 4, MakeMemoryView(FlatArray).Mid(0, 4), FlatView1);
+ TestIterateRange(4, 0, MakeMemoryView(FlatArray).Mid(4, 0), FlatView1);
+ TestIterateRange(4, 4, MakeMemoryView(FlatArray).Mid(4, 4), FlatView2);
+ TestIterateRange(8, 0, MakeMemoryView(FlatArray).Mid(8, 0), FlatView2);
+}
+
+void
+compositebuffer_forcelink()
+{
+}
+#endif
+
+} // namespace zen
diff --git a/src/zencore/compress.cpp b/src/zencore/compress.cpp
new file mode 100644
index 000000000..632e0e8f3
--- /dev/null
+++ b/src/zencore/compress.cpp
@@ -0,0 +1,1353 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/compress.h>
+
+#include <zencore/blake3.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/crc32.h>
+#include <zencore/endian.h>
+#include <zencore/iohash.h>
+#include <zencore/testing.h>
+
+#include "../../thirdparty/Oodle/include/oodle2.h"
+#if ZEN_PLATFORM_WINDOWS
+# pragma comment(lib, "oo2core_win64.lib")
+#endif
+
+#include <lz4.h>
+#include <functional>
+#include <limits>
+
+namespace zen::detail {
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+static constexpr uint64_t DefaultBlockSize = 256 * 1024;
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/** Method used to compress the data in a compressed buffer. */
+enum class CompressionMethod : uint8_t
+{
+ /** Header is followed by one uncompressed block. */
+ None = 0,
+ /** Header is followed by an array of compressed block sizes then the compressed blocks. */
+ Oodle = 3,
+ /** Header is followed by an array of compressed block sizes then the compressed blocks. */
+ LZ4 = 4,
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/** Header used on every compressed buffer. Always stored in big-endian format. */
+struct BufferHeader
+{
+ static constexpr uint32_t ExpectedMagic = 0xb7756362; // <dot>ucb
+
+ uint32_t Magic = ExpectedMagic; // A magic number to identify a compressed buffer. Always 0xb7756362.
+ uint32_t Crc32 = 0; // A CRC-32 used to check integrity of the buffer. Uses the polynomial 0x04c11db7
+ CompressionMethod Method =
+ CompressionMethod::None; // The method used to compress the buffer. Affects layout of data following the header
+ uint8_t Compressor = 0; // The method-specific compressor used to compress the buffer.
+ uint8_t CompressionLevel = 0; // The method-specific compression level used to compress the buffer.
+ uint8_t BlockSizeExponent = 0; // The power of two size of every uncompressed block except the last. Size is 1 << BlockSizeExponent
+ uint32_t BlockCount = 0; // The number of blocks that follow the header
+ uint64_t TotalRawSize = 0; // The total size of the uncompressed data
+ uint64_t TotalCompressedSize = 0; // The total size of the compressed data including the header
+ BLAKE3 RawHash; // The hash of the uncompressed data
+
+ /** Checks validity of the buffer based on the magic number, method, and CRC-32. */
+ static bool IsValid(const CompositeBuffer& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize);
+ static bool IsValid(const SharedBuffer& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize)
+ {
+ return IsValid(CompositeBuffer(CompressedData), OutRawHash, OutRawSize);
+ }
+
+ /** Read a header from a buffer that is at least sizeof(BufferHeader) without any validation. */
+ static BufferHeader Read(const CompositeBuffer& CompressedData)
+ {
+ BufferHeader Header;
+ if (sizeof(BufferHeader) <= CompressedData.GetSize())
+ {
+ // if (CompressedData.GetSegments()[0].AsIoBuffer().IsWholeFile())
+ // {
+ // ZEN_ASSERT(true);
+ // }
+ CompositeBuffer::Iterator It;
+ CompressedData.CopyTo(MakeMutableMemoryView(&Header, &Header + 1), It);
+ Header.ByteSwap();
+ }
+ return Header;
+ }
+
+ /**
+ * Write a header to a memory view that is at least sizeof(BufferHeader).
+ *
+ * @param HeaderView View of the header to write, including any method-specific header data.
+ */
+ void Write(MutableMemoryView HeaderView) const
+ {
+ BufferHeader Header = *this;
+ Header.ByteSwap();
+ HeaderView.CopyFrom(MakeMemoryView(&Header, &Header + 1));
+ Header.ByteSwap();
+ Header.Crc32 = CalculateCrc32(HeaderView);
+ Header.ByteSwap();
+ HeaderView.CopyFrom(MakeMemoryView(&Header, &Header + 1));
+ }
+
+ void ByteSwap()
+ {
+ Magic = zen::ByteSwap(Magic);
+ Crc32 = zen::ByteSwap(Crc32);
+ BlockCount = zen::ByteSwap(BlockCount);
+ TotalRawSize = zen::ByteSwap(TotalRawSize);
+ TotalCompressedSize = zen::ByteSwap(TotalCompressedSize);
+ }
+
+ /** Calculate the CRC-32 from a view of a header including any method-specific header data. */
+ static uint32_t CalculateCrc32(MemoryView HeaderView)
+ {
+ uint32_t Crc32 = 0;
+ constexpr uint64_t MethodOffset = offsetof(BufferHeader, Method);
+ for (MemoryView View = HeaderView + MethodOffset; const uint64_t ViewSize = View.GetSize();)
+ {
+ const int32_t Size = static_cast<int32_t>(zen::Min<uint64_t>(ViewSize, /* INT_MAX */ 2147483647u));
+ Crc32 = zen::MemCrc32(View.GetData(), Size, Crc32);
+ View += Size;
+ }
+ return Crc32;
+ }
+};
+
+static_assert(sizeof(BufferHeader) == 64, "BufferHeader is the wrong size.");
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+class BaseEncoder
+{
+public:
+ virtual CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t BlockSize = DefaultBlockSize) const = 0;
+};
+
+class BaseDecoder
+{
+public:
+ virtual CompositeBuffer Decompress(const BufferHeader& Header, const CompositeBuffer& CompressedData) const = 0;
+ virtual bool TryDecompressTo(const BufferHeader& Header,
+ const CompositeBuffer& CompressedData,
+ MutableMemoryView RawView,
+ uint64_t RawOffset) const = 0;
+ virtual uint64_t GetHeaderSize(const BufferHeader& Header) const = 0;
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+class NoneEncoder final : public BaseEncoder
+{
+public:
+ [[nodiscard]] CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t /* BlockSize */) const final
+ {
+ BufferHeader Header;
+ Header.Method = CompressionMethod::None;
+ Header.BlockCount = 1;
+ Header.TotalRawSize = RawData.GetSize();
+ Header.TotalCompressedSize = Header.TotalRawSize + sizeof(BufferHeader);
+ Header.RawHash = BLAKE3::HashBuffer(RawData);
+
+ UniqueBuffer HeaderData = UniqueBuffer::Alloc(sizeof(BufferHeader));
+ Header.Write(HeaderData);
+ return CompositeBuffer(HeaderData.MoveToShared(), RawData.MakeOwned());
+ }
+};
+
+class NoneDecoder final : public BaseDecoder
+{
+public:
+ [[nodiscard]] CompositeBuffer Decompress(const BufferHeader& Header, const CompositeBuffer& CompressedData) const final
+ {
+ if (Header.Method == CompressionMethod::None && Header.TotalCompressedSize == CompressedData.GetSize() &&
+ Header.TotalCompressedSize == Header.TotalRawSize + sizeof(BufferHeader))
+ {
+ return CompressedData.Mid(sizeof(BufferHeader), Header.TotalRawSize).MakeOwned();
+ }
+ return CompositeBuffer();
+ }
+
+ [[nodiscard]] bool TryDecompressTo(const BufferHeader& Header,
+ const CompositeBuffer& CompressedData,
+ MutableMemoryView RawView,
+ uint64_t RawOffset) const final
+ {
+ if (Header.Method == CompressionMethod::None && RawOffset + RawView.GetSize() <= Header.TotalRawSize &&
+ Header.TotalCompressedSize == CompressedData.GetSize() &&
+ Header.TotalCompressedSize == Header.TotalRawSize + sizeof(BufferHeader))
+ {
+ CompressedData.CopyTo(RawView, sizeof(BufferHeader) + RawOffset);
+ return true;
+ }
+ return false;
+ }
+
+ [[nodiscard]] uint64_t GetHeaderSize(const BufferHeader&) const final { return sizeof(BufferHeader); }
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+class BlockEncoder : public BaseEncoder
+{
+public:
+ CompositeBuffer Compress(const CompositeBuffer& RawData, uint64_t BlockSize = DefaultBlockSize) const final;
+
+protected:
+ virtual CompressionMethod GetMethod() const = 0;
+ virtual uint8_t GetCompressor() const = 0;
+ virtual uint8_t GetCompressionLevel() const = 0;
+ virtual uint64_t CompressBlockBound(uint64_t RawSize) const = 0;
+ virtual bool CompressBlock(MutableMemoryView& CompressedData, MemoryView RawData) const = 0;
+
+private:
+ uint64_t GetCompressedBlocksBound(uint64_t BlockCount, uint64_t BlockSize, uint64_t RawSize) const
+ {
+ switch (BlockCount)
+ {
+ case 0:
+ return 0;
+ case 1:
+ return CompressBlockBound(RawSize);
+ default:
+ return CompressBlockBound(BlockSize) - BlockSize + RawSize;
+ }
+ }
+};
+
+CompositeBuffer
+BlockEncoder::Compress(const CompositeBuffer& RawData, const uint64_t BlockSize) const
+{
+ ZEN_ASSERT(IsPow2(BlockSize) && (BlockSize <= (1u << 31)));
+
+ const uint64_t RawSize = RawData.GetSize();
+ BLAKE3Stream RawHash;
+
+ const uint64_t BlockCount = RoundUp(RawSize, BlockSize) / BlockSize;
+ ZEN_ASSERT(BlockCount <= ~uint32_t(0));
+
+ // Allocate the buffer for the header, metadata, and compressed blocks.
+ const uint64_t MetaSize = BlockCount * sizeof(uint32_t);
+ const uint64_t CompressedDataSize = sizeof(BufferHeader) + MetaSize + GetCompressedBlocksBound(BlockCount, BlockSize, RawSize);
+ UniqueBuffer CompressedData = UniqueBuffer::Alloc(CompressedDataSize);
+
+ // Compress the raw data in blocks and store the raw data for incompressible blocks.
+ std::vector<uint32_t> CompressedBlockSizes;
+ CompressedBlockSizes.reserve(BlockCount);
+ uint64_t CompressedSize = 0;
+ {
+ UniqueBuffer RawBlockCopy;
+ MutableMemoryView CompressedBlocksView = CompressedData.GetMutableView() + sizeof(BufferHeader) + MetaSize;
+
+ CompositeBuffer::Iterator It = RawData.GetIterator(0);
+
+ for (uint64_t RawOffset = 0; RawOffset < RawSize;)
+ {
+ const uint64_t RawBlockSize = zen::Min(RawSize - RawOffset, BlockSize);
+ const MemoryView RawBlock = RawData.ViewOrCopyRange(It, RawBlockSize, RawBlockCopy);
+ RawHash.Append(RawBlock);
+
+ MutableMemoryView CompressedBlock = CompressedBlocksView;
+ if (!CompressBlock(CompressedBlock, RawBlock))
+ {
+ return CompositeBuffer();
+ }
+
+ uint64_t CompressedBlockSize = CompressedBlock.GetSize();
+ if (RawBlockSize <= CompressedBlockSize)
+ {
+ CompressedBlockSize = RawBlockSize;
+ CompressedBlocksView = CompressedBlocksView.CopyFrom(RawBlock);
+ }
+ else
+ {
+ CompressedBlocksView += CompressedBlockSize;
+ }
+
+ CompressedBlockSizes.push_back(static_cast<uint32_t>(CompressedBlockSize));
+ CompressedSize += CompressedBlockSize;
+ RawOffset += RawBlockSize;
+ }
+ }
+
+ // Return an uncompressed buffer if the compressed data is larger than the raw data.
+ if (RawSize <= MetaSize + CompressedSize)
+ {
+ CompressedData.Reset();
+ return NoneEncoder().Compress(RawData, BlockSize);
+ }
+
+ // Write the header and calculate the CRC-32.
+ for (uint32_t& Size : CompressedBlockSizes)
+ {
+ Size = ByteSwap(Size);
+ }
+ CompressedData.GetMutableView().Mid(sizeof(BufferHeader), MetaSize).CopyFrom(MakeMemoryView(CompressedBlockSizes));
+
+ BufferHeader Header;
+ Header.Method = GetMethod();
+ Header.Compressor = GetCompressor();
+ Header.CompressionLevel = GetCompressionLevel();
+ Header.BlockSizeExponent = static_cast<uint8_t>(zen::FloorLog2_64(BlockSize));
+ Header.BlockCount = static_cast<uint32_t>(BlockCount);
+ Header.TotalRawSize = RawSize;
+ Header.TotalCompressedSize = sizeof(BufferHeader) + MetaSize + CompressedSize;
+ Header.RawHash = RawHash.GetHash();
+ Header.Write(CompressedData.GetMutableView().Left(sizeof(BufferHeader) + MetaSize));
+
+ const MemoryView CompositeView = CompressedData.GetView().Left(Header.TotalCompressedSize);
+ return CompositeBuffer(SharedBuffer::MakeView(CompositeView, CompressedData.MoveToShared()));
+}
+
+class BlockDecoder : public BaseDecoder
+{
+public:
+ CompositeBuffer Decompress(const BufferHeader& Header, const CompositeBuffer& CompressedData) const final;
+ [[nodiscard]] bool TryDecompressTo(const BufferHeader& Header,
+ const CompositeBuffer& CompressedData,
+ MutableMemoryView RawView,
+ uint64_t RawOffset) const final;
+ [[nodiscard]] uint64_t GetHeaderSize(const BufferHeader& Header) const final
+ {
+ return sizeof(BufferHeader) + sizeof(uint32_t) * uint64_t(Header.BlockCount);
+ }
+
+protected:
+ virtual bool DecompressBlock(MutableMemoryView RawData, MemoryView CompressedData) const = 0;
+};
+
+CompositeBuffer
+BlockDecoder::Decompress(const BufferHeader& Header, const CompositeBuffer& CompressedData) const
+{
+ if (Header.BlockCount == 0 || Header.TotalCompressedSize != CompressedData.GetSize())
+ {
+ return CompositeBuffer();
+ }
+
+ // The raw data cannot reference the compressed data unless it is owned.
+ // An empty raw buffer requires an empty segment, which this path creates.
+ if (!CompressedData.IsOwned() || Header.TotalRawSize == 0)
+ {
+ UniqueBuffer Buffer = UniqueBuffer::Alloc(Header.TotalRawSize);
+ return TryDecompressTo(Header, CompressedData, Buffer, 0) ? CompositeBuffer(Buffer.MoveToShared()) : CompositeBuffer();
+ }
+
+ std::vector<uint32_t> CompressedBlockSizes;
+ CompressedBlockSizes.resize(Header.BlockCount);
+ CompressedData.CopyTo(MakeMutableMemoryView(CompressedBlockSizes), sizeof(BufferHeader));
+
+ for (uint32_t& Size : CompressedBlockSizes)
+ {
+ Size = ByteSwap(Size);
+ }
+
+ // Allocate the buffer for the raw blocks that were compressed.
+ SharedBuffer RawData;
+ MutableMemoryView RawDataView;
+ const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent;
+ {
+ uint64_t RawDataSize = 0;
+ uint64_t RemainingRawSize = Header.TotalRawSize;
+ for (const uint32_t CompressedBlockSize : CompressedBlockSizes)
+ {
+ const uint64_t RawBlockSize = zen::Min(RemainingRawSize, BlockSize);
+ if (CompressedBlockSize < BlockSize)
+ {
+ RawDataSize += RawBlockSize;
+ }
+ RemainingRawSize -= RawBlockSize;
+ }
+ UniqueBuffer RawDataBuffer = UniqueBuffer::Alloc(RawDataSize);
+ RawDataView = RawDataBuffer;
+ RawData = RawDataBuffer.MoveToShared();
+ }
+
+ // Decompress the compressed data in blocks and reference the uncompressed blocks.
+ uint64_t PendingCompressedSegmentOffset = sizeof(BufferHeader) + uint64_t(Header.BlockCount) * sizeof(uint32_t);
+ uint64_t PendingCompressedSegmentSize = 0;
+ uint64_t PendingRawSegmentOffset = 0;
+ uint64_t PendingRawSegmentSize = 0;
+ std::vector<SharedBuffer> Segments;
+
+ const auto CommitPendingCompressedSegment =
+ [&PendingCompressedSegmentOffset, &PendingCompressedSegmentSize, &CompressedData, &Segments] {
+ if (PendingCompressedSegmentSize)
+ {
+ CompressedData.IterateRange(PendingCompressedSegmentOffset,
+ PendingCompressedSegmentSize,
+ [&Segments](MemoryView View, const SharedBuffer& ViewOuter) {
+ Segments.push_back(SharedBuffer::MakeView(View, ViewOuter));
+ });
+ PendingCompressedSegmentOffset += PendingCompressedSegmentSize;
+ PendingCompressedSegmentSize = 0;
+ }
+ };
+
+ const auto CommitPendingRawSegment = [&PendingRawSegmentOffset, &PendingRawSegmentSize, &RawData, &Segments] {
+ if (PendingRawSegmentSize)
+ {
+ const MemoryView PendingSegment = RawData.GetView().Mid(PendingRawSegmentOffset, PendingRawSegmentSize);
+ Segments.push_back(SharedBuffer::MakeView(PendingSegment, RawData));
+ PendingRawSegmentOffset += PendingRawSegmentSize;
+ PendingRawSegmentSize = 0;
+ }
+ };
+
+ UniqueBuffer CompressedBlockCopy;
+ uint64_t RemainingRawSize = Header.TotalRawSize;
+ uint64_t RemainingCompressedSize = CompressedData.GetSize();
+ for (const uint32_t CompressedBlockSize : CompressedBlockSizes)
+ {
+ if (RemainingCompressedSize < CompressedBlockSize)
+ {
+ return CompositeBuffer();
+ }
+
+ const uint64_t RawBlockSize = zen::Min(RemainingRawSize, BlockSize);
+ if (RawBlockSize == CompressedBlockSize)
+ {
+ CommitPendingRawSegment();
+ PendingCompressedSegmentSize += RawBlockSize;
+ }
+ else
+ {
+ CommitPendingCompressedSegment();
+ const MemoryView CompressedBlock =
+ CompressedData.ViewOrCopyRange(PendingCompressedSegmentOffset, CompressedBlockSize, CompressedBlockCopy);
+ if (!DecompressBlock(RawDataView.Left(RawBlockSize), CompressedBlock))
+ {
+ return CompositeBuffer();
+ }
+ PendingCompressedSegmentOffset += CompressedBlockSize;
+ PendingRawSegmentSize += RawBlockSize;
+ RawDataView += RawBlockSize;
+ }
+
+ RemainingCompressedSize -= CompressedBlockSize;
+ RemainingRawSize -= RawBlockSize;
+ }
+
+ CommitPendingCompressedSegment();
+ CommitPendingRawSegment();
+
+ return CompositeBuffer(std::move(Segments));
+}
+
+bool
+BlockDecoder::TryDecompressTo(const BufferHeader& Header,
+ const CompositeBuffer& CompressedData,
+ MutableMemoryView RawView,
+ uint64_t RawOffset) const
+{
+ if (Header.TotalRawSize < RawOffset + RawView.GetSize() || Header.TotalCompressedSize != CompressedData.GetSize())
+ {
+ return false;
+ }
+
+ const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent;
+
+ UniqueBuffer BlockSizeBuffer;
+ MemoryView BlockSizeView = CompressedData.ViewOrCopyRange(sizeof(BufferHeader), Header.BlockCount * sizeof(uint32_t), BlockSizeBuffer);
+ std::span<uint32_t const> CompressedBlockSizes(reinterpret_cast<const uint32_t*>(BlockSizeView.GetData()), Header.BlockCount);
+
+ UniqueBuffer CompressedBlockCopy;
+ UniqueBuffer UncompressedBlockCopy;
+
+ const size_t FirstBlockIndex = uint64_t(RawOffset / BlockSize);
+ const size_t LastBlockIndex = uint64_t((RawOffset + RawView.GetSize() - 1) / BlockSize);
+ const uint64_t LastBlockSize = BlockSize - ((Header.BlockCount * BlockSize) - Header.TotalRawSize);
+ uint64_t OffsetInFirstBlock = RawOffset % BlockSize;
+ uint64_t CompressedOffset = sizeof(BufferHeader) + uint64_t(Header.BlockCount) * sizeof(uint32_t);
+ uint64_t RemainingRawSize = RawView.GetSize();
+
+ for (size_t BlockIndex = 0; BlockIndex < FirstBlockIndex; BlockIndex++)
+ {
+ const uint32_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]);
+ CompressedOffset += CompressedBlockSize;
+ }
+
+ for (size_t BlockIndex = FirstBlockIndex; BlockIndex <= LastBlockIndex; BlockIndex++)
+ {
+ const uint64_t UncompressedBlockSize = BlockIndex == Header.BlockCount - 1 ? LastBlockSize : BlockSize;
+ const uint32_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]);
+ const bool IsCompressed = CompressedBlockSize < UncompressedBlockSize;
+
+ const uint64_t BytesToUncompress = OffsetInFirstBlock > 0 ? zen::Min(RawView.GetSize(), UncompressedBlockSize - OffsetInFirstBlock)
+ : zen::Min(RemainingRawSize, BlockSize);
+
+ MemoryView CompressedBlock = CompressedData.ViewOrCopyRange(CompressedOffset, CompressedBlockSize, CompressedBlockCopy);
+
+ if (IsCompressed)
+ {
+ MutableMemoryView UncompressedBlock = RawView.Left(BytesToUncompress);
+
+ const bool IsAligned = BytesToUncompress == UncompressedBlockSize;
+ if (!IsAligned)
+ {
+ // Decompress to a temporary buffer when the first or the last block reads are not aligned with the block boundaries.
+ if (UncompressedBlockCopy.IsNull())
+ {
+ UncompressedBlockCopy = UniqueBuffer::Alloc(BlockSize);
+ }
+ UncompressedBlock = UncompressedBlockCopy.GetMutableView().Mid(0, UncompressedBlockSize);
+ }
+
+ if (!DecompressBlock(UncompressedBlock, CompressedBlock))
+ {
+ return false;
+ }
+
+ if (!IsAligned)
+ {
+ RawView.CopyFrom(UncompressedBlock.Mid(OffsetInFirstBlock, BytesToUncompress));
+ }
+ }
+ else
+ {
+ RawView.CopyFrom(CompressedBlock.Mid(OffsetInFirstBlock, BytesToUncompress));
+ }
+
+ OffsetInFirstBlock = 0;
+ RemainingRawSize -= BytesToUncompress;
+ CompressedOffset += CompressedBlockSize;
+ RawView += BytesToUncompress;
+ }
+
+ return RemainingRawSize == 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct OodleInit
+{
+ OodleInit()
+ {
+ OodleConfigValues Config;
+ Oodle_GetConfigValues(&Config);
+ // Always read/write Oodle v9 binary data.
+ Config.m_OodleLZ_BackwardsCompatible_MajorVersion = 9;
+ Oodle_SetConfigValues(&Config);
+ }
+};
+
+OodleInit InitOodle;
+
+class OodleEncoder final : public BlockEncoder
+{
+public:
+ OodleEncoder(OodleCompressor InCompressor, OodleCompressionLevel InCompressionLevel)
+ : Compressor(InCompressor)
+ , CompressionLevel(InCompressionLevel)
+ {
+ }
+
+protected:
+ CompressionMethod GetMethod() const final { return CompressionMethod::Oodle; }
+ uint8_t GetCompressor() const final { return static_cast<uint8_t>(Compressor); }
+ uint8_t GetCompressionLevel() const final { return static_cast<uint8_t>(CompressionLevel); }
+
+ uint64_t CompressBlockBound(uint64_t RawSize) const final
+ {
+ return static_cast<uint64_t>(OodleLZ_GetCompressedBufferSizeNeeded(OodleLZ_Compressor_Kraken, static_cast<OO_SINTa>(RawSize)));
+ }
+
+ bool CompressBlock(MutableMemoryView& CompressedData, MemoryView RawData) const final
+ {
+ const OodleLZ_Compressor LZCompressor = GetOodleLZCompressor(Compressor);
+ const OodleLZ_CompressionLevel LZCompressionLevel = GetOodleLZCompressionLevel(CompressionLevel);
+ if (LZCompressor == OodleLZ_Compressor_Invalid || LZCompressionLevel == OodleLZ_CompressionLevel_Invalid ||
+ LZCompressionLevel == OodleLZ_CompressionLevel_None)
+ {
+ return false;
+ }
+
+ const OO_SINTa RawSize = static_cast<OO_SINTa>(RawData.GetSize());
+ if (static_cast<OO_SINTa>(CompressedData.GetSize()) < OodleLZ_GetCompressedBufferSizeNeeded(LZCompressor, RawSize))
+ {
+ return false;
+ }
+
+ const OO_SINTa Size = OodleLZ_Compress(LZCompressor, RawData.GetData(), RawSize, CompressedData.GetData(), LZCompressionLevel);
+ CompressedData.LeftInline(static_cast<uint64_t>(Size));
+ return Size > 0;
+ }
+
+ static OodleLZ_Compressor GetOodleLZCompressor(OodleCompressor Compressor)
+ {
+ switch (Compressor)
+ {
+ case OodleCompressor::Selkie:
+ return OodleLZ_Compressor_Selkie;
+ case OodleCompressor::Mermaid:
+ return OodleLZ_Compressor_Mermaid;
+ case OodleCompressor::Kraken:
+ return OodleLZ_Compressor_Kraken;
+ case OodleCompressor::Leviathan:
+ return OodleLZ_Compressor_Leviathan;
+ case OodleCompressor::NotSet:
+ default:
+ return OodleLZ_Compressor_Invalid;
+ }
+ }
+
+ static OodleLZ_CompressionLevel GetOodleLZCompressionLevel(OodleCompressionLevel Level)
+ {
+ const int IntLevel = (int)Level;
+ if (IntLevel < (int)OodleLZ_CompressionLevel_Min || IntLevel > (int)OodleLZ_CompressionLevel_Max)
+ {
+ return OodleLZ_CompressionLevel_Invalid;
+ }
+ return OodleLZ_CompressionLevel(IntLevel);
+ }
+
+private:
+ const OodleCompressor Compressor;
+ const OodleCompressionLevel CompressionLevel;
+};
+
+class OodleDecoder final : public BlockDecoder
+{
+protected:
+ bool DecompressBlock(MutableMemoryView RawData, MemoryView CompressedData) const final
+ {
+ const OO_SINTa RawSize = static_cast<OO_SINTa>(RawData.GetSize());
+ const OO_SINTa Size = OodleLZ_Decompress(CompressedData.GetData(),
+ static_cast<OO_SINTa>(CompressedData.GetSize()),
+ RawData.GetData(),
+ RawSize,
+ OodleLZ_FuzzSafe_Yes,
+ OodleLZ_CheckCRC_Yes,
+ OodleLZ_Verbosity_None);
+ return Size == RawSize;
+ }
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+class LZ4Decoder final : public BlockDecoder
+{
+protected:
+ bool DecompressBlock(MutableMemoryView RawData, MemoryView CompressedData) const final
+ {
+ if (CompressedData.GetSize() <= std::numeric_limits<int>::max())
+ {
+ const int Size = LZ4_decompress_safe(static_cast<const char*>(CompressedData.GetData()),
+ static_cast<char*>(RawData.GetData()),
+ static_cast<int>(CompressedData.GetSize()),
+ static_cast<int>(zen::Min<uint64_t>(RawData.GetSize(), uint64_t(LZ4_MAX_INPUT_SIZE))));
+ return static_cast<uint64_t>(Size) == RawData.GetSize();
+ }
+ return false;
+ }
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+static const BaseDecoder*
+GetDecoder(CompressionMethod Method)
+{
+ static NoneDecoder None;
+ static OodleDecoder Oodle;
+ static LZ4Decoder LZ4;
+
+ switch (Method)
+ {
+ default:
+ return nullptr;
+ case CompressionMethod::None:
+ return &None;
+ case CompressionMethod::Oodle:
+ return &Oodle;
+ case CompressionMethod::LZ4:
+ return &LZ4;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+bool
+BufferHeader::IsValid(const CompositeBuffer& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize)
+{
+ uint64_t Size = CompressedData.GetSize();
+ if (Size < sizeof(BufferHeader))
+ {
+ return false;
+ }
+ const size_t StackBufferSize = 256;
+ uint8_t StackBuffer[StackBufferSize];
+ uint64_t ReadSize = Min(Size, StackBufferSize);
+ BufferHeader* Header = reinterpret_cast<BufferHeader*>(StackBuffer);
+ {
+ CompositeBuffer::Iterator It;
+ CompressedData.CopyTo(MutableMemoryView(StackBuffer, StackBuffer + StackBufferSize), It);
+ }
+ Header->ByteSwap();
+ if (Header->Magic != BufferHeader::ExpectedMagic)
+ {
+ return false;
+ }
+ const BaseDecoder* const Decoder = GetDecoder(Header->Method);
+ if (!Decoder)
+ {
+ return false;
+ }
+ uint32_t Crc32 = Header->Crc32;
+ OutRawHash = IoHash::FromBLAKE3(Header->RawHash);
+ OutRawSize = Header->TotalRawSize;
+ uint64_t HeaderSize = Decoder->GetHeaderSize(*Header);
+ Header->ByteSwap();
+
+ if (HeaderSize > ReadSize)
+ {
+ // 0.004% of cases on a Fortnite hot cache cook
+ UniqueBuffer HeaderCopy = UniqueBuffer::Alloc(HeaderSize);
+ CompositeBuffer::Iterator It;
+ CompressedData.CopyTo(HeaderCopy.GetMutableView(), It);
+ const MemoryView HeaderView = HeaderCopy.GetView();
+ if (Crc32 != BufferHeader::CalculateCrc32(HeaderView))
+ {
+ return false;
+ }
+ }
+ else
+ {
+ MemoryView FullHeaderView(StackBuffer, StackBuffer + HeaderSize);
+ if (Crc32 != BufferHeader::CalculateCrc32(FullHeaderView))
+ {
+ return false;
+ }
+ }
+ return true;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+template<typename BufferType>
+inline CompositeBuffer
+ValidBufferOrEmpty(BufferType&& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize)
+{
+ return BufferHeader::IsValid(CompressedData, OutRawHash, OutRawSize) ? CompositeBuffer(std::forward<BufferType>(CompressedData))
+ : CompositeBuffer();
+}
+
+CompositeBuffer
+CopyCompressedRange(const BufferHeader& Header, const CompositeBuffer& CompressedData, uint64_t RawOffset, uint64_t RawSize)
+{
+ if (Header.TotalRawSize < RawOffset + RawSize)
+ {
+ return CompositeBuffer();
+ }
+
+ if (Header.Method == CompressionMethod::None)
+ {
+ UniqueBuffer NewCompressedData = UniqueBuffer::Alloc(RawSize);
+ CompressedData.CopyTo(NewCompressedData.GetMutableView(), sizeof(Header) + RawOffset);
+
+ BufferHeader NewHeader = Header;
+ NewHeader.Crc32 = 0;
+ NewHeader.TotalRawSize = RawSize;
+ NewHeader.TotalCompressedSize = NewHeader.TotalRawSize + sizeof(BufferHeader);
+ NewHeader.RawHash = BLAKE3();
+
+ UniqueBuffer HeaderData = UniqueBuffer::Alloc(sizeof(BufferHeader));
+ NewHeader.Write(HeaderData);
+
+ return CompositeBuffer(HeaderData.MoveToShared(), NewCompressedData.MoveToShared());
+ }
+ else
+ {
+ UniqueBuffer BlockSizeBuffer;
+ MemoryView BlockSizeView =
+ CompressedData.ViewOrCopyRange(sizeof(BufferHeader), Header.BlockCount * sizeof(uint32_t), BlockSizeBuffer);
+ std::span<uint32_t const> CompressedBlockSizes(reinterpret_cast<const uint32_t*>(BlockSizeView.GetData()), Header.BlockCount);
+
+ const uint64_t BlockSize = uint64_t(1) << Header.BlockSizeExponent;
+ const uint64_t LastBlockSize = BlockSize - ((Header.BlockCount * BlockSize) - Header.TotalRawSize);
+ const size_t FirstBlock = uint64_t(RawOffset / BlockSize);
+ const size_t LastBlock = uint64_t((RawOffset + RawSize - 1) / BlockSize);
+ uint64_t CompressedOffset = sizeof(BufferHeader) + uint64_t(Header.BlockCount) * sizeof(uint32_t);
+
+ const uint64_t NewBlockCount = LastBlock - FirstBlock + 1;
+ const uint64_t NewMetaSize = NewBlockCount * sizeof(uint32_t);
+ uint64_t NewCompressedSize = 0;
+ uint64_t NewTotalRawSize = 0;
+ std::vector<uint32_t> NewCompressedBlockSizes;
+
+ NewCompressedBlockSizes.reserve(NewBlockCount);
+ for (size_t BlockIndex = FirstBlock; BlockIndex <= LastBlock; ++BlockIndex)
+ {
+ const uint64_t UncompressedBlockSize = (BlockIndex == Header.BlockCount - 1) ? LastBlockSize : BlockSize;
+ NewTotalRawSize += UncompressedBlockSize;
+
+ const uint32_t CompressedBlockSize = CompressedBlockSizes[BlockIndex];
+ NewCompressedBlockSizes.push_back(CompressedBlockSize);
+ NewCompressedSize += ByteSwap(CompressedBlockSize);
+ }
+
+ const uint64_t NewTotalCompressedSize = sizeof(BufferHeader) + NewBlockCount * sizeof(uint32_t) + NewCompressedSize;
+ UniqueBuffer NewCompressedData = UniqueBuffer::Alloc(NewTotalCompressedSize);
+ MutableMemoryView NewCompressedBlocks = NewCompressedData.GetMutableView() + sizeof(BufferHeader) + NewMetaSize;
+
+ // Seek to first compressed block
+ for (size_t BlockIndex = 0; BlockIndex < FirstBlock; ++BlockIndex)
+ {
+ const uint64_t CompressedBlockSize = ByteSwap(CompressedBlockSizes[BlockIndex]);
+ CompressedOffset += CompressedBlockSize;
+ }
+
+ // Copy blocks
+ UniqueBuffer CompressedBlockCopy;
+ const MemoryView CompressedRange = CompressedData.ViewOrCopyRange(CompressedOffset, NewCompressedSize, CompressedBlockCopy);
+ NewCompressedBlocks.CopyFrom(CompressedRange);
+
+ // Copy block sizes
+ NewCompressedData.GetMutableView().Mid(sizeof(BufferHeader), NewMetaSize).CopyFrom(MakeMemoryView(NewCompressedBlockSizes));
+
+ BufferHeader NewHeader;
+ NewHeader.Crc32 = 0;
+ NewHeader.Method = Header.Method;
+ NewHeader.Compressor = Header.Compressor;
+ NewHeader.CompressionLevel = Header.CompressionLevel;
+ NewHeader.BlockSizeExponent = Header.BlockSizeExponent;
+ NewHeader.BlockCount = static_cast<uint32_t>(NewBlockCount);
+ NewHeader.TotalRawSize = NewTotalRawSize;
+ NewHeader.TotalCompressedSize = NewTotalCompressedSize;
+ NewHeader.RawHash = BLAKE3();
+ NewHeader.Write(NewCompressedData.GetMutableView().Left(sizeof(BufferHeader) + NewMetaSize));
+
+ return CompositeBuffer(NewCompressedData.MoveToShared());
+ }
+}
+
+} // namespace zen::detail
+
+namespace zen {
+
+const CompressedBuffer CompressedBuffer::Null;
+
+CompressedBuffer
+CompressedBuffer::Compress(const CompositeBuffer& RawData,
+ OodleCompressor Compressor,
+ OodleCompressionLevel CompressionLevel,
+ uint64_t BlockSize)
+{
+ using namespace detail;
+
+ if (BlockSize == 0)
+ {
+ BlockSize = DefaultBlockSize;
+ }
+
+ CompressedBuffer Local;
+ if (CompressionLevel == OodleCompressionLevel::None)
+ {
+ Local.CompressedData = NoneEncoder().Compress(RawData, BlockSize);
+ }
+ else
+ {
+ Local.CompressedData = OodleEncoder(Compressor, CompressionLevel).Compress(RawData, BlockSize);
+ }
+ return Local;
+}
+
+CompressedBuffer
+CompressedBuffer::Compress(const SharedBuffer& RawData,
+ OodleCompressor Compressor,
+ OodleCompressionLevel CompressionLevel,
+ uint64_t BlockSize)
+{
+ return Compress(CompositeBuffer(RawData), Compressor, CompressionLevel, BlockSize);
+}
+
+CompressedBuffer
+CompressedBuffer::FromCompressed(const CompositeBuffer& InCompressedData, IoHash& OutRawHash, uint64_t& OutRawSize)
+{
+ CompressedBuffer Local;
+ Local.CompressedData = detail::ValidBufferOrEmpty(InCompressedData, OutRawHash, OutRawSize);
+ return Local;
+}
+
+CompressedBuffer
+CompressedBuffer::FromCompressed(CompositeBuffer&& InCompressedData, IoHash& OutRawHash, uint64_t& OutRawSize)
+{
+ CompressedBuffer Local;
+ Local.CompressedData = detail::ValidBufferOrEmpty(std::move(InCompressedData), OutRawHash, OutRawSize);
+ return Local;
+}
+
+CompressedBuffer
+CompressedBuffer::FromCompressed(const SharedBuffer& InCompressedData, IoHash& OutRawHash, uint64_t& OutRawSize)
+{
+ CompressedBuffer Local;
+ Local.CompressedData = detail::ValidBufferOrEmpty(InCompressedData, OutRawHash, OutRawSize);
+ return Local;
+}
+
+CompressedBuffer
+CompressedBuffer::FromCompressed(SharedBuffer&& InCompressedData, IoHash& OutRawHash, uint64_t& OutRawSize)
+{
+ CompressedBuffer Local;
+ Local.CompressedData = detail::ValidBufferOrEmpty(std::move(InCompressedData), OutRawHash, OutRawSize);
+ return Local;
+}
+
+CompressedBuffer
+CompressedBuffer::FromCompressedNoValidate(IoBuffer&& InCompressedData)
+{
+ if (InCompressedData.GetSize() <= sizeof(detail::BufferHeader))
+ {
+ return CompressedBuffer();
+ }
+ CompressedBuffer Local;
+ Local.CompressedData = CompositeBuffer(SharedBuffer(std::move(InCompressedData)));
+ return Local;
+}
+
+CompressedBuffer
+CompressedBuffer::FromCompressedNoValidate(CompositeBuffer&& InCompressedData)
+{
+ if (InCompressedData.GetSize() <= sizeof(detail::BufferHeader))
+ {
+ return CompressedBuffer();
+ }
+ CompressedBuffer Local;
+ Local.CompressedData = std::move(InCompressedData);
+ return Local;
+}
+
+bool
+CompressedBuffer::ValidateCompressedHeader(IoBuffer&& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize)
+{
+ return detail::BufferHeader::IsValid(SharedBuffer(std::move(CompressedData)), OutRawHash, OutRawSize);
+}
+
+bool
+CompressedBuffer::ValidateCompressedHeader(const IoBuffer& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize)
+{
+ return detail::BufferHeader::IsValid(SharedBuffer(CompressedData), OutRawHash, OutRawSize);
+}
+
+uint64_t
+CompressedBuffer::DecodeRawSize() const
+{
+ return CompressedData ? detail::BufferHeader::Read(CompressedData).TotalRawSize : 0;
+}
+
+IoHash
+CompressedBuffer::DecodeRawHash() const
+{
+ return CompressedData ? IoHash::FromBLAKE3(detail::BufferHeader::Read(CompressedData).RawHash) : IoHash();
+}
+
+CompressedBuffer
+CompressedBuffer::CopyRange(uint64_t RawOffset, uint64_t RawSize) const
+{
+ using namespace detail;
+ const BufferHeader Header = BufferHeader::Read(CompressedData);
+ const uint64_t TotalRawSize = RawSize < ~uint64_t(0) ? RawSize : Header.TotalRawSize - RawOffset;
+
+ CompressedBuffer Range;
+ Range.CompressedData = CopyCompressedRange(Header, CompressedData, RawOffset, TotalRawSize);
+
+ return Range;
+}
+
+bool
+CompressedBuffer::TryDecompressTo(MutableMemoryView RawView, uint64_t RawOffset) const
+{
+ using namespace detail;
+ if (CompressedData)
+ {
+ const BufferHeader Header = BufferHeader::Read(CompressedData);
+ if (Header.Magic == BufferHeader::ExpectedMagic)
+ {
+ if (const BaseDecoder* const Decoder = GetDecoder(Header.Method))
+ {
+ return Decoder->TryDecompressTo(Header, CompressedData, RawView, RawOffset);
+ }
+ }
+ }
+ return false;
+}
+
+SharedBuffer
+CompressedBuffer::Decompress(uint64_t RawOffset, uint64_t RawSize) const
+{
+ using namespace detail;
+ if (CompressedData && RawSize > 0)
+ {
+ const BufferHeader Header = BufferHeader::Read(CompressedData);
+ if (Header.Magic == BufferHeader::ExpectedMagic)
+ {
+ if (const BaseDecoder* const Decoder = GetDecoder(Header.Method))
+ {
+ const uint64_t TotalRawSize = RawSize < ~uint64_t(0) ? RawSize : Header.TotalRawSize - RawOffset;
+ UniqueBuffer RawData = UniqueBuffer::Alloc(TotalRawSize);
+ if (Decoder->TryDecompressTo(Header, CompressedData, RawData, RawOffset))
+ {
+ return RawData.MoveToShared();
+ }
+ }
+ }
+ }
+ return SharedBuffer();
+}
+
+CompositeBuffer
+CompressedBuffer::DecompressToComposite() const
+{
+ using namespace detail;
+ if (CompressedData)
+ {
+ const BufferHeader Header = BufferHeader::Read(CompressedData);
+ if (Header.Magic == BufferHeader::ExpectedMagic)
+ {
+ if (const BaseDecoder* const Decoder = GetDecoder(Header.Method))
+ {
+ return Decoder->Decompress(Header, CompressedData);
+ }
+ }
+ }
+ return CompositeBuffer();
+}
+
+bool
+CompressedBuffer::TryGetCompressParameters(OodleCompressor& OutCompressor,
+ OodleCompressionLevel& OutCompressionLevel,
+ uint64_t& OutBlockSize) const
+{
+ using namespace detail;
+ if (CompressedData)
+ {
+ switch (const BufferHeader Header = BufferHeader::Read(CompressedData); Header.Method)
+ {
+ case CompressionMethod::None:
+ OutCompressor = OodleCompressor::NotSet;
+ OutCompressionLevel = OodleCompressionLevel::None;
+ OutBlockSize = 0;
+ return true;
+ case CompressionMethod::Oodle:
+ OutCompressor = OodleCompressor(Header.Compressor);
+ OutCompressionLevel = OodleCompressionLevel(Header.CompressionLevel);
+ OutBlockSize = uint64_t(1) << Header.BlockSizeExponent;
+ return true;
+ default:
+ break;
+ }
+ }
+ return false;
+}
+
+/**
+ ______________________ _____________________________
+ \__ ___/\_ _____// _____/\__ ___/ _____/
+ | | | __)_ \_____ \ | | \_____ \
+ | | | \/ \ | | / \
+ |____| /_______ /_______ / |____| /_______ /
+ \/ \/ \/
+ */
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("CompressedBuffer")
+{
+ uint8_t Zeroes[1024]{};
+ uint8_t Ones[1024];
+ memset(Ones, 1, sizeof Ones);
+
+ {
+ CompressedBuffer Buffer = CompressedBuffer::Compress(CompositeBuffer(SharedBuffer::MakeView(MakeMemoryView(Zeroes))),
+ OodleCompressor::NotSet,
+ OodleCompressionLevel::None);
+
+ CHECK(Buffer.DecodeRawSize() == sizeof(Zeroes));
+ CHECK(Buffer.GetCompressedSize() == (sizeof(Zeroes) + sizeof(detail::BufferHeader)));
+
+ CompositeBuffer Compressed = Buffer.GetCompressed();
+ IoHash DecodedHash;
+ uint64_t DecodedRawSize;
+ CompressedBuffer BufferD = CompressedBuffer::FromCompressed(Compressed, DecodedHash, DecodedRawSize);
+
+ CHECK(BufferD.IsNull() == false);
+
+ CompositeBuffer Decomp = BufferD.DecompressToComposite();
+
+ CHECK(Decomp.GetSize() == DecodedRawSize);
+ CHECK(IoHash::HashBuffer(Decomp) == DecodedHash);
+ }
+
+ {
+ CompressedBuffer Buffer = CompressedBuffer::Compress(
+ CompositeBuffer(SharedBuffer::MakeView(MakeMemoryView(Zeroes)), SharedBuffer::MakeView(MakeMemoryView(Ones))),
+ OodleCompressor::NotSet,
+ OodleCompressionLevel::None);
+
+ CHECK(Buffer.DecodeRawSize() == (sizeof(Zeroes) + sizeof(Ones)));
+ CHECK(Buffer.GetCompressedSize() == (sizeof(Zeroes) + sizeof(Ones) + sizeof(detail::BufferHeader)));
+
+ CompositeBuffer Compressed = Buffer.GetCompressed();
+ IoHash DecodedHash;
+ uint64_t DecodedRawSize;
+ CompressedBuffer BufferD = CompressedBuffer::FromCompressed(Compressed, DecodedHash, DecodedRawSize);
+
+ CHECK(BufferD.IsNull() == false);
+
+ CompositeBuffer Decomp = BufferD.DecompressToComposite();
+
+ CHECK(Decomp.GetSize() == DecodedRawSize);
+ CHECK(IoHash::HashBuffer(Decomp) == DecodedHash);
+ }
+
+ {
+ CompressedBuffer Buffer = CompressedBuffer::Compress(CompositeBuffer(SharedBuffer::MakeView(MakeMemoryView(Zeroes))));
+
+ CHECK(Buffer.DecodeRawSize() == sizeof(Zeroes));
+ CHECK(Buffer.GetCompressedSize() < sizeof(Zeroes));
+
+ CompositeBuffer Compressed = Buffer.GetCompressed();
+ IoHash DecodedHash;
+ uint64_t DecodedRawSize;
+ CompressedBuffer BufferD = CompressedBuffer::FromCompressed(Compressed, DecodedHash, DecodedRawSize);
+
+ CHECK(BufferD.IsNull() == false);
+
+ CompositeBuffer Decomp = BufferD.DecompressToComposite();
+
+ CHECK(Decomp.GetSize() == DecodedRawSize);
+ CHECK(IoHash::HashBuffer(Decomp) == DecodedHash);
+ }
+
+ {
+ CompressedBuffer Buffer = CompressedBuffer::Compress(
+ CompositeBuffer(SharedBuffer::MakeView(MakeMemoryView(Zeroes)), SharedBuffer::MakeView(MakeMemoryView(Ones))));
+
+ CHECK(Buffer.DecodeRawSize() == (sizeof(Zeroes) + sizeof(Ones)));
+ CHECK(Buffer.GetCompressedSize() < (sizeof(Zeroes) + sizeof(Ones)));
+
+ CompositeBuffer Compressed = Buffer.GetCompressed();
+ IoHash DecodedHash;
+ uint64_t DecodedRawSize;
+ CompressedBuffer BufferD = CompressedBuffer::FromCompressed(Compressed, DecodedHash, DecodedRawSize);
+
+ CHECK(BufferD.IsNull() == false);
+
+ CompositeBuffer Decomp = BufferD.DecompressToComposite();
+
+ CHECK(Decomp.GetSize() == DecodedRawSize);
+ CHECK(IoHash::HashBuffer(Decomp) == DecodedHash);
+ }
+
+ auto GenerateData = [](uint64_t N) -> std::vector<uint64_t> {
+ std::vector<uint64_t> Data;
+ Data.resize(N);
+ for (size_t Idx = 0; Idx < Data.size(); ++Idx)
+ {
+ Data[Idx] = Idx;
+ }
+ return Data;
+ };
+
+ auto ValidateData = [](std::span<uint64_t const> Values, std::span<uint64_t const> ExpectedValues, uint64_t Offset) {
+ for (size_t Idx = Offset; uint64_t Value : Values)
+ {
+ const uint64_t ExpectedValue = ExpectedValues[Idx++];
+ CHECK(Value == ExpectedValue);
+ }
+ };
+
+ SUBCASE("decompress with offset and size")
+ {
+ auto UncompressAndValidate = [&ValidateData](CompressedBuffer Compressed,
+ uint64_t OffsetCount,
+ uint64_t Count,
+ const std::vector<uint64_t>& ExpectedValues) {
+ SharedBuffer Uncompressed = Compressed.Decompress(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t));
+ CHECK(Uncompressed.GetSize() == Count * sizeof(uint64_t));
+ std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t));
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ };
+
+ const uint64_t BlockSize = 64 * sizeof(uint64_t);
+ const uint64_t N = 5000;
+ std::vector<uint64_t> ExpectedValues = GenerateData(N);
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)),
+ OodleCompressor::Mermaid,
+ OodleCompressionLevel::Optimal4,
+ BlockSize);
+ UncompressAndValidate(Compressed, 0, N, ExpectedValues);
+ UncompressAndValidate(Compressed, 1, N - 1, ExpectedValues);
+ UncompressAndValidate(Compressed, N - 1, 1, ExpectedValues);
+ UncompressAndValidate(Compressed, 0, 1, ExpectedValues);
+ UncompressAndValidate(Compressed, 2, 4, ExpectedValues);
+ UncompressAndValidate(Compressed, 0, 512, ExpectedValues);
+ UncompressAndValidate(Compressed, 3, 514, ExpectedValues);
+ UncompressAndValidate(Compressed, 256, 512, ExpectedValues);
+ UncompressAndValidate(Compressed, 512, 512, ExpectedValues);
+ }
+
+ SUBCASE("decompress with offset only")
+ {
+ const uint64_t BlockSize = 64 * sizeof(uint64_t);
+ const uint64_t N = 1000;
+ std::vector<uint64_t> ExpectedValues = GenerateData(N);
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)),
+ OodleCompressor::Mermaid,
+ OodleCompressionLevel::Optimal4,
+ BlockSize);
+ const uint64_t OffsetCount = 150;
+ SharedBuffer Uncompressed = Compressed.Decompress(OffsetCount * sizeof(uint64_t));
+ std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t));
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+
+ SUBCASE("decompress buffer with one block")
+ {
+ const uint64_t BlockSize = 256 * sizeof(uint64_t);
+ const uint64_t N = 100;
+ std::vector<uint64_t> ExpectedValues = GenerateData(N);
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)),
+ OodleCompressor::Mermaid,
+ OodleCompressionLevel::Optimal4,
+ BlockSize);
+ const uint64_t OffsetCount = 2;
+ const uint64_t Count = 50;
+ SharedBuffer Uncompressed = Compressed.Decompress(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t));
+ std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t));
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+
+ SUBCASE("decompress uncompressed buffer")
+ {
+ const uint64_t N = 4242;
+ std::vector<uint64_t> ExpectedValues = GenerateData(N);
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)),
+ OodleCompressor::NotSet,
+ OodleCompressionLevel::None);
+ {
+ const uint64_t OffsetCount = 0;
+ const uint64_t Count = N;
+ SharedBuffer Uncompressed = Compressed.Decompress(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t));
+ std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t));
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+
+ {
+ const uint64_t OffsetCount = 21;
+ const uint64_t Count = 999;
+ SharedBuffer Uncompressed = Compressed.Decompress(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t));
+ std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t));
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+ }
+
+ SUBCASE("copy range")
+ {
+ const uint64_t BlockSize = 64 * sizeof(uint64_t);
+ const uint64_t N = 1000;
+ std::vector<uint64_t> ExpectedValues = GenerateData(N);
+
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)),
+ OodleCompressor::Mermaid,
+ OodleCompressionLevel::Optimal4,
+ BlockSize);
+
+ {
+ const uint64_t OffsetCount = 0;
+ const uint64_t Count = N;
+ SharedBuffer Uncompressed = Compressed.CopyRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress();
+ std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t));
+ CHECK(Values.size() == Count);
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+
+ {
+ const uint64_t OffsetCount = 64;
+ const uint64_t Count = N - 64;
+ SharedBuffer Uncompressed = Compressed.CopyRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress();
+ std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t));
+ CHECK(Values.size() == Count);
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+
+ {
+ const uint64_t OffsetCount = 64 * 2 + 32;
+ const uint64_t Count = N - OffsetCount;
+ const uint64_t RawOffset = OffsetCount * sizeof(uint64_t);
+ const uint64_t RawSize = Count * sizeof(uint64_t);
+ uint64_t FirstBlockOffset = RawOffset % BlockSize;
+
+ SharedBuffer Uncompressed = Compressed.CopyRange(RawOffset, RawSize).Decompress();
+ std::span<uint64_t const> AllValues((const uint64_t*)Uncompressed.GetData(), RawSize / sizeof(uint64_t));
+ std::span<uint64_t const> Values((const uint64_t*)(((const uint8_t*)(Uncompressed.GetData()) + FirstBlockOffset)),
+ RawSize / sizeof(uint64_t));
+ CHECK(Values.size() == Count);
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+
+ {
+ const uint64_t OffsetCount = 64 * 2 + 63;
+ const uint64_t Count = N - OffsetCount - 5;
+ const uint64_t RawOffset = OffsetCount * sizeof(uint64_t);
+ const uint64_t RawSize = Count * sizeof(uint64_t);
+ uint64_t FirstBlockOffset = RawOffset % BlockSize;
+
+ SharedBuffer Uncompressed = Compressed.CopyRange(RawOffset, RawSize).Decompress();
+ std::span<uint64_t const> AllValues((const uint64_t*)Uncompressed.GetData(), RawSize / sizeof(uint64_t));
+ std::span<uint64_t const> Values((const uint64_t*)(((const uint8_t*)(Uncompressed.GetData()) + FirstBlockOffset)),
+ RawSize / sizeof(uint64_t));
+ CHECK(Values.size() == Count);
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+ }
+
+ SUBCASE("copy uncompressed range")
+ {
+ const uint64_t N = 1000;
+ std::vector<uint64_t> ExpectedValues = GenerateData(N);
+
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(MakeMemoryView(ExpectedValues)),
+ OodleCompressor::NotSet,
+ OodleCompressionLevel::None);
+
+ {
+ const uint64_t OffsetCount = 0;
+ const uint64_t Count = N;
+ SharedBuffer Uncompressed = Compressed.CopyRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress();
+ std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t));
+ CHECK(Values.size() == Count);
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+
+ {
+ const uint64_t OffsetCount = 1;
+ const uint64_t Count = N - OffsetCount;
+ SharedBuffer Uncompressed = Compressed.CopyRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress();
+ std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t));
+ CHECK(Values.size() == Count);
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+
+ {
+ const uint64_t OffsetCount = 42;
+ const uint64_t Count = 100;
+ SharedBuffer Uncompressed = Compressed.CopyRange(OffsetCount * sizeof(uint64_t), Count * sizeof(uint64_t)).Decompress();
+ std::span<uint64_t const> Values((const uint64_t*)Uncompressed.GetData(), Uncompressed.GetSize() / sizeof(uint64_t));
+ CHECK(Values.size() == Count);
+ ValidateData(Values, ExpectedValues, OffsetCount);
+ }
+ }
+}
+
+void
+compress_forcelink()
+{
+}
+#endif
+
+} // namespace zen
diff --git a/src/zencore/crc32.cpp b/src/zencore/crc32.cpp
new file mode 100644
index 000000000..d4a3cac57
--- /dev/null
+++ b/src/zencore/crc32.cpp
@@ -0,0 +1,545 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencore/crc32.h"
+
+namespace CRC32 {
+
+static const uint32_t CRCTable_DEPRECATED[256] = {
+ 0x00000000, 0x04C11DB7, 0x09823B6E, 0x0D4326D9, 0x130476DC, 0x17C56B6B, 0x1A864DB2, 0x1E475005, 0x2608EDB8, 0x22C9F00F, 0x2F8AD6D6,
+ 0x2B4BCB61, 0x350C9B64, 0x31CD86D3, 0x3C8EA00A, 0x384FBDBD, 0x4C11DB70, 0x48D0C6C7, 0x4593E01E, 0x4152FDA9, 0x5F15ADAC, 0x5BD4B01B,
+ 0x569796C2, 0x52568B75, 0x6A1936C8, 0x6ED82B7F, 0x639B0DA6, 0x675A1011, 0x791D4014, 0x7DDC5DA3, 0x709F7B7A, 0x745E66CD, 0x9823B6E0,
+ 0x9CE2AB57, 0x91A18D8E, 0x95609039, 0x8B27C03C, 0x8FE6DD8B, 0x82A5FB52, 0x8664E6E5, 0xBE2B5B58, 0xBAEA46EF, 0xB7A96036, 0xB3687D81,
+ 0xAD2F2D84, 0xA9EE3033, 0xA4AD16EA, 0xA06C0B5D, 0xD4326D90, 0xD0F37027, 0xDDB056FE, 0xD9714B49, 0xC7361B4C, 0xC3F706FB, 0xCEB42022,
+ 0xCA753D95, 0xF23A8028, 0xF6FB9D9F, 0xFBB8BB46, 0xFF79A6F1, 0xE13EF6F4, 0xE5FFEB43, 0xE8BCCD9A, 0xEC7DD02D, 0x34867077, 0x30476DC0,
+ 0x3D044B19, 0x39C556AE, 0x278206AB, 0x23431B1C, 0x2E003DC5, 0x2AC12072, 0x128E9DCF, 0x164F8078, 0x1B0CA6A1, 0x1FCDBB16, 0x018AEB13,
+ 0x054BF6A4, 0x0808D07D, 0x0CC9CDCA, 0x7897AB07, 0x7C56B6B0, 0x71159069, 0x75D48DDE, 0x6B93DDDB, 0x6F52C06C, 0x6211E6B5, 0x66D0FB02,
+ 0x5E9F46BF, 0x5A5E5B08, 0x571D7DD1, 0x53DC6066, 0x4D9B3063, 0x495A2DD4, 0x44190B0D, 0x40D816BA, 0xACA5C697, 0xA864DB20, 0xA527FDF9,
+ 0xA1E6E04E, 0xBFA1B04B, 0xBB60ADFC, 0xB6238B25, 0xB2E29692, 0x8AAD2B2F, 0x8E6C3698, 0x832F1041, 0x87EE0DF6, 0x99A95DF3, 0x9D684044,
+ 0x902B669D, 0x94EA7B2A, 0xE0B41DE7, 0xE4750050, 0xE9362689, 0xEDF73B3E, 0xF3B06B3B, 0xF771768C, 0xFA325055, 0xFEF34DE2, 0xC6BCF05F,
+ 0xC27DEDE8, 0xCF3ECB31, 0xCBFFD686, 0xD5B88683, 0xD1799B34, 0xDC3ABDED, 0xD8FBA05A, 0x690CE0EE, 0x6DCDFD59, 0x608EDB80, 0x644FC637,
+ 0x7A089632, 0x7EC98B85, 0x738AAD5C, 0x774BB0EB, 0x4F040D56, 0x4BC510E1, 0x46863638, 0x42472B8F, 0x5C007B8A, 0x58C1663D, 0x558240E4,
+ 0x51435D53, 0x251D3B9E, 0x21DC2629, 0x2C9F00F0, 0x285E1D47, 0x36194D42, 0x32D850F5, 0x3F9B762C, 0x3B5A6B9B, 0x0315D626, 0x07D4CB91,
+ 0x0A97ED48, 0x0E56F0FF, 0x1011A0FA, 0x14D0BD4D, 0x19939B94, 0x1D528623, 0xF12F560E, 0xF5EE4BB9, 0xF8AD6D60, 0xFC6C70D7, 0xE22B20D2,
+ 0xE6EA3D65, 0xEBA91BBC, 0xEF68060B, 0xD727BBB6, 0xD3E6A601, 0xDEA580D8, 0xDA649D6F, 0xC423CD6A, 0xC0E2D0DD, 0xCDA1F604, 0xC960EBB3,
+ 0xBD3E8D7E, 0xB9FF90C9, 0xB4BCB610, 0xB07DABA7, 0xAE3AFBA2, 0xAAFBE615, 0xA7B8C0CC, 0xA379DD7B, 0x9B3660C6, 0x9FF77D71, 0x92B45BA8,
+ 0x9675461F, 0x8832161A, 0x8CF30BAD, 0x81B02D74, 0x857130C3, 0x5D8A9099, 0x594B8D2E, 0x5408ABF7, 0x50C9B640, 0x4E8EE645, 0x4A4FFBF2,
+ 0x470CDD2B, 0x43CDC09C, 0x7B827D21, 0x7F436096, 0x7200464F, 0x76C15BF8, 0x68860BFD, 0x6C47164A, 0x61043093, 0x65C52D24, 0x119B4BE9,
+ 0x155A565E, 0x18197087, 0x1CD86D30, 0x029F3D35, 0x065E2082, 0x0B1D065B, 0x0FDC1BEC, 0x3793A651, 0x3352BBE6, 0x3E119D3F, 0x3AD08088,
+ 0x2497D08D, 0x2056CD3A, 0x2D15EBE3, 0x29D4F654, 0xC5A92679, 0xC1683BCE, 0xCC2B1D17, 0xC8EA00A0, 0xD6AD50A5, 0xD26C4D12, 0xDF2F6BCB,
+ 0xDBEE767C, 0xE3A1CBC1, 0xE760D676, 0xEA23F0AF, 0xEEE2ED18, 0xF0A5BD1D, 0xF464A0AA, 0xF9278673, 0xFDE69BC4, 0x89B8FD09, 0x8D79E0BE,
+ 0x803AC667, 0x84FBDBD0, 0x9ABC8BD5, 0x9E7D9662, 0x933EB0BB, 0x97FFAD0C, 0xAFB010B1, 0xAB710D06, 0xA6322BDF, 0xA2F33668, 0xBCB4666D,
+ 0xB8757BDA, 0xB5365D03, 0xB1F740B4};
+
+static const uint32_t CRCTablesSB8_DEPRECATED[8][256] = {
+ {0x00000000, 0xb71dc104, 0x6e3b8209, 0xd926430d, 0xdc760413, 0x6b6bc517, 0xb24d861a, 0x0550471e, 0xb8ed0826, 0x0ff0c922, 0xd6d68a2f,
+ 0x61cb4b2b, 0x649b0c35, 0xd386cd31, 0x0aa08e3c, 0xbdbd4f38, 0x70db114c, 0xc7c6d048, 0x1ee09345, 0xa9fd5241, 0xacad155f, 0x1bb0d45b,
+ 0xc2969756, 0x758b5652, 0xc836196a, 0x7f2bd86e, 0xa60d9b63, 0x11105a67, 0x14401d79, 0xa35ddc7d, 0x7a7b9f70, 0xcd665e74, 0xe0b62398,
+ 0x57abe29c, 0x8e8da191, 0x39906095, 0x3cc0278b, 0x8bdde68f, 0x52fba582, 0xe5e66486, 0x585b2bbe, 0xef46eaba, 0x3660a9b7, 0x817d68b3,
+ 0x842d2fad, 0x3330eea9, 0xea16ada4, 0x5d0b6ca0, 0x906d32d4, 0x2770f3d0, 0xfe56b0dd, 0x494b71d9, 0x4c1b36c7, 0xfb06f7c3, 0x2220b4ce,
+ 0x953d75ca, 0x28803af2, 0x9f9dfbf6, 0x46bbb8fb, 0xf1a679ff, 0xf4f63ee1, 0x43ebffe5, 0x9acdbce8, 0x2dd07dec, 0x77708634, 0xc06d4730,
+ 0x194b043d, 0xae56c539, 0xab068227, 0x1c1b4323, 0xc53d002e, 0x7220c12a, 0xcf9d8e12, 0x78804f16, 0xa1a60c1b, 0x16bbcd1f, 0x13eb8a01,
+ 0xa4f64b05, 0x7dd00808, 0xcacdc90c, 0x07ab9778, 0xb0b6567c, 0x69901571, 0xde8dd475, 0xdbdd936b, 0x6cc0526f, 0xb5e61162, 0x02fbd066,
+ 0xbf469f5e, 0x085b5e5a, 0xd17d1d57, 0x6660dc53, 0x63309b4d, 0xd42d5a49, 0x0d0b1944, 0xba16d840, 0x97c6a5ac, 0x20db64a8, 0xf9fd27a5,
+ 0x4ee0e6a1, 0x4bb0a1bf, 0xfcad60bb, 0x258b23b6, 0x9296e2b2, 0x2f2bad8a, 0x98366c8e, 0x41102f83, 0xf60dee87, 0xf35da999, 0x4440689d,
+ 0x9d662b90, 0x2a7bea94, 0xe71db4e0, 0x500075e4, 0x892636e9, 0x3e3bf7ed, 0x3b6bb0f3, 0x8c7671f7, 0x555032fa, 0xe24df3fe, 0x5ff0bcc6,
+ 0xe8ed7dc2, 0x31cb3ecf, 0x86d6ffcb, 0x8386b8d5, 0x349b79d1, 0xedbd3adc, 0x5aa0fbd8, 0xeee00c69, 0x59fdcd6d, 0x80db8e60, 0x37c64f64,
+ 0x3296087a, 0x858bc97e, 0x5cad8a73, 0xebb04b77, 0x560d044f, 0xe110c54b, 0x38368646, 0x8f2b4742, 0x8a7b005c, 0x3d66c158, 0xe4408255,
+ 0x535d4351, 0x9e3b1d25, 0x2926dc21, 0xf0009f2c, 0x471d5e28, 0x424d1936, 0xf550d832, 0x2c769b3f, 0x9b6b5a3b, 0x26d61503, 0x91cbd407,
+ 0x48ed970a, 0xfff0560e, 0xfaa01110, 0x4dbdd014, 0x949b9319, 0x2386521d, 0x0e562ff1, 0xb94beef5, 0x606dadf8, 0xd7706cfc, 0xd2202be2,
+ 0x653deae6, 0xbc1ba9eb, 0x0b0668ef, 0xb6bb27d7, 0x01a6e6d3, 0xd880a5de, 0x6f9d64da, 0x6acd23c4, 0xddd0e2c0, 0x04f6a1cd, 0xb3eb60c9,
+ 0x7e8d3ebd, 0xc990ffb9, 0x10b6bcb4, 0xa7ab7db0, 0xa2fb3aae, 0x15e6fbaa, 0xccc0b8a7, 0x7bdd79a3, 0xc660369b, 0x717df79f, 0xa85bb492,
+ 0x1f467596, 0x1a163288, 0xad0bf38c, 0x742db081, 0xc3307185, 0x99908a5d, 0x2e8d4b59, 0xf7ab0854, 0x40b6c950, 0x45e68e4e, 0xf2fb4f4a,
+ 0x2bdd0c47, 0x9cc0cd43, 0x217d827b, 0x9660437f, 0x4f460072, 0xf85bc176, 0xfd0b8668, 0x4a16476c, 0x93300461, 0x242dc565, 0xe94b9b11,
+ 0x5e565a15, 0x87701918, 0x306dd81c, 0x353d9f02, 0x82205e06, 0x5b061d0b, 0xec1bdc0f, 0x51a69337, 0xe6bb5233, 0x3f9d113e, 0x8880d03a,
+ 0x8dd09724, 0x3acd5620, 0xe3eb152d, 0x54f6d429, 0x7926a9c5, 0xce3b68c1, 0x171d2bcc, 0xa000eac8, 0xa550add6, 0x124d6cd2, 0xcb6b2fdf,
+ 0x7c76eedb, 0xc1cba1e3, 0x76d660e7, 0xaff023ea, 0x18ede2ee, 0x1dbda5f0, 0xaaa064f4, 0x738627f9, 0xc49be6fd, 0x09fdb889, 0xbee0798d,
+ 0x67c63a80, 0xd0dbfb84, 0xd58bbc9a, 0x62967d9e, 0xbbb03e93, 0x0cadff97, 0xb110b0af, 0x060d71ab, 0xdf2b32a6, 0x6836f3a2, 0x6d66b4bc,
+ 0xda7b75b8, 0x035d36b5, 0xb440f7b1},
+ {0x00000000, 0xdcc119d2, 0x0f9ef2a0, 0xd35feb72, 0xa9212445, 0x75e03d97, 0xa6bfd6e5, 0x7a7ecf37, 0x5243488a, 0x8e825158, 0x5dddba2a,
+ 0x811ca3f8, 0xfb626ccf, 0x27a3751d, 0xf4fc9e6f, 0x283d87bd, 0x139b5110, 0xcf5a48c2, 0x1c05a3b0, 0xc0c4ba62, 0xbaba7555, 0x667b6c87,
+ 0xb52487f5, 0x69e59e27, 0x41d8199a, 0x9d190048, 0x4e46eb3a, 0x9287f2e8, 0xe8f93ddf, 0x3438240d, 0xe767cf7f, 0x3ba6d6ad, 0x2636a320,
+ 0xfaf7baf2, 0x29a85180, 0xf5694852, 0x8f178765, 0x53d69eb7, 0x808975c5, 0x5c486c17, 0x7475ebaa, 0xa8b4f278, 0x7beb190a, 0xa72a00d8,
+ 0xdd54cfef, 0x0195d63d, 0xd2ca3d4f, 0x0e0b249d, 0x35adf230, 0xe96cebe2, 0x3a330090, 0xe6f21942, 0x9c8cd675, 0x404dcfa7, 0x931224d5,
+ 0x4fd33d07, 0x67eebaba, 0xbb2fa368, 0x6870481a, 0xb4b151c8, 0xcecf9eff, 0x120e872d, 0xc1516c5f, 0x1d90758d, 0x4c6c4641, 0x90ad5f93,
+ 0x43f2b4e1, 0x9f33ad33, 0xe54d6204, 0x398c7bd6, 0xead390a4, 0x36128976, 0x1e2f0ecb, 0xc2ee1719, 0x11b1fc6b, 0xcd70e5b9, 0xb70e2a8e,
+ 0x6bcf335c, 0xb890d82e, 0x6451c1fc, 0x5ff71751, 0x83360e83, 0x5069e5f1, 0x8ca8fc23, 0xf6d63314, 0x2a172ac6, 0xf948c1b4, 0x2589d866,
+ 0x0db45fdb, 0xd1754609, 0x022aad7b, 0xdeebb4a9, 0xa4957b9e, 0x7854624c, 0xab0b893e, 0x77ca90ec, 0x6a5ae561, 0xb69bfcb3, 0x65c417c1,
+ 0xb9050e13, 0xc37bc124, 0x1fbad8f6, 0xcce53384, 0x10242a56, 0x3819adeb, 0xe4d8b439, 0x37875f4b, 0xeb464699, 0x913889ae, 0x4df9907c,
+ 0x9ea67b0e, 0x426762dc, 0x79c1b471, 0xa500ada3, 0x765f46d1, 0xaa9e5f03, 0xd0e09034, 0x0c2189e6, 0xdf7e6294, 0x03bf7b46, 0x2b82fcfb,
+ 0xf743e529, 0x241c0e5b, 0xf8dd1789, 0x82a3d8be, 0x5e62c16c, 0x8d3d2a1e, 0x51fc33cc, 0x98d88c82, 0x44199550, 0x97467e22, 0x4b8767f0,
+ 0x31f9a8c7, 0xed38b115, 0x3e675a67, 0xe2a643b5, 0xca9bc408, 0x165addda, 0xc50536a8, 0x19c42f7a, 0x63bae04d, 0xbf7bf99f, 0x6c2412ed,
+ 0xb0e50b3f, 0x8b43dd92, 0x5782c440, 0x84dd2f32, 0x581c36e0, 0x2262f9d7, 0xfea3e005, 0x2dfc0b77, 0xf13d12a5, 0xd9009518, 0x05c18cca,
+ 0xd69e67b8, 0x0a5f7e6a, 0x7021b15d, 0xace0a88f, 0x7fbf43fd, 0xa37e5a2f, 0xbeee2fa2, 0x622f3670, 0xb170dd02, 0x6db1c4d0, 0x17cf0be7,
+ 0xcb0e1235, 0x1851f947, 0xc490e095, 0xecad6728, 0x306c7efa, 0xe3339588, 0x3ff28c5a, 0x458c436d, 0x994d5abf, 0x4a12b1cd, 0x96d3a81f,
+ 0xad757eb2, 0x71b46760, 0xa2eb8c12, 0x7e2a95c0, 0x04545af7, 0xd8954325, 0x0bcaa857, 0xd70bb185, 0xff363638, 0x23f72fea, 0xf0a8c498,
+ 0x2c69dd4a, 0x5617127d, 0x8ad60baf, 0x5989e0dd, 0x8548f90f, 0xd4b4cac3, 0x0875d311, 0xdb2a3863, 0x07eb21b1, 0x7d95ee86, 0xa154f754,
+ 0x720b1c26, 0xaeca05f4, 0x86f78249, 0x5a369b9b, 0x896970e9, 0x55a8693b, 0x2fd6a60c, 0xf317bfde, 0x204854ac, 0xfc894d7e, 0xc72f9bd3,
+ 0x1bee8201, 0xc8b16973, 0x147070a1, 0x6e0ebf96, 0xb2cfa644, 0x61904d36, 0xbd5154e4, 0x956cd359, 0x49adca8b, 0x9af221f9, 0x4633382b,
+ 0x3c4df71c, 0xe08ceece, 0x33d305bc, 0xef121c6e, 0xf28269e3, 0x2e437031, 0xfd1c9b43, 0x21dd8291, 0x5ba34da6, 0x87625474, 0x543dbf06,
+ 0x88fca6d4, 0xa0c12169, 0x7c0038bb, 0xaf5fd3c9, 0x739eca1b, 0x09e0052c, 0xd5211cfe, 0x067ef78c, 0xdabfee5e, 0xe11938f3, 0x3dd82121,
+ 0xee87ca53, 0x3246d381, 0x48381cb6, 0x94f90564, 0x47a6ee16, 0x9b67f7c4, 0xb35a7079, 0x6f9b69ab, 0xbcc482d9, 0x60059b0b, 0x1a7b543c,
+ 0xc6ba4dee, 0x15e5a69c, 0xc924bf4e},
+ {0x00000000, 0x87acd801, 0x0e59b103, 0x89f56902, 0x1cb26207, 0x9b1eba06, 0x12ebd304, 0x95470b05, 0x3864c50e, 0xbfc81d0f, 0x363d740d,
+ 0xb191ac0c, 0x24d6a709, 0xa37a7f08, 0x2a8f160a, 0xad23ce0b, 0x70c88a1d, 0xf764521c, 0x7e913b1e, 0xf93de31f, 0x6c7ae81a, 0xebd6301b,
+ 0x62235919, 0xe58f8118, 0x48ac4f13, 0xcf009712, 0x46f5fe10, 0xc1592611, 0x541e2d14, 0xd3b2f515, 0x5a479c17, 0xddeb4416, 0xe090153b,
+ 0x673ccd3a, 0xeec9a438, 0x69657c39, 0xfc22773c, 0x7b8eaf3d, 0xf27bc63f, 0x75d71e3e, 0xd8f4d035, 0x5f580834, 0xd6ad6136, 0x5101b937,
+ 0xc446b232, 0x43ea6a33, 0xca1f0331, 0x4db3db30, 0x90589f26, 0x17f44727, 0x9e012e25, 0x19adf624, 0x8ceafd21, 0x0b462520, 0x82b34c22,
+ 0x051f9423, 0xa83c5a28, 0x2f908229, 0xa665eb2b, 0x21c9332a, 0xb48e382f, 0x3322e02e, 0xbad7892c, 0x3d7b512d, 0xc0212b76, 0x478df377,
+ 0xce789a75, 0x49d44274, 0xdc934971, 0x5b3f9170, 0xd2caf872, 0x55662073, 0xf845ee78, 0x7fe93679, 0xf61c5f7b, 0x71b0877a, 0xe4f78c7f,
+ 0x635b547e, 0xeaae3d7c, 0x6d02e57d, 0xb0e9a16b, 0x3745796a, 0xbeb01068, 0x391cc869, 0xac5bc36c, 0x2bf71b6d, 0xa202726f, 0x25aeaa6e,
+ 0x888d6465, 0x0f21bc64, 0x86d4d566, 0x01780d67, 0x943f0662, 0x1393de63, 0x9a66b761, 0x1dca6f60, 0x20b13e4d, 0xa71de64c, 0x2ee88f4e,
+ 0xa944574f, 0x3c035c4a, 0xbbaf844b, 0x325aed49, 0xb5f63548, 0x18d5fb43, 0x9f792342, 0x168c4a40, 0x91209241, 0x04679944, 0x83cb4145,
+ 0x0a3e2847, 0x8d92f046, 0x5079b450, 0xd7d56c51, 0x5e200553, 0xd98cdd52, 0x4ccbd657, 0xcb670e56, 0x42926754, 0xc53ebf55, 0x681d715e,
+ 0xefb1a95f, 0x6644c05d, 0xe1e8185c, 0x74af1359, 0xf303cb58, 0x7af6a25a, 0xfd5a7a5b, 0x804356ec, 0x07ef8eed, 0x8e1ae7ef, 0x09b63fee,
+ 0x9cf134eb, 0x1b5decea, 0x92a885e8, 0x15045de9, 0xb82793e2, 0x3f8b4be3, 0xb67e22e1, 0x31d2fae0, 0xa495f1e5, 0x233929e4, 0xaacc40e6,
+ 0x2d6098e7, 0xf08bdcf1, 0x772704f0, 0xfed26df2, 0x797eb5f3, 0xec39bef6, 0x6b9566f7, 0xe2600ff5, 0x65ccd7f4, 0xc8ef19ff, 0x4f43c1fe,
+ 0xc6b6a8fc, 0x411a70fd, 0xd45d7bf8, 0x53f1a3f9, 0xda04cafb, 0x5da812fa, 0x60d343d7, 0xe77f9bd6, 0x6e8af2d4, 0xe9262ad5, 0x7c6121d0,
+ 0xfbcdf9d1, 0x723890d3, 0xf59448d2, 0x58b786d9, 0xdf1b5ed8, 0x56ee37da, 0xd142efdb, 0x4405e4de, 0xc3a93cdf, 0x4a5c55dd, 0xcdf08ddc,
+ 0x101bc9ca, 0x97b711cb, 0x1e4278c9, 0x99eea0c8, 0x0ca9abcd, 0x8b0573cc, 0x02f01ace, 0x855cc2cf, 0x287f0cc4, 0xafd3d4c5, 0x2626bdc7,
+ 0xa18a65c6, 0x34cd6ec3, 0xb361b6c2, 0x3a94dfc0, 0xbd3807c1, 0x40627d9a, 0xc7cea59b, 0x4e3bcc99, 0xc9971498, 0x5cd01f9d, 0xdb7cc79c,
+ 0x5289ae9e, 0xd525769f, 0x7806b894, 0xffaa6095, 0x765f0997, 0xf1f3d196, 0x64b4da93, 0xe3180292, 0x6aed6b90, 0xed41b391, 0x30aaf787,
+ 0xb7062f86, 0x3ef34684, 0xb95f9e85, 0x2c189580, 0xabb44d81, 0x22412483, 0xa5edfc82, 0x08ce3289, 0x8f62ea88, 0x0697838a, 0x813b5b8b,
+ 0x147c508e, 0x93d0888f, 0x1a25e18d, 0x9d89398c, 0xa0f268a1, 0x275eb0a0, 0xaeabd9a2, 0x290701a3, 0xbc400aa6, 0x3becd2a7, 0xb219bba5,
+ 0x35b563a4, 0x9896adaf, 0x1f3a75ae, 0x96cf1cac, 0x1163c4ad, 0x8424cfa8, 0x038817a9, 0x8a7d7eab, 0x0dd1a6aa, 0xd03ae2bc, 0x57963abd,
+ 0xde6353bf, 0x59cf8bbe, 0xcc8880bb, 0x4b2458ba, 0xc2d131b8, 0x457de9b9, 0xe85e27b2, 0x6ff2ffb3, 0xe60796b1, 0x61ab4eb0, 0xf4ec45b5,
+ 0x73409db4, 0xfab5f4b6, 0x7d192cb7},
+ {0x00000000, 0xb79a6ddc, 0xd9281abc, 0x6eb27760, 0x054cf57c, 0xb2d698a0, 0xdc64efc0, 0x6bfe821c, 0x0a98eaf9, 0xbd028725, 0xd3b0f045,
+ 0x642a9d99, 0x0fd41f85, 0xb84e7259, 0xd6fc0539, 0x616668e5, 0xa32d14f7, 0x14b7792b, 0x7a050e4b, 0xcd9f6397, 0xa661e18b, 0x11fb8c57,
+ 0x7f49fb37, 0xc8d396eb, 0xa9b5fe0e, 0x1e2f93d2, 0x709de4b2, 0xc707896e, 0xacf90b72, 0x1b6366ae, 0x75d111ce, 0xc24b7c12, 0xf146e9ea,
+ 0x46dc8436, 0x286ef356, 0x9ff49e8a, 0xf40a1c96, 0x4390714a, 0x2d22062a, 0x9ab86bf6, 0xfbde0313, 0x4c446ecf, 0x22f619af, 0x956c7473,
+ 0xfe92f66f, 0x49089bb3, 0x27baecd3, 0x9020810f, 0x526bfd1d, 0xe5f190c1, 0x8b43e7a1, 0x3cd98a7d, 0x57270861, 0xe0bd65bd, 0x8e0f12dd,
+ 0x39957f01, 0x58f317e4, 0xef697a38, 0x81db0d58, 0x36416084, 0x5dbfe298, 0xea258f44, 0x8497f824, 0x330d95f8, 0x559013d1, 0xe20a7e0d,
+ 0x8cb8096d, 0x3b2264b1, 0x50dce6ad, 0xe7468b71, 0x89f4fc11, 0x3e6e91cd, 0x5f08f928, 0xe89294f4, 0x8620e394, 0x31ba8e48, 0x5a440c54,
+ 0xedde6188, 0x836c16e8, 0x34f67b34, 0xf6bd0726, 0x41276afa, 0x2f951d9a, 0x980f7046, 0xf3f1f25a, 0x446b9f86, 0x2ad9e8e6, 0x9d43853a,
+ 0xfc25eddf, 0x4bbf8003, 0x250df763, 0x92979abf, 0xf96918a3, 0x4ef3757f, 0x2041021f, 0x97db6fc3, 0xa4d6fa3b, 0x134c97e7, 0x7dfee087,
+ 0xca648d5b, 0xa19a0f47, 0x1600629b, 0x78b215fb, 0xcf287827, 0xae4e10c2, 0x19d47d1e, 0x77660a7e, 0xc0fc67a2, 0xab02e5be, 0x1c988862,
+ 0x722aff02, 0xc5b092de, 0x07fbeecc, 0xb0618310, 0xded3f470, 0x694999ac, 0x02b71bb0, 0xb52d766c, 0xdb9f010c, 0x6c056cd0, 0x0d630435,
+ 0xbaf969e9, 0xd44b1e89, 0x63d17355, 0x082ff149, 0xbfb59c95, 0xd107ebf5, 0x669d8629, 0x1d3de6a6, 0xaaa78b7a, 0xc415fc1a, 0x738f91c6,
+ 0x187113da, 0xafeb7e06, 0xc1590966, 0x76c364ba, 0x17a50c5f, 0xa03f6183, 0xce8d16e3, 0x79177b3f, 0x12e9f923, 0xa57394ff, 0xcbc1e39f,
+ 0x7c5b8e43, 0xbe10f251, 0x098a9f8d, 0x6738e8ed, 0xd0a28531, 0xbb5c072d, 0x0cc66af1, 0x62741d91, 0xd5ee704d, 0xb48818a8, 0x03127574,
+ 0x6da00214, 0xda3a6fc8, 0xb1c4edd4, 0x065e8008, 0x68ecf768, 0xdf769ab4, 0xec7b0f4c, 0x5be16290, 0x355315f0, 0x82c9782c, 0xe937fa30,
+ 0x5ead97ec, 0x301fe08c, 0x87858d50, 0xe6e3e5b5, 0x51798869, 0x3fcbff09, 0x885192d5, 0xe3af10c9, 0x54357d15, 0x3a870a75, 0x8d1d67a9,
+ 0x4f561bbb, 0xf8cc7667, 0x967e0107, 0x21e46cdb, 0x4a1aeec7, 0xfd80831b, 0x9332f47b, 0x24a899a7, 0x45cef142, 0xf2549c9e, 0x9ce6ebfe,
+ 0x2b7c8622, 0x4082043e, 0xf71869e2, 0x99aa1e82, 0x2e30735e, 0x48adf577, 0xff3798ab, 0x9185efcb, 0x261f8217, 0x4de1000b, 0xfa7b6dd7,
+ 0x94c91ab7, 0x2353776b, 0x42351f8e, 0xf5af7252, 0x9b1d0532, 0x2c8768ee, 0x4779eaf2, 0xf0e3872e, 0x9e51f04e, 0x29cb9d92, 0xeb80e180,
+ 0x5c1a8c5c, 0x32a8fb3c, 0x853296e0, 0xeecc14fc, 0x59567920, 0x37e40e40, 0x807e639c, 0xe1180b79, 0x568266a5, 0x383011c5, 0x8faa7c19,
+ 0xe454fe05, 0x53ce93d9, 0x3d7ce4b9, 0x8ae68965, 0xb9eb1c9d, 0x0e717141, 0x60c30621, 0xd7596bfd, 0xbca7e9e1, 0x0b3d843d, 0x658ff35d,
+ 0xd2159e81, 0xb373f664, 0x04e99bb8, 0x6a5becd8, 0xddc18104, 0xb63f0318, 0x01a56ec4, 0x6f1719a4, 0xd88d7478, 0x1ac6086a, 0xad5c65b6,
+ 0xc3ee12d6, 0x74747f0a, 0x1f8afd16, 0xa81090ca, 0xc6a2e7aa, 0x71388a76, 0x105ee293, 0xa7c48f4f, 0xc976f82f, 0x7eec95f3, 0x151217ef,
+ 0xa2887a33, 0xcc3a0d53, 0x7ba0608f},
+ {0x00000000, 0x8d670d49, 0x1acf1a92, 0x97a817db, 0x8383f420, 0x0ee4f969, 0x994ceeb2, 0x142be3fb, 0x0607e941, 0x8b60e408, 0x1cc8f3d3,
+ 0x91affe9a, 0x85841d61, 0x08e31028, 0x9f4b07f3, 0x122c0aba, 0x0c0ed283, 0x8169dfca, 0x16c1c811, 0x9ba6c558, 0x8f8d26a3, 0x02ea2bea,
+ 0x95423c31, 0x18253178, 0x0a093bc2, 0x876e368b, 0x10c62150, 0x9da12c19, 0x898acfe2, 0x04edc2ab, 0x9345d570, 0x1e22d839, 0xaf016503,
+ 0x2266684a, 0xb5ce7f91, 0x38a972d8, 0x2c829123, 0xa1e59c6a, 0x364d8bb1, 0xbb2a86f8, 0xa9068c42, 0x2461810b, 0xb3c996d0, 0x3eae9b99,
+ 0x2a857862, 0xa7e2752b, 0x304a62f0, 0xbd2d6fb9, 0xa30fb780, 0x2e68bac9, 0xb9c0ad12, 0x34a7a05b, 0x208c43a0, 0xadeb4ee9, 0x3a435932,
+ 0xb724547b, 0xa5085ec1, 0x286f5388, 0xbfc74453, 0x32a0491a, 0x268baae1, 0xabeca7a8, 0x3c44b073, 0xb123bd3a, 0x5e03ca06, 0xd364c74f,
+ 0x44ccd094, 0xc9abdddd, 0xdd803e26, 0x50e7336f, 0xc74f24b4, 0x4a2829fd, 0x58042347, 0xd5632e0e, 0x42cb39d5, 0xcfac349c, 0xdb87d767,
+ 0x56e0da2e, 0xc148cdf5, 0x4c2fc0bc, 0x520d1885, 0xdf6a15cc, 0x48c20217, 0xc5a50f5e, 0xd18eeca5, 0x5ce9e1ec, 0xcb41f637, 0x4626fb7e,
+ 0x540af1c4, 0xd96dfc8d, 0x4ec5eb56, 0xc3a2e61f, 0xd78905e4, 0x5aee08ad, 0xcd461f76, 0x4021123f, 0xf102af05, 0x7c65a24c, 0xebcdb597,
+ 0x66aab8de, 0x72815b25, 0xffe6566c, 0x684e41b7, 0xe5294cfe, 0xf7054644, 0x7a624b0d, 0xedca5cd6, 0x60ad519f, 0x7486b264, 0xf9e1bf2d,
+ 0x6e49a8f6, 0xe32ea5bf, 0xfd0c7d86, 0x706b70cf, 0xe7c36714, 0x6aa46a5d, 0x7e8f89a6, 0xf3e884ef, 0x64409334, 0xe9279e7d, 0xfb0b94c7,
+ 0x766c998e, 0xe1c48e55, 0x6ca3831c, 0x788860e7, 0xf5ef6dae, 0x62477a75, 0xef20773c, 0xbc06940d, 0x31619944, 0xa6c98e9f, 0x2bae83d6,
+ 0x3f85602d, 0xb2e26d64, 0x254a7abf, 0xa82d77f6, 0xba017d4c, 0x37667005, 0xa0ce67de, 0x2da96a97, 0x3982896c, 0xb4e58425, 0x234d93fe,
+ 0xae2a9eb7, 0xb008468e, 0x3d6f4bc7, 0xaac75c1c, 0x27a05155, 0x338bb2ae, 0xbeecbfe7, 0x2944a83c, 0xa423a575, 0xb60fafcf, 0x3b68a286,
+ 0xacc0b55d, 0x21a7b814, 0x358c5bef, 0xb8eb56a6, 0x2f43417d, 0xa2244c34, 0x1307f10e, 0x9e60fc47, 0x09c8eb9c, 0x84afe6d5, 0x9084052e,
+ 0x1de30867, 0x8a4b1fbc, 0x072c12f5, 0x1500184f, 0x98671506, 0x0fcf02dd, 0x82a80f94, 0x9683ec6f, 0x1be4e126, 0x8c4cf6fd, 0x012bfbb4,
+ 0x1f09238d, 0x926e2ec4, 0x05c6391f, 0x88a13456, 0x9c8ad7ad, 0x11eddae4, 0x8645cd3f, 0x0b22c076, 0x190ecacc, 0x9469c785, 0x03c1d05e,
+ 0x8ea6dd17, 0x9a8d3eec, 0x17ea33a5, 0x8042247e, 0x0d252937, 0xe2055e0b, 0x6f625342, 0xf8ca4499, 0x75ad49d0, 0x6186aa2b, 0xece1a762,
+ 0x7b49b0b9, 0xf62ebdf0, 0xe402b74a, 0x6965ba03, 0xfecdadd8, 0x73aaa091, 0x6781436a, 0xeae64e23, 0x7d4e59f8, 0xf02954b1, 0xee0b8c88,
+ 0x636c81c1, 0xf4c4961a, 0x79a39b53, 0x6d8878a8, 0xe0ef75e1, 0x7747623a, 0xfa206f73, 0xe80c65c9, 0x656b6880, 0xf2c37f5b, 0x7fa47212,
+ 0x6b8f91e9, 0xe6e89ca0, 0x71408b7b, 0xfc278632, 0x4d043b08, 0xc0633641, 0x57cb219a, 0xdaac2cd3, 0xce87cf28, 0x43e0c261, 0xd448d5ba,
+ 0x592fd8f3, 0x4b03d249, 0xc664df00, 0x51ccc8db, 0xdcabc592, 0xc8802669, 0x45e72b20, 0xd24f3cfb, 0x5f2831b2, 0x410ae98b, 0xcc6de4c2,
+ 0x5bc5f319, 0xd6a2fe50, 0xc2891dab, 0x4fee10e2, 0xd8460739, 0x55210a70, 0x470d00ca, 0xca6a0d83, 0x5dc21a58, 0xd0a51711, 0xc48ef4ea,
+ 0x49e9f9a3, 0xde41ee78, 0x5326e331},
+ {0x00000000, 0x780d281b, 0xf01a5036, 0x8817782d, 0xe035a06c, 0x98388877, 0x102ff05a, 0x6822d841, 0xc06b40d9, 0xb86668c2, 0x307110ef,
+ 0x487c38f4, 0x205ee0b5, 0x5853c8ae, 0xd044b083, 0xa8499898, 0x37ca41b6, 0x4fc769ad, 0xc7d01180, 0xbfdd399b, 0xd7ffe1da, 0xaff2c9c1,
+ 0x27e5b1ec, 0x5fe899f7, 0xf7a1016f, 0x8fac2974, 0x07bb5159, 0x7fb67942, 0x1794a103, 0x6f998918, 0xe78ef135, 0x9f83d92e, 0xd9894268,
+ 0xa1846a73, 0x2993125e, 0x519e3a45, 0x39bce204, 0x41b1ca1f, 0xc9a6b232, 0xb1ab9a29, 0x19e202b1, 0x61ef2aaa, 0xe9f85287, 0x91f57a9c,
+ 0xf9d7a2dd, 0x81da8ac6, 0x09cdf2eb, 0x71c0daf0, 0xee4303de, 0x964e2bc5, 0x1e5953e8, 0x66547bf3, 0x0e76a3b2, 0x767b8ba9, 0xfe6cf384,
+ 0x8661db9f, 0x2e284307, 0x56256b1c, 0xde321331, 0xa63f3b2a, 0xce1de36b, 0xb610cb70, 0x3e07b35d, 0x460a9b46, 0xb21385d0, 0xca1eadcb,
+ 0x4209d5e6, 0x3a04fdfd, 0x522625bc, 0x2a2b0da7, 0xa23c758a, 0xda315d91, 0x7278c509, 0x0a75ed12, 0x8262953f, 0xfa6fbd24, 0x924d6565,
+ 0xea404d7e, 0x62573553, 0x1a5a1d48, 0x85d9c466, 0xfdd4ec7d, 0x75c39450, 0x0dcebc4b, 0x65ec640a, 0x1de14c11, 0x95f6343c, 0xedfb1c27,
+ 0x45b284bf, 0x3dbfaca4, 0xb5a8d489, 0xcda5fc92, 0xa58724d3, 0xdd8a0cc8, 0x559d74e5, 0x2d905cfe, 0x6b9ac7b8, 0x1397efa3, 0x9b80978e,
+ 0xe38dbf95, 0x8baf67d4, 0xf3a24fcf, 0x7bb537e2, 0x03b81ff9, 0xabf18761, 0xd3fcaf7a, 0x5bebd757, 0x23e6ff4c, 0x4bc4270d, 0x33c90f16,
+ 0xbbde773b, 0xc3d35f20, 0x5c50860e, 0x245dae15, 0xac4ad638, 0xd447fe23, 0xbc652662, 0xc4680e79, 0x4c7f7654, 0x34725e4f, 0x9c3bc6d7,
+ 0xe436eecc, 0x6c2196e1, 0x142cbefa, 0x7c0e66bb, 0x04034ea0, 0x8c14368d, 0xf4191e96, 0xd33acba5, 0xab37e3be, 0x23209b93, 0x5b2db388,
+ 0x330f6bc9, 0x4b0243d2, 0xc3153bff, 0xbb1813e4, 0x13518b7c, 0x6b5ca367, 0xe34bdb4a, 0x9b46f351, 0xf3642b10, 0x8b69030b, 0x037e7b26,
+ 0x7b73533d, 0xe4f08a13, 0x9cfda208, 0x14eada25, 0x6ce7f23e, 0x04c52a7f, 0x7cc80264, 0xf4df7a49, 0x8cd25252, 0x249bcaca, 0x5c96e2d1,
+ 0xd4819afc, 0xac8cb2e7, 0xc4ae6aa6, 0xbca342bd, 0x34b43a90, 0x4cb9128b, 0x0ab389cd, 0x72bea1d6, 0xfaa9d9fb, 0x82a4f1e0, 0xea8629a1,
+ 0x928b01ba, 0x1a9c7997, 0x6291518c, 0xcad8c914, 0xb2d5e10f, 0x3ac29922, 0x42cfb139, 0x2aed6978, 0x52e04163, 0xdaf7394e, 0xa2fa1155,
+ 0x3d79c87b, 0x4574e060, 0xcd63984d, 0xb56eb056, 0xdd4c6817, 0xa541400c, 0x2d563821, 0x555b103a, 0xfd1288a2, 0x851fa0b9, 0x0d08d894,
+ 0x7505f08f, 0x1d2728ce, 0x652a00d5, 0xed3d78f8, 0x953050e3, 0x61294e75, 0x1924666e, 0x91331e43, 0xe93e3658, 0x811cee19, 0xf911c602,
+ 0x7106be2f, 0x090b9634, 0xa1420eac, 0xd94f26b7, 0x51585e9a, 0x29557681, 0x4177aec0, 0x397a86db, 0xb16dfef6, 0xc960d6ed, 0x56e30fc3,
+ 0x2eee27d8, 0xa6f95ff5, 0xdef477ee, 0xb6d6afaf, 0xcedb87b4, 0x46ccff99, 0x3ec1d782, 0x96884f1a, 0xee856701, 0x66921f2c, 0x1e9f3737,
+ 0x76bdef76, 0x0eb0c76d, 0x86a7bf40, 0xfeaa975b, 0xb8a00c1d, 0xc0ad2406, 0x48ba5c2b, 0x30b77430, 0x5895ac71, 0x2098846a, 0xa88ffc47,
+ 0xd082d45c, 0x78cb4cc4, 0x00c664df, 0x88d11cf2, 0xf0dc34e9, 0x98feeca8, 0xe0f3c4b3, 0x68e4bc9e, 0x10e99485, 0x8f6a4dab, 0xf76765b0,
+ 0x7f701d9d, 0x077d3586, 0x6f5fedc7, 0x1752c5dc, 0x9f45bdf1, 0xe74895ea, 0x4f010d72, 0x370c2569, 0xbf1b5d44, 0xc716755f, 0xaf34ad1e,
+ 0xd7398505, 0x5f2efd28, 0x2723d533},
+ {0x00000000, 0x1168574f, 0x22d0ae9e, 0x33b8f9d1, 0xf3bd9c39, 0xe2d5cb76, 0xd16d32a7, 0xc00565e8, 0xe67b3973, 0xf7136e3c, 0xc4ab97ed,
+ 0xd5c3c0a2, 0x15c6a54a, 0x04aef205, 0x37160bd4, 0x267e5c9b, 0xccf772e6, 0xdd9f25a9, 0xee27dc78, 0xff4f8b37, 0x3f4aeedf, 0x2e22b990,
+ 0x1d9a4041, 0x0cf2170e, 0x2a8c4b95, 0x3be41cda, 0x085ce50b, 0x1934b244, 0xd931d7ac, 0xc85980e3, 0xfbe17932, 0xea892e7d, 0x2ff224c8,
+ 0x3e9a7387, 0x0d228a56, 0x1c4add19, 0xdc4fb8f1, 0xcd27efbe, 0xfe9f166f, 0xeff74120, 0xc9891dbb, 0xd8e14af4, 0xeb59b325, 0xfa31e46a,
+ 0x3a348182, 0x2b5cd6cd, 0x18e42f1c, 0x098c7853, 0xe305562e, 0xf26d0161, 0xc1d5f8b0, 0xd0bdafff, 0x10b8ca17, 0x01d09d58, 0x32686489,
+ 0x230033c6, 0x057e6f5d, 0x14163812, 0x27aec1c3, 0x36c6968c, 0xf6c3f364, 0xe7aba42b, 0xd4135dfa, 0xc57b0ab5, 0xe9f98894, 0xf891dfdb,
+ 0xcb29260a, 0xda417145, 0x1a4414ad, 0x0b2c43e2, 0x3894ba33, 0x29fced7c, 0x0f82b1e7, 0x1eeae6a8, 0x2d521f79, 0x3c3a4836, 0xfc3f2dde,
+ 0xed577a91, 0xdeef8340, 0xcf87d40f, 0x250efa72, 0x3466ad3d, 0x07de54ec, 0x16b603a3, 0xd6b3664b, 0xc7db3104, 0xf463c8d5, 0xe50b9f9a,
+ 0xc375c301, 0xd21d944e, 0xe1a56d9f, 0xf0cd3ad0, 0x30c85f38, 0x21a00877, 0x1218f1a6, 0x0370a6e9, 0xc60bac5c, 0xd763fb13, 0xe4db02c2,
+ 0xf5b3558d, 0x35b63065, 0x24de672a, 0x17669efb, 0x060ec9b4, 0x2070952f, 0x3118c260, 0x02a03bb1, 0x13c86cfe, 0xd3cd0916, 0xc2a55e59,
+ 0xf11da788, 0xe075f0c7, 0x0afcdeba, 0x1b9489f5, 0x282c7024, 0x3944276b, 0xf9414283, 0xe82915cc, 0xdb91ec1d, 0xcaf9bb52, 0xec87e7c9,
+ 0xfdefb086, 0xce574957, 0xdf3f1e18, 0x1f3a7bf0, 0x0e522cbf, 0x3dead56e, 0x2c828221, 0x65eed02d, 0x74868762, 0x473e7eb3, 0x565629fc,
+ 0x96534c14, 0x873b1b5b, 0xb483e28a, 0xa5ebb5c5, 0x8395e95e, 0x92fdbe11, 0xa14547c0, 0xb02d108f, 0x70287567, 0x61402228, 0x52f8dbf9,
+ 0x43908cb6, 0xa919a2cb, 0xb871f584, 0x8bc90c55, 0x9aa15b1a, 0x5aa43ef2, 0x4bcc69bd, 0x7874906c, 0x691cc723, 0x4f629bb8, 0x5e0accf7,
+ 0x6db23526, 0x7cda6269, 0xbcdf0781, 0xadb750ce, 0x9e0fa91f, 0x8f67fe50, 0x4a1cf4e5, 0x5b74a3aa, 0x68cc5a7b, 0x79a40d34, 0xb9a168dc,
+ 0xa8c93f93, 0x9b71c642, 0x8a19910d, 0xac67cd96, 0xbd0f9ad9, 0x8eb76308, 0x9fdf3447, 0x5fda51af, 0x4eb206e0, 0x7d0aff31, 0x6c62a87e,
+ 0x86eb8603, 0x9783d14c, 0xa43b289d, 0xb5537fd2, 0x75561a3a, 0x643e4d75, 0x5786b4a4, 0x46eee3eb, 0x6090bf70, 0x71f8e83f, 0x424011ee,
+ 0x532846a1, 0x932d2349, 0x82457406, 0xb1fd8dd7, 0xa095da98, 0x8c1758b9, 0x9d7f0ff6, 0xaec7f627, 0xbfafa168, 0x7faac480, 0x6ec293cf,
+ 0x5d7a6a1e, 0x4c123d51, 0x6a6c61ca, 0x7b043685, 0x48bccf54, 0x59d4981b, 0x99d1fdf3, 0x88b9aabc, 0xbb01536d, 0xaa690422, 0x40e02a5f,
+ 0x51887d10, 0x623084c1, 0x7358d38e, 0xb35db666, 0xa235e129, 0x918d18f8, 0x80e54fb7, 0xa69b132c, 0xb7f34463, 0x844bbdb2, 0x9523eafd,
+ 0x55268f15, 0x444ed85a, 0x77f6218b, 0x669e76c4, 0xa3e57c71, 0xb28d2b3e, 0x8135d2ef, 0x905d85a0, 0x5058e048, 0x4130b707, 0x72884ed6,
+ 0x63e01999, 0x459e4502, 0x54f6124d, 0x674eeb9c, 0x7626bcd3, 0xb623d93b, 0xa74b8e74, 0x94f377a5, 0x859b20ea, 0x6f120e97, 0x7e7a59d8,
+ 0x4dc2a009, 0x5caaf746, 0x9caf92ae, 0x8dc7c5e1, 0xbe7f3c30, 0xaf176b7f, 0x896937e4, 0x980160ab, 0xabb9997a, 0xbad1ce35, 0x7ad4abdd,
+ 0x6bbcfc92, 0x58040543, 0x496c520c},
+ {0x00000000, 0xcadca15b, 0x94b943b7, 0x5e65e2ec, 0x9f6e466a, 0x55b2e731, 0x0bd705dd, 0xc10ba486, 0x3edd8cd4, 0xf4012d8f, 0xaa64cf63,
+ 0x60b86e38, 0xa1b3cabe, 0x6b6f6be5, 0x350a8909, 0xffd62852, 0xcba7d8ad, 0x017b79f6, 0x5f1e9b1a, 0x95c23a41, 0x54c99ec7, 0x9e153f9c,
+ 0xc070dd70, 0x0aac7c2b, 0xf57a5479, 0x3fa6f522, 0x61c317ce, 0xab1fb695, 0x6a141213, 0xa0c8b348, 0xfead51a4, 0x3471f0ff, 0x2152705f,
+ 0xeb8ed104, 0xb5eb33e8, 0x7f3792b3, 0xbe3c3635, 0x74e0976e, 0x2a857582, 0xe059d4d9, 0x1f8ffc8b, 0xd5535dd0, 0x8b36bf3c, 0x41ea1e67,
+ 0x80e1bae1, 0x4a3d1bba, 0x1458f956, 0xde84580d, 0xeaf5a8f2, 0x202909a9, 0x7e4ceb45, 0xb4904a1e, 0x759bee98, 0xbf474fc3, 0xe122ad2f,
+ 0x2bfe0c74, 0xd4282426, 0x1ef4857d, 0x40916791, 0x8a4dc6ca, 0x4b46624c, 0x819ac317, 0xdfff21fb, 0x152380a0, 0x42a4e0be, 0x887841e5,
+ 0xd61da309, 0x1cc10252, 0xddcaa6d4, 0x1716078f, 0x4973e563, 0x83af4438, 0x7c796c6a, 0xb6a5cd31, 0xe8c02fdd, 0x221c8e86, 0xe3172a00,
+ 0x29cb8b5b, 0x77ae69b7, 0xbd72c8ec, 0x89033813, 0x43df9948, 0x1dba7ba4, 0xd766daff, 0x166d7e79, 0xdcb1df22, 0x82d43dce, 0x48089c95,
+ 0xb7deb4c7, 0x7d02159c, 0x2367f770, 0xe9bb562b, 0x28b0f2ad, 0xe26c53f6, 0xbc09b11a, 0x76d51041, 0x63f690e1, 0xa92a31ba, 0xf74fd356,
+ 0x3d93720d, 0xfc98d68b, 0x364477d0, 0x6821953c, 0xa2fd3467, 0x5d2b1c35, 0x97f7bd6e, 0xc9925f82, 0x034efed9, 0xc2455a5f, 0x0899fb04,
+ 0x56fc19e8, 0x9c20b8b3, 0xa851484c, 0x628de917, 0x3ce80bfb, 0xf634aaa0, 0x373f0e26, 0xfde3af7d, 0xa3864d91, 0x695aecca, 0x968cc498,
+ 0x5c5065c3, 0x0235872f, 0xc8e92674, 0x09e282f2, 0xc33e23a9, 0x9d5bc145, 0x5787601e, 0x33550079, 0xf989a122, 0xa7ec43ce, 0x6d30e295,
+ 0xac3b4613, 0x66e7e748, 0x388205a4, 0xf25ea4ff, 0x0d888cad, 0xc7542df6, 0x9931cf1a, 0x53ed6e41, 0x92e6cac7, 0x583a6b9c, 0x065f8970,
+ 0xcc83282b, 0xf8f2d8d4, 0x322e798f, 0x6c4b9b63, 0xa6973a38, 0x679c9ebe, 0xad403fe5, 0xf325dd09, 0x39f97c52, 0xc62f5400, 0x0cf3f55b,
+ 0x529617b7, 0x984ab6ec, 0x5941126a, 0x939db331, 0xcdf851dd, 0x0724f086, 0x12077026, 0xd8dbd17d, 0x86be3391, 0x4c6292ca, 0x8d69364c,
+ 0x47b59717, 0x19d075fb, 0xd30cd4a0, 0x2cdafcf2, 0xe6065da9, 0xb863bf45, 0x72bf1e1e, 0xb3b4ba98, 0x79681bc3, 0x270df92f, 0xedd15874,
+ 0xd9a0a88b, 0x137c09d0, 0x4d19eb3c, 0x87c54a67, 0x46ceeee1, 0x8c124fba, 0xd277ad56, 0x18ab0c0d, 0xe77d245f, 0x2da18504, 0x73c467e8,
+ 0xb918c6b3, 0x78136235, 0xb2cfc36e, 0xecaa2182, 0x267680d9, 0x71f1e0c7, 0xbb2d419c, 0xe548a370, 0x2f94022b, 0xee9fa6ad, 0x244307f6,
+ 0x7a26e51a, 0xb0fa4441, 0x4f2c6c13, 0x85f0cd48, 0xdb952fa4, 0x11498eff, 0xd0422a79, 0x1a9e8b22, 0x44fb69ce, 0x8e27c895, 0xba56386a,
+ 0x708a9931, 0x2eef7bdd, 0xe433da86, 0x25387e00, 0xefe4df5b, 0xb1813db7, 0x7b5d9cec, 0x848bb4be, 0x4e5715e5, 0x1032f709, 0xdaee5652,
+ 0x1be5f2d4, 0xd139538f, 0x8f5cb163, 0x45801038, 0x50a39098, 0x9a7f31c3, 0xc41ad32f, 0x0ec67274, 0xcfcdd6f2, 0x051177a9, 0x5b749545,
+ 0x91a8341e, 0x6e7e1c4c, 0xa4a2bd17, 0xfac75ffb, 0x301bfea0, 0xf1105a26, 0x3bccfb7d, 0x65a91991, 0xaf75b8ca, 0x9b044835, 0x51d8e96e,
+ 0x0fbd0b82, 0xc561aad9, 0x046a0e5f, 0xceb6af04, 0x90d34de8, 0x5a0fecb3, 0xa5d9c4e1, 0x6f0565ba, 0x31608756, 0xfbbc260d, 0x3ab7828b,
+ 0xf06b23d0, 0xae0ec13c, 0x64d26067}};
+
+static const uint32_t CRCTablesSB8[8][256] = {
+ {0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3, 0x0edb8832, 0x79dcb8a4, 0xe0d5e91e,
+ 0x97d2d988, 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb,
+ 0xf4d4b551, 0x83d385c7, 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5, 0x3b6e20c8,
+ 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940,
+ 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599,
+ 0xb8bda50f, 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, 0x76dc4190, 0x01db7106,
+ 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433, 0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb,
+ 0x086d3d2d, 0x91646c97, 0xe6635c01, 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457,
+ 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, 0x4db26158, 0x3ab551ce, 0xa3bc0074,
+ 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb, 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5,
+ 0xaa0a4c5f, 0xdd0d7cc9, 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, 0x5edef90e,
+ 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad, 0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a,
+ 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683, 0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27,
+ 0x7d079eb1, 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7, 0xfed41b76, 0x89d32be0,
+ 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1,
+ 0xa6bc5767, 0x3fb506dd, 0x48b2364b, 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79,
+ 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92,
+ 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f,
+ 0x72076785, 0x05005713, 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21, 0x86d3d2d4,
+ 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c,
+ 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45, 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d,
+ 0x3e6e77db, 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, 0xbdbdf21c, 0xcabac28a,
+ 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf, 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37,
+ 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d},
+ {0x00000000, 0x191b3141, 0x32366282, 0x2b2d53c3, 0x646cc504, 0x7d77f445, 0x565aa786, 0x4f4196c7, 0xc8d98a08, 0xd1c2bb49, 0xfaefe88a,
+ 0xe3f4d9cb, 0xacb54f0c, 0xb5ae7e4d, 0x9e832d8e, 0x87981ccf, 0x4ac21251, 0x53d92310, 0x78f470d3, 0x61ef4192, 0x2eaed755, 0x37b5e614,
+ 0x1c98b5d7, 0x05838496, 0x821b9859, 0x9b00a918, 0xb02dfadb, 0xa936cb9a, 0xe6775d5d, 0xff6c6c1c, 0xd4413fdf, 0xcd5a0e9e, 0x958424a2,
+ 0x8c9f15e3, 0xa7b24620, 0xbea97761, 0xf1e8e1a6, 0xe8f3d0e7, 0xc3de8324, 0xdac5b265, 0x5d5daeaa, 0x44469feb, 0x6f6bcc28, 0x7670fd69,
+ 0x39316bae, 0x202a5aef, 0x0b07092c, 0x121c386d, 0xdf4636f3, 0xc65d07b2, 0xed705471, 0xf46b6530, 0xbb2af3f7, 0xa231c2b6, 0x891c9175,
+ 0x9007a034, 0x179fbcfb, 0x0e848dba, 0x25a9de79, 0x3cb2ef38, 0x73f379ff, 0x6ae848be, 0x41c51b7d, 0x58de2a3c, 0xf0794f05, 0xe9627e44,
+ 0xc24f2d87, 0xdb541cc6, 0x94158a01, 0x8d0ebb40, 0xa623e883, 0xbf38d9c2, 0x38a0c50d, 0x21bbf44c, 0x0a96a78f, 0x138d96ce, 0x5ccc0009,
+ 0x45d73148, 0x6efa628b, 0x77e153ca, 0xbabb5d54, 0xa3a06c15, 0x888d3fd6, 0x91960e97, 0xded79850, 0xc7cca911, 0xece1fad2, 0xf5facb93,
+ 0x7262d75c, 0x6b79e61d, 0x4054b5de, 0x594f849f, 0x160e1258, 0x0f152319, 0x243870da, 0x3d23419b, 0x65fd6ba7, 0x7ce65ae6, 0x57cb0925,
+ 0x4ed03864, 0x0191aea3, 0x188a9fe2, 0x33a7cc21, 0x2abcfd60, 0xad24e1af, 0xb43fd0ee, 0x9f12832d, 0x8609b26c, 0xc94824ab, 0xd05315ea,
+ 0xfb7e4629, 0xe2657768, 0x2f3f79f6, 0x362448b7, 0x1d091b74, 0x04122a35, 0x4b53bcf2, 0x52488db3, 0x7965de70, 0x607eef31, 0xe7e6f3fe,
+ 0xfefdc2bf, 0xd5d0917c, 0xcccba03d, 0x838a36fa, 0x9a9107bb, 0xb1bc5478, 0xa8a76539, 0x3b83984b, 0x2298a90a, 0x09b5fac9, 0x10aecb88,
+ 0x5fef5d4f, 0x46f46c0e, 0x6dd93fcd, 0x74c20e8c, 0xf35a1243, 0xea412302, 0xc16c70c1, 0xd8774180, 0x9736d747, 0x8e2de606, 0xa500b5c5,
+ 0xbc1b8484, 0x71418a1a, 0x685abb5b, 0x4377e898, 0x5a6cd9d9, 0x152d4f1e, 0x0c367e5f, 0x271b2d9c, 0x3e001cdd, 0xb9980012, 0xa0833153,
+ 0x8bae6290, 0x92b553d1, 0xddf4c516, 0xc4eff457, 0xefc2a794, 0xf6d996d5, 0xae07bce9, 0xb71c8da8, 0x9c31de6b, 0x852aef2a, 0xca6b79ed,
+ 0xd37048ac, 0xf85d1b6f, 0xe1462a2e, 0x66de36e1, 0x7fc507a0, 0x54e85463, 0x4df36522, 0x02b2f3e5, 0x1ba9c2a4, 0x30849167, 0x299fa026,
+ 0xe4c5aeb8, 0xfdde9ff9, 0xd6f3cc3a, 0xcfe8fd7b, 0x80a96bbc, 0x99b25afd, 0xb29f093e, 0xab84387f, 0x2c1c24b0, 0x350715f1, 0x1e2a4632,
+ 0x07317773, 0x4870e1b4, 0x516bd0f5, 0x7a468336, 0x635db277, 0xcbfad74e, 0xd2e1e60f, 0xf9ccb5cc, 0xe0d7848d, 0xaf96124a, 0xb68d230b,
+ 0x9da070c8, 0x84bb4189, 0x03235d46, 0x1a386c07, 0x31153fc4, 0x280e0e85, 0x674f9842, 0x7e54a903, 0x5579fac0, 0x4c62cb81, 0x8138c51f,
+ 0x9823f45e, 0xb30ea79d, 0xaa1596dc, 0xe554001b, 0xfc4f315a, 0xd7626299, 0xce7953d8, 0x49e14f17, 0x50fa7e56, 0x7bd72d95, 0x62cc1cd4,
+ 0x2d8d8a13, 0x3496bb52, 0x1fbbe891, 0x06a0d9d0, 0x5e7ef3ec, 0x4765c2ad, 0x6c48916e, 0x7553a02f, 0x3a1236e8, 0x230907a9, 0x0824546a,
+ 0x113f652b, 0x96a779e4, 0x8fbc48a5, 0xa4911b66, 0xbd8a2a27, 0xf2cbbce0, 0xebd08da1, 0xc0fdde62, 0xd9e6ef23, 0x14bce1bd, 0x0da7d0fc,
+ 0x268a833f, 0x3f91b27e, 0x70d024b9, 0x69cb15f8, 0x42e6463b, 0x5bfd777a, 0xdc656bb5, 0xc57e5af4, 0xee530937, 0xf7483876, 0xb809aeb1,
+ 0xa1129ff0, 0x8a3fcc33, 0x9324fd72},
+ {0x00000000, 0x01c26a37, 0x0384d46e, 0x0246be59, 0x0709a8dc, 0x06cbc2eb, 0x048d7cb2, 0x054f1685, 0x0e1351b8, 0x0fd13b8f, 0x0d9785d6,
+ 0x0c55efe1, 0x091af964, 0x08d89353, 0x0a9e2d0a, 0x0b5c473d, 0x1c26a370, 0x1de4c947, 0x1fa2771e, 0x1e601d29, 0x1b2f0bac, 0x1aed619b,
+ 0x18abdfc2, 0x1969b5f5, 0x1235f2c8, 0x13f798ff, 0x11b126a6, 0x10734c91, 0x153c5a14, 0x14fe3023, 0x16b88e7a, 0x177ae44d, 0x384d46e0,
+ 0x398f2cd7, 0x3bc9928e, 0x3a0bf8b9, 0x3f44ee3c, 0x3e86840b, 0x3cc03a52, 0x3d025065, 0x365e1758, 0x379c7d6f, 0x35dac336, 0x3418a901,
+ 0x3157bf84, 0x3095d5b3, 0x32d36bea, 0x331101dd, 0x246be590, 0x25a98fa7, 0x27ef31fe, 0x262d5bc9, 0x23624d4c, 0x22a0277b, 0x20e69922,
+ 0x2124f315, 0x2a78b428, 0x2bbade1f, 0x29fc6046, 0x283e0a71, 0x2d711cf4, 0x2cb376c3, 0x2ef5c89a, 0x2f37a2ad, 0x709a8dc0, 0x7158e7f7,
+ 0x731e59ae, 0x72dc3399, 0x7793251c, 0x76514f2b, 0x7417f172, 0x75d59b45, 0x7e89dc78, 0x7f4bb64f, 0x7d0d0816, 0x7ccf6221, 0x798074a4,
+ 0x78421e93, 0x7a04a0ca, 0x7bc6cafd, 0x6cbc2eb0, 0x6d7e4487, 0x6f38fade, 0x6efa90e9, 0x6bb5866c, 0x6a77ec5b, 0x68315202, 0x69f33835,
+ 0x62af7f08, 0x636d153f, 0x612bab66, 0x60e9c151, 0x65a6d7d4, 0x6464bde3, 0x662203ba, 0x67e0698d, 0x48d7cb20, 0x4915a117, 0x4b531f4e,
+ 0x4a917579, 0x4fde63fc, 0x4e1c09cb, 0x4c5ab792, 0x4d98dda5, 0x46c49a98, 0x4706f0af, 0x45404ef6, 0x448224c1, 0x41cd3244, 0x400f5873,
+ 0x4249e62a, 0x438b8c1d, 0x54f16850, 0x55330267, 0x5775bc3e, 0x56b7d609, 0x53f8c08c, 0x523aaabb, 0x507c14e2, 0x51be7ed5, 0x5ae239e8,
+ 0x5b2053df, 0x5966ed86, 0x58a487b1, 0x5deb9134, 0x5c29fb03, 0x5e6f455a, 0x5fad2f6d, 0xe1351b80, 0xe0f771b7, 0xe2b1cfee, 0xe373a5d9,
+ 0xe63cb35c, 0xe7fed96b, 0xe5b86732, 0xe47a0d05, 0xef264a38, 0xeee4200f, 0xeca29e56, 0xed60f461, 0xe82fe2e4, 0xe9ed88d3, 0xebab368a,
+ 0xea695cbd, 0xfd13b8f0, 0xfcd1d2c7, 0xfe976c9e, 0xff5506a9, 0xfa1a102c, 0xfbd87a1b, 0xf99ec442, 0xf85cae75, 0xf300e948, 0xf2c2837f,
+ 0xf0843d26, 0xf1465711, 0xf4094194, 0xf5cb2ba3, 0xf78d95fa, 0xf64fffcd, 0xd9785d60, 0xd8ba3757, 0xdafc890e, 0xdb3ee339, 0xde71f5bc,
+ 0xdfb39f8b, 0xddf521d2, 0xdc374be5, 0xd76b0cd8, 0xd6a966ef, 0xd4efd8b6, 0xd52db281, 0xd062a404, 0xd1a0ce33, 0xd3e6706a, 0xd2241a5d,
+ 0xc55efe10, 0xc49c9427, 0xc6da2a7e, 0xc7184049, 0xc25756cc, 0xc3953cfb, 0xc1d382a2, 0xc011e895, 0xcb4dafa8, 0xca8fc59f, 0xc8c97bc6,
+ 0xc90b11f1, 0xcc440774, 0xcd866d43, 0xcfc0d31a, 0xce02b92d, 0x91af9640, 0x906dfc77, 0x922b422e, 0x93e92819, 0x96a63e9c, 0x976454ab,
+ 0x9522eaf2, 0x94e080c5, 0x9fbcc7f8, 0x9e7eadcf, 0x9c381396, 0x9dfa79a1, 0x98b56f24, 0x99770513, 0x9b31bb4a, 0x9af3d17d, 0x8d893530,
+ 0x8c4b5f07, 0x8e0de15e, 0x8fcf8b69, 0x8a809dec, 0x8b42f7db, 0x89044982, 0x88c623b5, 0x839a6488, 0x82580ebf, 0x801eb0e6, 0x81dcdad1,
+ 0x8493cc54, 0x8551a663, 0x8717183a, 0x86d5720d, 0xa9e2d0a0, 0xa820ba97, 0xaa6604ce, 0xaba46ef9, 0xaeeb787c, 0xaf29124b, 0xad6fac12,
+ 0xacadc625, 0xa7f18118, 0xa633eb2f, 0xa4755576, 0xa5b73f41, 0xa0f829c4, 0xa13a43f3, 0xa37cfdaa, 0xa2be979d, 0xb5c473d0, 0xb40619e7,
+ 0xb640a7be, 0xb782cd89, 0xb2cddb0c, 0xb30fb13b, 0xb1490f62, 0xb08b6555, 0xbbd72268, 0xba15485f, 0xb853f606, 0xb9919c31, 0xbcde8ab4,
+ 0xbd1ce083, 0xbf5a5eda, 0xbe9834ed},
+ {0x00000000, 0xb8bc6765, 0xaa09c88b, 0x12b5afee, 0x8f629757, 0x37def032, 0x256b5fdc, 0x9dd738b9, 0xc5b428ef, 0x7d084f8a, 0x6fbde064,
+ 0xd7018701, 0x4ad6bfb8, 0xf26ad8dd, 0xe0df7733, 0x58631056, 0x5019579f, 0xe8a530fa, 0xfa109f14, 0x42acf871, 0xdf7bc0c8, 0x67c7a7ad,
+ 0x75720843, 0xcdce6f26, 0x95ad7f70, 0x2d111815, 0x3fa4b7fb, 0x8718d09e, 0x1acfe827, 0xa2738f42, 0xb0c620ac, 0x087a47c9, 0xa032af3e,
+ 0x188ec85b, 0x0a3b67b5, 0xb28700d0, 0x2f503869, 0x97ec5f0c, 0x8559f0e2, 0x3de59787, 0x658687d1, 0xdd3ae0b4, 0xcf8f4f5a, 0x7733283f,
+ 0xeae41086, 0x525877e3, 0x40edd80d, 0xf851bf68, 0xf02bf8a1, 0x48979fc4, 0x5a22302a, 0xe29e574f, 0x7f496ff6, 0xc7f50893, 0xd540a77d,
+ 0x6dfcc018, 0x359fd04e, 0x8d23b72b, 0x9f9618c5, 0x272a7fa0, 0xbafd4719, 0x0241207c, 0x10f48f92, 0xa848e8f7, 0x9b14583d, 0x23a83f58,
+ 0x311d90b6, 0x89a1f7d3, 0x1476cf6a, 0xaccaa80f, 0xbe7f07e1, 0x06c36084, 0x5ea070d2, 0xe61c17b7, 0xf4a9b859, 0x4c15df3c, 0xd1c2e785,
+ 0x697e80e0, 0x7bcb2f0e, 0xc377486b, 0xcb0d0fa2, 0x73b168c7, 0x6104c729, 0xd9b8a04c, 0x446f98f5, 0xfcd3ff90, 0xee66507e, 0x56da371b,
+ 0x0eb9274d, 0xb6054028, 0xa4b0efc6, 0x1c0c88a3, 0x81dbb01a, 0x3967d77f, 0x2bd27891, 0x936e1ff4, 0x3b26f703, 0x839a9066, 0x912f3f88,
+ 0x299358ed, 0xb4446054, 0x0cf80731, 0x1e4da8df, 0xa6f1cfba, 0xfe92dfec, 0x462eb889, 0x549b1767, 0xec277002, 0x71f048bb, 0xc94c2fde,
+ 0xdbf98030, 0x6345e755, 0x6b3fa09c, 0xd383c7f9, 0xc1366817, 0x798a0f72, 0xe45d37cb, 0x5ce150ae, 0x4e54ff40, 0xf6e89825, 0xae8b8873,
+ 0x1637ef16, 0x048240f8, 0xbc3e279d, 0x21e91f24, 0x99557841, 0x8be0d7af, 0x335cb0ca, 0xed59b63b, 0x55e5d15e, 0x47507eb0, 0xffec19d5,
+ 0x623b216c, 0xda874609, 0xc832e9e7, 0x708e8e82, 0x28ed9ed4, 0x9051f9b1, 0x82e4565f, 0x3a58313a, 0xa78f0983, 0x1f336ee6, 0x0d86c108,
+ 0xb53aa66d, 0xbd40e1a4, 0x05fc86c1, 0x1749292f, 0xaff54e4a, 0x322276f3, 0x8a9e1196, 0x982bbe78, 0x2097d91d, 0x78f4c94b, 0xc048ae2e,
+ 0xd2fd01c0, 0x6a4166a5, 0xf7965e1c, 0x4f2a3979, 0x5d9f9697, 0xe523f1f2, 0x4d6b1905, 0xf5d77e60, 0xe762d18e, 0x5fdeb6eb, 0xc2098e52,
+ 0x7ab5e937, 0x680046d9, 0xd0bc21bc, 0x88df31ea, 0x3063568f, 0x22d6f961, 0x9a6a9e04, 0x07bda6bd, 0xbf01c1d8, 0xadb46e36, 0x15080953,
+ 0x1d724e9a, 0xa5ce29ff, 0xb77b8611, 0x0fc7e174, 0x9210d9cd, 0x2aacbea8, 0x38191146, 0x80a57623, 0xd8c66675, 0x607a0110, 0x72cfaefe,
+ 0xca73c99b, 0x57a4f122, 0xef189647, 0xfdad39a9, 0x45115ecc, 0x764dee06, 0xcef18963, 0xdc44268d, 0x64f841e8, 0xf92f7951, 0x41931e34,
+ 0x5326b1da, 0xeb9ad6bf, 0xb3f9c6e9, 0x0b45a18c, 0x19f00e62, 0xa14c6907, 0x3c9b51be, 0x842736db, 0x96929935, 0x2e2efe50, 0x2654b999,
+ 0x9ee8defc, 0x8c5d7112, 0x34e11677, 0xa9362ece, 0x118a49ab, 0x033fe645, 0xbb838120, 0xe3e09176, 0x5b5cf613, 0x49e959fd, 0xf1553e98,
+ 0x6c820621, 0xd43e6144, 0xc68bceaa, 0x7e37a9cf, 0xd67f4138, 0x6ec3265d, 0x7c7689b3, 0xc4caeed6, 0x591dd66f, 0xe1a1b10a, 0xf3141ee4,
+ 0x4ba87981, 0x13cb69d7, 0xab770eb2, 0xb9c2a15c, 0x017ec639, 0x9ca9fe80, 0x241599e5, 0x36a0360b, 0x8e1c516e, 0x866616a7, 0x3eda71c2,
+ 0x2c6fde2c, 0x94d3b949, 0x090481f0, 0xb1b8e695, 0xa30d497b, 0x1bb12e1e, 0x43d23e48, 0xfb6e592d, 0xe9dbf6c3, 0x516791a6, 0xccb0a91f,
+ 0x740cce7a, 0x66b96194, 0xde0506f1},
+ {0x00000000, 0x3d6029b0, 0x7ac05360, 0x47a07ad0, 0xf580a6c0, 0xc8e08f70, 0x8f40f5a0, 0xb220dc10, 0x30704bc1, 0x0d106271, 0x4ab018a1,
+ 0x77d03111, 0xc5f0ed01, 0xf890c4b1, 0xbf30be61, 0x825097d1, 0x60e09782, 0x5d80be32, 0x1a20c4e2, 0x2740ed52, 0x95603142, 0xa80018f2,
+ 0xefa06222, 0xd2c04b92, 0x5090dc43, 0x6df0f5f3, 0x2a508f23, 0x1730a693, 0xa5107a83, 0x98705333, 0xdfd029e3, 0xe2b00053, 0xc1c12f04,
+ 0xfca106b4, 0xbb017c64, 0x866155d4, 0x344189c4, 0x0921a074, 0x4e81daa4, 0x73e1f314, 0xf1b164c5, 0xccd14d75, 0x8b7137a5, 0xb6111e15,
+ 0x0431c205, 0x3951ebb5, 0x7ef19165, 0x4391b8d5, 0xa121b886, 0x9c419136, 0xdbe1ebe6, 0xe681c256, 0x54a11e46, 0x69c137f6, 0x2e614d26,
+ 0x13016496, 0x9151f347, 0xac31daf7, 0xeb91a027, 0xd6f18997, 0x64d15587, 0x59b17c37, 0x1e1106e7, 0x23712f57, 0x58f35849, 0x659371f9,
+ 0x22330b29, 0x1f532299, 0xad73fe89, 0x9013d739, 0xd7b3ade9, 0xead38459, 0x68831388, 0x55e33a38, 0x124340e8, 0x2f236958, 0x9d03b548,
+ 0xa0639cf8, 0xe7c3e628, 0xdaa3cf98, 0x3813cfcb, 0x0573e67b, 0x42d39cab, 0x7fb3b51b, 0xcd93690b, 0xf0f340bb, 0xb7533a6b, 0x8a3313db,
+ 0x0863840a, 0x3503adba, 0x72a3d76a, 0x4fc3feda, 0xfde322ca, 0xc0830b7a, 0x872371aa, 0xba43581a, 0x9932774d, 0xa4525efd, 0xe3f2242d,
+ 0xde920d9d, 0x6cb2d18d, 0x51d2f83d, 0x167282ed, 0x2b12ab5d, 0xa9423c8c, 0x9422153c, 0xd3826fec, 0xeee2465c, 0x5cc29a4c, 0x61a2b3fc,
+ 0x2602c92c, 0x1b62e09c, 0xf9d2e0cf, 0xc4b2c97f, 0x8312b3af, 0xbe729a1f, 0x0c52460f, 0x31326fbf, 0x7692156f, 0x4bf23cdf, 0xc9a2ab0e,
+ 0xf4c282be, 0xb362f86e, 0x8e02d1de, 0x3c220dce, 0x0142247e, 0x46e25eae, 0x7b82771e, 0xb1e6b092, 0x8c869922, 0xcb26e3f2, 0xf646ca42,
+ 0x44661652, 0x79063fe2, 0x3ea64532, 0x03c66c82, 0x8196fb53, 0xbcf6d2e3, 0xfb56a833, 0xc6368183, 0x74165d93, 0x49767423, 0x0ed60ef3,
+ 0x33b62743, 0xd1062710, 0xec660ea0, 0xabc67470, 0x96a65dc0, 0x248681d0, 0x19e6a860, 0x5e46d2b0, 0x6326fb00, 0xe1766cd1, 0xdc164561,
+ 0x9bb63fb1, 0xa6d61601, 0x14f6ca11, 0x2996e3a1, 0x6e369971, 0x5356b0c1, 0x70279f96, 0x4d47b626, 0x0ae7ccf6, 0x3787e546, 0x85a73956,
+ 0xb8c710e6, 0xff676a36, 0xc2074386, 0x4057d457, 0x7d37fde7, 0x3a978737, 0x07f7ae87, 0xb5d77297, 0x88b75b27, 0xcf1721f7, 0xf2770847,
+ 0x10c70814, 0x2da721a4, 0x6a075b74, 0x576772c4, 0xe547aed4, 0xd8278764, 0x9f87fdb4, 0xa2e7d404, 0x20b743d5, 0x1dd76a65, 0x5a7710b5,
+ 0x67173905, 0xd537e515, 0xe857cca5, 0xaff7b675, 0x92979fc5, 0xe915e8db, 0xd475c16b, 0x93d5bbbb, 0xaeb5920b, 0x1c954e1b, 0x21f567ab,
+ 0x66551d7b, 0x5b3534cb, 0xd965a31a, 0xe4058aaa, 0xa3a5f07a, 0x9ec5d9ca, 0x2ce505da, 0x11852c6a, 0x562556ba, 0x6b457f0a, 0x89f57f59,
+ 0xb49556e9, 0xf3352c39, 0xce550589, 0x7c75d999, 0x4115f029, 0x06b58af9, 0x3bd5a349, 0xb9853498, 0x84e51d28, 0xc34567f8, 0xfe254e48,
+ 0x4c059258, 0x7165bbe8, 0x36c5c138, 0x0ba5e888, 0x28d4c7df, 0x15b4ee6f, 0x521494bf, 0x6f74bd0f, 0xdd54611f, 0xe03448af, 0xa794327f,
+ 0x9af41bcf, 0x18a48c1e, 0x25c4a5ae, 0x6264df7e, 0x5f04f6ce, 0xed242ade, 0xd044036e, 0x97e479be, 0xaa84500e, 0x4834505d, 0x755479ed,
+ 0x32f4033d, 0x0f942a8d, 0xbdb4f69d, 0x80d4df2d, 0xc774a5fd, 0xfa148c4d, 0x78441b9c, 0x4524322c, 0x028448fc, 0x3fe4614c, 0x8dc4bd5c,
+ 0xb0a494ec, 0xf704ee3c, 0xca64c78c},
+ {0x00000000, 0xcb5cd3a5, 0x4dc8a10b, 0x869472ae, 0x9b914216, 0x50cd91b3, 0xd659e31d, 0x1d0530b8, 0xec53826d, 0x270f51c8, 0xa19b2366,
+ 0x6ac7f0c3, 0x77c2c07b, 0xbc9e13de, 0x3a0a6170, 0xf156b2d5, 0x03d6029b, 0xc88ad13e, 0x4e1ea390, 0x85427035, 0x9847408d, 0x531b9328,
+ 0xd58fe186, 0x1ed33223, 0xef8580f6, 0x24d95353, 0xa24d21fd, 0x6911f258, 0x7414c2e0, 0xbf481145, 0x39dc63eb, 0xf280b04e, 0x07ac0536,
+ 0xccf0d693, 0x4a64a43d, 0x81387798, 0x9c3d4720, 0x57619485, 0xd1f5e62b, 0x1aa9358e, 0xebff875b, 0x20a354fe, 0xa6372650, 0x6d6bf5f5,
+ 0x706ec54d, 0xbb3216e8, 0x3da66446, 0xf6fab7e3, 0x047a07ad, 0xcf26d408, 0x49b2a6a6, 0x82ee7503, 0x9feb45bb, 0x54b7961e, 0xd223e4b0,
+ 0x197f3715, 0xe82985c0, 0x23755665, 0xa5e124cb, 0x6ebdf76e, 0x73b8c7d6, 0xb8e41473, 0x3e7066dd, 0xf52cb578, 0x0f580a6c, 0xc404d9c9,
+ 0x4290ab67, 0x89cc78c2, 0x94c9487a, 0x5f959bdf, 0xd901e971, 0x125d3ad4, 0xe30b8801, 0x28575ba4, 0xaec3290a, 0x659ffaaf, 0x789aca17,
+ 0xb3c619b2, 0x35526b1c, 0xfe0eb8b9, 0x0c8e08f7, 0xc7d2db52, 0x4146a9fc, 0x8a1a7a59, 0x971f4ae1, 0x5c439944, 0xdad7ebea, 0x118b384f,
+ 0xe0dd8a9a, 0x2b81593f, 0xad152b91, 0x6649f834, 0x7b4cc88c, 0xb0101b29, 0x36846987, 0xfdd8ba22, 0x08f40f5a, 0xc3a8dcff, 0x453cae51,
+ 0x8e607df4, 0x93654d4c, 0x58399ee9, 0xdeadec47, 0x15f13fe2, 0xe4a78d37, 0x2ffb5e92, 0xa96f2c3c, 0x6233ff99, 0x7f36cf21, 0xb46a1c84,
+ 0x32fe6e2a, 0xf9a2bd8f, 0x0b220dc1, 0xc07ede64, 0x46eaacca, 0x8db67f6f, 0x90b34fd7, 0x5bef9c72, 0xdd7beedc, 0x16273d79, 0xe7718fac,
+ 0x2c2d5c09, 0xaab92ea7, 0x61e5fd02, 0x7ce0cdba, 0xb7bc1e1f, 0x31286cb1, 0xfa74bf14, 0x1eb014d8, 0xd5ecc77d, 0x5378b5d3, 0x98246676,
+ 0x852156ce, 0x4e7d856b, 0xc8e9f7c5, 0x03b52460, 0xf2e396b5, 0x39bf4510, 0xbf2b37be, 0x7477e41b, 0x6972d4a3, 0xa22e0706, 0x24ba75a8,
+ 0xefe6a60d, 0x1d661643, 0xd63ac5e6, 0x50aeb748, 0x9bf264ed, 0x86f75455, 0x4dab87f0, 0xcb3ff55e, 0x006326fb, 0xf135942e, 0x3a69478b,
+ 0xbcfd3525, 0x77a1e680, 0x6aa4d638, 0xa1f8059d, 0x276c7733, 0xec30a496, 0x191c11ee, 0xd240c24b, 0x54d4b0e5, 0x9f886340, 0x828d53f8,
+ 0x49d1805d, 0xcf45f2f3, 0x04192156, 0xf54f9383, 0x3e134026, 0xb8873288, 0x73dbe12d, 0x6eded195, 0xa5820230, 0x2316709e, 0xe84aa33b,
+ 0x1aca1375, 0xd196c0d0, 0x5702b27e, 0x9c5e61db, 0x815b5163, 0x4a0782c6, 0xcc93f068, 0x07cf23cd, 0xf6999118, 0x3dc542bd, 0xbb513013,
+ 0x700de3b6, 0x6d08d30e, 0xa65400ab, 0x20c07205, 0xeb9ca1a0, 0x11e81eb4, 0xdab4cd11, 0x5c20bfbf, 0x977c6c1a, 0x8a795ca2, 0x41258f07,
+ 0xc7b1fda9, 0x0ced2e0c, 0xfdbb9cd9, 0x36e74f7c, 0xb0733dd2, 0x7b2fee77, 0x662adecf, 0xad760d6a, 0x2be27fc4, 0xe0beac61, 0x123e1c2f,
+ 0xd962cf8a, 0x5ff6bd24, 0x94aa6e81, 0x89af5e39, 0x42f38d9c, 0xc467ff32, 0x0f3b2c97, 0xfe6d9e42, 0x35314de7, 0xb3a53f49, 0x78f9ecec,
+ 0x65fcdc54, 0xaea00ff1, 0x28347d5f, 0xe368aefa, 0x16441b82, 0xdd18c827, 0x5b8cba89, 0x90d0692c, 0x8dd55994, 0x46898a31, 0xc01df89f,
+ 0x0b412b3a, 0xfa1799ef, 0x314b4a4a, 0xb7df38e4, 0x7c83eb41, 0x6186dbf9, 0xaada085c, 0x2c4e7af2, 0xe712a957, 0x15921919, 0xdececabc,
+ 0x585ab812, 0x93066bb7, 0x8e035b0f, 0x455f88aa, 0xc3cbfa04, 0x089729a1, 0xf9c19b74, 0x329d48d1, 0xb4093a7f, 0x7f55e9da, 0x6250d962,
+ 0xa90c0ac7, 0x2f987869, 0xe4c4abcc},
+ {0x00000000, 0xa6770bb4, 0x979f1129, 0x31e81a9d, 0xf44f2413, 0x52382fa7, 0x63d0353a, 0xc5a73e8e, 0x33ef4e67, 0x959845d3, 0xa4705f4e,
+ 0x020754fa, 0xc7a06a74, 0x61d761c0, 0x503f7b5d, 0xf64870e9, 0x67de9cce, 0xc1a9977a, 0xf0418de7, 0x56368653, 0x9391b8dd, 0x35e6b369,
+ 0x040ea9f4, 0xa279a240, 0x5431d2a9, 0xf246d91d, 0xc3aec380, 0x65d9c834, 0xa07ef6ba, 0x0609fd0e, 0x37e1e793, 0x9196ec27, 0xcfbd399c,
+ 0x69ca3228, 0x582228b5, 0xfe552301, 0x3bf21d8f, 0x9d85163b, 0xac6d0ca6, 0x0a1a0712, 0xfc5277fb, 0x5a257c4f, 0x6bcd66d2, 0xcdba6d66,
+ 0x081d53e8, 0xae6a585c, 0x9f8242c1, 0x39f54975, 0xa863a552, 0x0e14aee6, 0x3ffcb47b, 0x998bbfcf, 0x5c2c8141, 0xfa5b8af5, 0xcbb39068,
+ 0x6dc49bdc, 0x9b8ceb35, 0x3dfbe081, 0x0c13fa1c, 0xaa64f1a8, 0x6fc3cf26, 0xc9b4c492, 0xf85cde0f, 0x5e2bd5bb, 0x440b7579, 0xe27c7ecd,
+ 0xd3946450, 0x75e36fe4, 0xb044516a, 0x16335ade, 0x27db4043, 0x81ac4bf7, 0x77e43b1e, 0xd19330aa, 0xe07b2a37, 0x460c2183, 0x83ab1f0d,
+ 0x25dc14b9, 0x14340e24, 0xb2430590, 0x23d5e9b7, 0x85a2e203, 0xb44af89e, 0x123df32a, 0xd79acda4, 0x71edc610, 0x4005dc8d, 0xe672d739,
+ 0x103aa7d0, 0xb64dac64, 0x87a5b6f9, 0x21d2bd4d, 0xe47583c3, 0x42028877, 0x73ea92ea, 0xd59d995e, 0x8bb64ce5, 0x2dc14751, 0x1c295dcc,
+ 0xba5e5678, 0x7ff968f6, 0xd98e6342, 0xe86679df, 0x4e11726b, 0xb8590282, 0x1e2e0936, 0x2fc613ab, 0x89b1181f, 0x4c162691, 0xea612d25,
+ 0xdb8937b8, 0x7dfe3c0c, 0xec68d02b, 0x4a1fdb9f, 0x7bf7c102, 0xdd80cab6, 0x1827f438, 0xbe50ff8c, 0x8fb8e511, 0x29cfeea5, 0xdf879e4c,
+ 0x79f095f8, 0x48188f65, 0xee6f84d1, 0x2bc8ba5f, 0x8dbfb1eb, 0xbc57ab76, 0x1a20a0c2, 0x8816eaf2, 0x2e61e146, 0x1f89fbdb, 0xb9fef06f,
+ 0x7c59cee1, 0xda2ec555, 0xebc6dfc8, 0x4db1d47c, 0xbbf9a495, 0x1d8eaf21, 0x2c66b5bc, 0x8a11be08, 0x4fb68086, 0xe9c18b32, 0xd82991af,
+ 0x7e5e9a1b, 0xefc8763c, 0x49bf7d88, 0x78576715, 0xde206ca1, 0x1b87522f, 0xbdf0599b, 0x8c184306, 0x2a6f48b2, 0xdc27385b, 0x7a5033ef,
+ 0x4bb82972, 0xedcf22c6, 0x28681c48, 0x8e1f17fc, 0xbff70d61, 0x198006d5, 0x47abd36e, 0xe1dcd8da, 0xd034c247, 0x7643c9f3, 0xb3e4f77d,
+ 0x1593fcc9, 0x247be654, 0x820cede0, 0x74449d09, 0xd23396bd, 0xe3db8c20, 0x45ac8794, 0x800bb91a, 0x267cb2ae, 0x1794a833, 0xb1e3a387,
+ 0x20754fa0, 0x86024414, 0xb7ea5e89, 0x119d553d, 0xd43a6bb3, 0x724d6007, 0x43a57a9a, 0xe5d2712e, 0x139a01c7, 0xb5ed0a73, 0x840510ee,
+ 0x22721b5a, 0xe7d525d4, 0x41a22e60, 0x704a34fd, 0xd63d3f49, 0xcc1d9f8b, 0x6a6a943f, 0x5b828ea2, 0xfdf58516, 0x3852bb98, 0x9e25b02c,
+ 0xafcdaab1, 0x09baa105, 0xfff2d1ec, 0x5985da58, 0x686dc0c5, 0xce1acb71, 0x0bbdf5ff, 0xadcafe4b, 0x9c22e4d6, 0x3a55ef62, 0xabc30345,
+ 0x0db408f1, 0x3c5c126c, 0x9a2b19d8, 0x5f8c2756, 0xf9fb2ce2, 0xc813367f, 0x6e643dcb, 0x982c4d22, 0x3e5b4696, 0x0fb35c0b, 0xa9c457bf,
+ 0x6c636931, 0xca146285, 0xfbfc7818, 0x5d8b73ac, 0x03a0a617, 0xa5d7ada3, 0x943fb73e, 0x3248bc8a, 0xf7ef8204, 0x519889b0, 0x6070932d,
+ 0xc6079899, 0x304fe870, 0x9638e3c4, 0xa7d0f959, 0x01a7f2ed, 0xc400cc63, 0x6277c7d7, 0x539fdd4a, 0xf5e8d6fe, 0x647e3ad9, 0xc209316d,
+ 0xf3e12bf0, 0x55962044, 0x90311eca, 0x3646157e, 0x07ae0fe3, 0xa1d90457, 0x579174be, 0xf1e67f0a, 0xc00e6597, 0x66796e23, 0xa3de50ad,
+ 0x05a95b19, 0x34414184, 0x92364a30},
+ {0x00000000, 0xccaa009e, 0x4225077d, 0x8e8f07e3, 0x844a0efa, 0x48e00e64, 0xc66f0987, 0x0ac50919, 0xd3e51bb5, 0x1f4f1b2b, 0x91c01cc8,
+ 0x5d6a1c56, 0x57af154f, 0x9b0515d1, 0x158a1232, 0xd92012ac, 0x7cbb312b, 0xb01131b5, 0x3e9e3656, 0xf23436c8, 0xf8f13fd1, 0x345b3f4f,
+ 0xbad438ac, 0x767e3832, 0xaf5e2a9e, 0x63f42a00, 0xed7b2de3, 0x21d12d7d, 0x2b142464, 0xe7be24fa, 0x69312319, 0xa59b2387, 0xf9766256,
+ 0x35dc62c8, 0xbb53652b, 0x77f965b5, 0x7d3c6cac, 0xb1966c32, 0x3f196bd1, 0xf3b36b4f, 0x2a9379e3, 0xe639797d, 0x68b67e9e, 0xa41c7e00,
+ 0xaed97719, 0x62737787, 0xecfc7064, 0x205670fa, 0x85cd537d, 0x496753e3, 0xc7e85400, 0x0b42549e, 0x01875d87, 0xcd2d5d19, 0x43a25afa,
+ 0x8f085a64, 0x562848c8, 0x9a824856, 0x140d4fb5, 0xd8a74f2b, 0xd2624632, 0x1ec846ac, 0x9047414f, 0x5ced41d1, 0x299dc2ed, 0xe537c273,
+ 0x6bb8c590, 0xa712c50e, 0xadd7cc17, 0x617dcc89, 0xeff2cb6a, 0x2358cbf4, 0xfa78d958, 0x36d2d9c6, 0xb85dde25, 0x74f7debb, 0x7e32d7a2,
+ 0xb298d73c, 0x3c17d0df, 0xf0bdd041, 0x5526f3c6, 0x998cf358, 0x1703f4bb, 0xdba9f425, 0xd16cfd3c, 0x1dc6fda2, 0x9349fa41, 0x5fe3fadf,
+ 0x86c3e873, 0x4a69e8ed, 0xc4e6ef0e, 0x084cef90, 0x0289e689, 0xce23e617, 0x40ace1f4, 0x8c06e16a, 0xd0eba0bb, 0x1c41a025, 0x92cea7c6,
+ 0x5e64a758, 0x54a1ae41, 0x980baedf, 0x1684a93c, 0xda2ea9a2, 0x030ebb0e, 0xcfa4bb90, 0x412bbc73, 0x8d81bced, 0x8744b5f4, 0x4beeb56a,
+ 0xc561b289, 0x09cbb217, 0xac509190, 0x60fa910e, 0xee7596ed, 0x22df9673, 0x281a9f6a, 0xe4b09ff4, 0x6a3f9817, 0xa6959889, 0x7fb58a25,
+ 0xb31f8abb, 0x3d908d58, 0xf13a8dc6, 0xfbff84df, 0x37558441, 0xb9da83a2, 0x7570833c, 0x533b85da, 0x9f918544, 0x111e82a7, 0xddb48239,
+ 0xd7718b20, 0x1bdb8bbe, 0x95548c5d, 0x59fe8cc3, 0x80de9e6f, 0x4c749ef1, 0xc2fb9912, 0x0e51998c, 0x04949095, 0xc83e900b, 0x46b197e8,
+ 0x8a1b9776, 0x2f80b4f1, 0xe32ab46f, 0x6da5b38c, 0xa10fb312, 0xabcaba0b, 0x6760ba95, 0xe9efbd76, 0x2545bde8, 0xfc65af44, 0x30cfafda,
+ 0xbe40a839, 0x72eaa8a7, 0x782fa1be, 0xb485a120, 0x3a0aa6c3, 0xf6a0a65d, 0xaa4de78c, 0x66e7e712, 0xe868e0f1, 0x24c2e06f, 0x2e07e976,
+ 0xe2ade9e8, 0x6c22ee0b, 0xa088ee95, 0x79a8fc39, 0xb502fca7, 0x3b8dfb44, 0xf727fbda, 0xfde2f2c3, 0x3148f25d, 0xbfc7f5be, 0x736df520,
+ 0xd6f6d6a7, 0x1a5cd639, 0x94d3d1da, 0x5879d144, 0x52bcd85d, 0x9e16d8c3, 0x1099df20, 0xdc33dfbe, 0x0513cd12, 0xc9b9cd8c, 0x4736ca6f,
+ 0x8b9ccaf1, 0x8159c3e8, 0x4df3c376, 0xc37cc495, 0x0fd6c40b, 0x7aa64737, 0xb60c47a9, 0x3883404a, 0xf42940d4, 0xfeec49cd, 0x32464953,
+ 0xbcc94eb0, 0x70634e2e, 0xa9435c82, 0x65e95c1c, 0xeb665bff, 0x27cc5b61, 0x2d095278, 0xe1a352e6, 0x6f2c5505, 0xa386559b, 0x061d761c,
+ 0xcab77682, 0x44387161, 0x889271ff, 0x825778e6, 0x4efd7878, 0xc0727f9b, 0x0cd87f05, 0xd5f86da9, 0x19526d37, 0x97dd6ad4, 0x5b776a4a,
+ 0x51b26353, 0x9d1863cd, 0x1397642e, 0xdf3d64b0, 0x83d02561, 0x4f7a25ff, 0xc1f5221c, 0x0d5f2282, 0x079a2b9b, 0xcb302b05, 0x45bf2ce6,
+ 0x89152c78, 0x50353ed4, 0x9c9f3e4a, 0x121039a9, 0xdeba3937, 0xd47f302e, 0x18d530b0, 0x965a3753, 0x5af037cd, 0xff6b144a, 0x33c114d4,
+ 0xbd4e1337, 0x71e413a9, 0x7b211ab0, 0xb78b1a2e, 0x39041dcd, 0xf5ae1d53, 0x2c8e0fff, 0xe0240f61, 0x6eab0882, 0xa201081c, 0xa8c40105,
+ 0x646e019b, 0xeae10678, 0x264b06e6}};
+
+#define BYTESWAP_ORDER32(x) (((x) >> 24) + (((x) >> 8) & 0xff00) + (((x) << 8) & 0xff0000) + ((x) << 24))
+#define UE_PTRDIFF_TO_UINT32(argument) static_cast<uint32_t>(argument)
+
+template<typename T>
+constexpr T
+Align(T Val, uint64_t Alignment)
+{
+ return (T)(((uint64_t)Val + Alignment - 1) & ~(Alignment - 1));
+}
+
+} // namespace CRC32
+
+namespace zen {
+
+uint32_t
+StrCrc_Deprecated(const char* Data)
+{
+ using namespace CRC32;
+
+ uint32_t CRC = 0xFFFFFFFF;
+ while (*Data)
+ {
+ char16_t C = *Data++;
+ int32_t CL = (C & 255);
+ CRC = (CRC << 8) ^ CRCTable_DEPRECATED[(CRC >> 24) ^ CL];
+ int32_t CH = (C >> 8) & 255;
+ CRC = (CRC << 8) ^ CRCTable_DEPRECATED[(CRC >> 24) ^ CH];
+ }
+ return ~CRC;
+}
+
+uint32_t
+MemCrc32(const void* InData, size_t Length, uint32_t CRC /*=0 */)
+{
+ using namespace CRC32;
+
+ // Based on the Slicing-by-8 implementation found here:
+ // http://slicing-by-8.sourceforge.net/
+
+ CRC = ~CRC;
+
+ const uint8_t* __restrict Data = (uint8_t*)InData;
+
+ // First we need to align to 32-bits
+ uint32_t InitBytes = UE_PTRDIFF_TO_UINT32(Align(Data, 4) - Data);
+
+ if (Length > InitBytes)
+ {
+ Length -= InitBytes;
+
+ for (; InitBytes; --InitBytes)
+ {
+ CRC = (CRC >> 8) ^ CRCTablesSB8[0][(CRC & 0xFF) ^ *Data++];
+ }
+
+ auto Data4 = (const uint32_t*)Data;
+ for (size_t Repeat = Length / 8; Repeat; --Repeat)
+ {
+ uint32_t V1 = *Data4++ ^ CRC;
+ uint32_t V2 = *Data4++;
+ CRC = CRCTablesSB8[7][V1 & 0xFF] ^ CRCTablesSB8[6][(V1 >> 8) & 0xFF] ^ CRCTablesSB8[5][(V1 >> 16) & 0xFF] ^
+ CRCTablesSB8[4][V1 >> 24] ^ CRCTablesSB8[3][V2 & 0xFF] ^ CRCTablesSB8[2][(V2 >> 8) & 0xFF] ^
+ CRCTablesSB8[1][(V2 >> 16) & 0xFF] ^ CRCTablesSB8[0][V2 >> 24];
+ }
+ Data = (const uint8_t*)Data4;
+
+ Length %= 8;
+ }
+
+ for (; Length; --Length)
+ {
+ CRC = (CRC >> 8) ^ CRCTablesSB8[0][(CRC & 0xFF) ^ *Data++];
+ }
+
+ return ~CRC;
+}
+
+uint32_t
+MemCrc32_Deprecated(const void* InData, size_t Length, uint32_t CRC)
+{
+ using namespace CRC32;
+
+ // Based on the Slicing-by-8 implementation found here:
+ // http://slicing-by-8.sourceforge.net/
+
+ CRC = ~BYTESWAP_ORDER32(CRC);
+
+ const uint8_t* __restrict Data = (uint8_t*)InData;
+
+ // First we need to align to 32-bits
+ uint32_t InitBytes = UE_PTRDIFF_TO_UINT32(Align(Data, 4) - Data);
+
+ if (Length > InitBytes)
+ {
+ Length -= InitBytes;
+
+ for (; InitBytes; --InitBytes)
+ {
+ CRC = (CRC >> 8) ^ CRCTablesSB8_DEPRECATED[0][(CRC & 0xFF) ^ *Data++];
+ }
+
+ auto Data4 = (const uint32_t*)Data;
+ for (size_t Repeat = Length / 8; Repeat; --Repeat)
+ {
+ uint32_t V1 = *Data4++ ^ CRC;
+ uint32_t V2 = *Data4++;
+ CRC = CRCTablesSB8_DEPRECATED[7][V1 & 0xFF] ^ CRCTablesSB8_DEPRECATED[6][(V1 >> 8) & 0xFF] ^
+ CRCTablesSB8_DEPRECATED[5][(V1 >> 16) & 0xFF] ^ CRCTablesSB8_DEPRECATED[4][V1 >> 24] ^
+ CRCTablesSB8_DEPRECATED[3][V2 & 0xFF] ^ CRCTablesSB8_DEPRECATED[2][(V2 >> 8) & 0xFF] ^
+ CRCTablesSB8_DEPRECATED[1][(V2 >> 16) & 0xFF] ^ CRCTablesSB8_DEPRECATED[0][V2 >> 24];
+ }
+ Data = (const uint8_t*)Data4;
+
+ Length %= 8;
+ }
+
+ for (; Length; --Length)
+ {
+ CRC = (CRC >> 8) ^ CRCTablesSB8_DEPRECATED[0][(CRC & 0xFF) ^ *Data++];
+ }
+
+ return BYTESWAP_ORDER32(~CRC);
+}
+
+} // namespace zen
diff --git a/src/zencore/crypto.cpp b/src/zencore/crypto.cpp
new file mode 100644
index 000000000..448fd36fa
--- /dev/null
+++ b/src/zencore/crypto.cpp
@@ -0,0 +1,208 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/crypto.h>
+#include <zencore/intmath.h>
+#include <zencore/testing.h>
+
+#include <string>
+#include <string_view>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <openssl/conf.h>
+#include <openssl/err.h>
+#include <openssl/evp.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# pragma comment(lib, "crypt32.lib")
+# pragma comment(lib, "ws2_32.lib")
+#endif
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace crypto {
+
+ class EvpContext
+ {
+ public:
+ EvpContext() : m_Ctx(EVP_CIPHER_CTX_new()) {}
+ ~EvpContext() { EVP_CIPHER_CTX_free(m_Ctx); }
+
+ operator EVP_CIPHER_CTX*() { return m_Ctx; }
+
+ private:
+ EVP_CIPHER_CTX* m_Ctx;
+ };
+
+ enum class TransformMode : uint32_t
+ {
+ Decrypt,
+ Encrypt
+ };
+
+ MemoryView Transform(const EVP_CIPHER* Cipher,
+ TransformMode Mode,
+ MemoryView Key,
+ MemoryView IV,
+ MemoryView In,
+ MutableMemoryView Out,
+ std::optional<std::string>& Reason)
+ {
+ ZEN_ASSERT(Cipher != nullptr);
+
+ EvpContext Ctx;
+
+ int Err = EVP_CipherInit_ex(Ctx,
+ Cipher,
+ nullptr,
+ reinterpret_cast<const unsigned char*>(Key.GetData()),
+ reinterpret_cast<const unsigned char*>(IV.GetData()),
+ static_cast<int>(Mode));
+
+ if (Err != 1)
+ {
+ Reason = fmt::format("failed to initialize cipher, error code '{}'", Err);
+
+ return MemoryView();
+ }
+
+ int EncryptedBytes = 0;
+ int TotalEncryptedBytes = 0;
+
+ Err = EVP_CipherUpdate(Ctx,
+ reinterpret_cast<unsigned char*>(Out.GetData()),
+ &EncryptedBytes,
+ reinterpret_cast<const unsigned char*>(In.GetData()),
+ static_cast<int>(In.GetSize()));
+
+ if (Err != 1)
+ {
+ Reason = fmt::format("update crypto transform failed, error code '{}'", Err);
+
+ return MemoryView();
+ }
+
+ TotalEncryptedBytes = EncryptedBytes;
+ MutableMemoryView Remaining = Out.RightChop(EncryptedBytes);
+
+ EncryptedBytes = static_cast<int>(Remaining.GetSize());
+
+ Err = EVP_CipherFinal(Ctx, reinterpret_cast<unsigned char*>(Remaining.GetData()), &EncryptedBytes);
+
+ if (Err != 1)
+ {
+ Reason = fmt::format("finalize crypto transform failed, error code '{}'", Err);
+
+ return MemoryView();
+ }
+
+ TotalEncryptedBytes += EncryptedBytes;
+
+ return Out.Left(TotalEncryptedBytes);
+ }
+
+ bool ValidateKeyAndIV(const AesKey256Bit& Key, const AesIV128Bit& IV, std::optional<std::string>& Reason)
+ {
+ if (Key.IsValid() == false)
+ {
+ Reason = "invalid key"sv;
+
+ return false;
+ }
+
+ if (IV.IsValid() == false)
+ {
+ Reason = "invalid initialization vector"sv;
+
+ return false;
+ }
+
+ return true;
+ }
+
+} // namespace crypto
+
+MemoryView
+Aes::Encrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional<std::string>& Reason)
+{
+ if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false)
+ {
+ return MemoryView();
+ }
+
+ return crypto::Transform(EVP_aes_256_cbc(), crypto::TransformMode::Encrypt, Key.GetView(), IV.GetView(), In, Out, Reason);
+}
+
+MemoryView
+Aes::Decrypt(const AesKey256Bit& Key, const AesIV128Bit& IV, MemoryView In, MutableMemoryView Out, std::optional<std::string>& Reason)
+{
+ if (crypto::ValidateKeyAndIV(Key, IV, Reason) == false)
+ {
+ return MemoryView();
+ }
+
+ return crypto::Transform(EVP_aes_256_cbc(), crypto::TransformMode::Decrypt, Key.GetView(), IV.GetView(), In, Out, Reason);
+}
+
+#if ZEN_WITH_TESTS
+
+void
+crypto_forcelink()
+{
+}
+
+TEST_CASE("crypto.bits")
+{
+ using CryptoBits256Bit = CryptoBits<256>;
+
+ CryptoBits256Bit Bits;
+
+ CHECK(Bits.IsNull());
+ CHECK(Bits.IsValid() == false);
+
+ CHECK(Bits.GetBitCount() == 256);
+ CHECK(Bits.GetSize() == 32);
+
+ Bits = CryptoBits256Bit::FromString("Addff"sv);
+ CHECK(Bits.IsValid() == false);
+
+ Bits = CryptoBits256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv);
+ CHECK(Bits.IsValid());
+
+ auto SmallerBits = CryptoBits<128>::FromString("abcdefghijklmnopqrstuvxyz0123456"sv);
+ CHECK(SmallerBits.IsValid() == false);
+}
+
+TEST_CASE("crypto.aes")
+{
+ SUBCASE("basic")
+ {
+ const uint8_t InitVector[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+ const AesKey256Bit Key = AesKey256Bit::FromString("abcdefghijklmnopqrstuvxyz0123456"sv);
+ const AesIV128Bit IV = AesIV128Bit::FromMemoryView(MakeMemoryView(InitVector));
+
+ std::string_view PlainText = "The quick brown fox jumps over the lazy dog"sv;
+
+ std::vector<uint8_t> EncryptionBuffer;
+ std::vector<uint8_t> DecryptionBuffer;
+ std::optional<std::string> Reason;
+
+ EncryptionBuffer.resize(PlainText.size() + Aes::BlockSize);
+ DecryptionBuffer.resize(PlainText.size() + Aes::BlockSize);
+
+ MemoryView EncryptedView = Aes::Encrypt(Key, IV, MakeMemoryView(PlainText), MakeMutableMemoryView(EncryptionBuffer), Reason);
+ MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedView, MakeMutableMemoryView(DecryptionBuffer), Reason);
+
+ std::string_view EncryptedDecryptedText =
+ std::string_view(reinterpret_cast<const char*>(DecryptedView.GetData()), DecryptedView.GetSize());
+
+ CHECK(EncryptedDecryptedText == PlainText);
+ }
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/except.cpp b/src/zencore/except.cpp
new file mode 100644
index 000000000..2749d1984
--- /dev/null
+++ b/src/zencore/except.cpp
@@ -0,0 +1,93 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <fmt/format.h>
+#include <zencore/except.h>
+
+namespace zen {
+
+#if ZEN_PLATFORM_WINDOWS
+
+class WindowsException : public std::exception
+{
+public:
+ WindowsException(std::string_view Message)
+ {
+ m_hResult = HRESULT_FROM_WIN32(GetLastError());
+ m_Message = Message;
+ }
+
+ WindowsException(HRESULT hRes, std::string_view Message)
+ {
+ m_hResult = hRes;
+ m_Message = Message;
+ }
+
+ WindowsException(HRESULT hRes, const char* Message, const char* Detail)
+ {
+ m_hResult = hRes;
+
+ ExtendableStringBuilder<128> msg;
+ msg.Append(Message);
+ msg.Append(" (detail: '");
+ msg.Append(Detail);
+ msg.Append("')");
+
+ m_Message = msg.c_str();
+ }
+
+ virtual const char* what() const override { return m_Message.c_str(); }
+
+private:
+ std::string m_Message;
+ HRESULT m_hResult;
+};
+
+void
+ThrowSystemException([[maybe_unused]] HRESULT hRes, [[maybe_unused]] std::string_view Message)
+{
+ if (HRESULT_FACILITY(hRes) == FACILITY_WIN32)
+ {
+ throw std::system_error(std::error_code(hRes & 0xffff, std::system_category()), std::string(Message));
+ }
+ else
+ {
+ throw WindowsException(hRes, Message);
+ }
+}
+
+#endif // ZEN_PLATFORM_WINDOWS
+
+void
+ThrowSystemError(uint32_t ErrorCode, std::string_view Message)
+{
+ throw std::system_error(std::error_code(ErrorCode, std::system_category()), std::string(Message));
+}
+
+std::string
+GetLastErrorAsString()
+{
+ return GetSystemErrorAsString(zen::GetLastError());
+}
+
+std::string
+GetSystemErrorAsString(uint32_t ErrorCode)
+{
+ return std::error_code(ErrorCode, std::system_category()).message();
+}
+
+#if defined(__cpp_lib_source_location)
+void
+ThrowLastErrorImpl(std::string_view Message, const std::source_location& Location)
+{
+ throw std::system_error(std::error_code(zen::GetLastError(), std::system_category()),
+ fmt::format("{}({}): {}", Location.file_name(), Location.line(), Message));
+}
+#else
+void
+ThrowLastError(std::string_view Message)
+{
+ throw std::system_error(std::error_code(zen::GetLastError(), std::system_category()), std::string(Message));
+}
+#endif
+
+} // namespace zen
diff --git a/src/zencore/filesystem.cpp b/src/zencore/filesystem.cpp
new file mode 100644
index 000000000..a17773024
--- /dev/null
+++ b/src/zencore/filesystem.cpp
@@ -0,0 +1,1304 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/filesystem.h>
+
+#include <zencore/except.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#endif
+
+#if ZEN_PLATFORM_WINDOWS
+# include <atlbase.h>
+# include <atlfile.h>
+# include <winioctl.h>
+# include <winnt.h>
+#endif
+
+#if ZEN_PLATFORM_LINUX
+# include <dirent.h>
+# include <fcntl.h>
+# include <sys/resource.h>
+# include <sys/stat.h>
+# include <unistd.h>
+#endif
+
+#if ZEN_PLATFORM_MAC
+# include <dirent.h>
+# include <fcntl.h>
+# include <libproc.h>
+# include <sys/resource.h>
+# include <sys/stat.h>
+# include <sys/syslimits.h>
+# include <unistd.h>
+#endif
+
+#include <filesystem>
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+using namespace std::literals;
+
+#if ZEN_PLATFORM_WINDOWS
+
+static bool
+DeleteReparsePoint(const wchar_t* Path, DWORD dwReparseTag)
+{
+ CHandle hDir(CreateFileW(Path,
+ GENERIC_WRITE,
+ FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
+ nullptr,
+ OPEN_EXISTING,
+ FILE_FLAG_BACKUP_SEMANTICS | FILE_FLAG_OPEN_REPARSE_POINT,
+ nullptr));
+
+ if (hDir != INVALID_HANDLE_VALUE)
+ {
+ REPARSE_GUID_DATA_BUFFER Rgdb = {};
+ Rgdb.ReparseTag = dwReparseTag;
+
+ DWORD dwBytes;
+ const BOOL bOK =
+ DeviceIoControl(hDir, FSCTL_DELETE_REPARSE_POINT, &Rgdb, REPARSE_GUID_DATA_BUFFER_HEADER_SIZE, nullptr, 0, &dwBytes, nullptr);
+
+ return bOK == TRUE;
+ }
+
+ return false;
+}
+
+bool
+CreateDirectories(const wchar_t* Dir)
+{
+ // This may be suboptimal, in that it appears to try and create directories
+ // from the root on up instead of from some directory which is known to
+ // be present
+ //
+ // We should implement a smarter version at some point since this can be
+ // pretty expensive in aggregate
+
+ return std::filesystem::create_directories(Dir);
+}
+
+// Erase all files and directories in a given directory, leaving an empty directory
+// behind
+
+static bool
+WipeDirectory(const wchar_t* DirPath)
+{
+ ExtendableWideStringBuilder<128> Pattern;
+ Pattern.Append(DirPath);
+ Pattern.Append(L"\\*");
+
+ WIN32_FIND_DATAW FindData;
+ HANDLE hFind = FindFirstFileW(Pattern.c_str(), &FindData);
+
+ if (hFind != nullptr)
+ {
+ do
+ {
+ bool IsRegular = true;
+
+ if (FindData.cFileName[0] == L'.')
+ {
+ if (FindData.cFileName[1] == L'.')
+ {
+ if (FindData.cFileName[2] == L'\0')
+ {
+ IsRegular = false;
+ }
+ }
+ else if (FindData.cFileName[1] == L'\0')
+ {
+ IsRegular = false;
+ }
+ }
+
+ if (IsRegular)
+ {
+ ExtendableWideStringBuilder<128> Path;
+ Path.Append(DirPath);
+ Path.Append(L'\\');
+ Path.Append(FindData.cFileName);
+
+ // if (fd.dwFileAttributes & FILE_ATTRIBUTE_RECALL_ON_OPEN)
+ // deleteReparsePoint(path.c_str(), fd.dwReserved0);
+
+ if (FindData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)
+ {
+ if (FindData.dwFileAttributes & FILE_ATTRIBUTE_RECALL_ON_OPEN)
+ {
+ DeleteReparsePoint(Path.c_str(), FindData.dwReserved0);
+ }
+
+ if (FindData.dwFileAttributes & FILE_ATTRIBUTE_RECALL_ON_DATA_ACCESS)
+ {
+ DeleteReparsePoint(Path.c_str(), FindData.dwReserved0);
+ }
+
+ bool Success = DeleteDirectories(Path.c_str());
+
+ if (!Success)
+ {
+ if (FindData.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT)
+ {
+ DeleteReparsePoint(Path.c_str(), FindData.dwReserved0);
+ }
+ }
+ }
+ else
+ {
+ DeleteFileW(Path.c_str());
+ }
+ }
+ } while (FindNextFileW(hFind, &FindData) == TRUE);
+
+ FindClose(hFind);
+ }
+
+ return true;
+}
+
+bool
+DeleteDirectories(const wchar_t* DirPath)
+{
+ return WipeDirectory(DirPath) && RemoveDirectoryW(DirPath) == TRUE;
+}
+
+bool
+CleanDirectory(const wchar_t* DirPath)
+{
+ if (std::filesystem::exists(DirPath))
+ {
+ return WipeDirectory(DirPath);
+ }
+
+ return CreateDirectories(DirPath);
+}
+
+#endif // ZEN_PLATFORM_WINDOWS
+
+bool
+CreateDirectories(const std::filesystem::path& Dir)
+{
+ while (!std::filesystem::is_directory(Dir))
+ {
+ if (Dir.has_parent_path())
+ {
+ CreateDirectories(Dir.parent_path());
+ }
+ std::error_code ErrorCode;
+ std::filesystem::create_directory(Dir, ErrorCode);
+ if (ErrorCode)
+ {
+ throw std::system_error(ErrorCode, fmt::format("Failed to create directories for '{}'", Dir.string()));
+ }
+ return true;
+ }
+ return false;
+}
+
+bool
+DeleteDirectories(const std::filesystem::path& Dir)
+{
+#if ZEN_PLATFORM_WINDOWS
+ return DeleteDirectories(Dir.c_str());
+#else
+ std::error_code ErrorCode;
+ return std::filesystem::remove_all(Dir, ErrorCode);
+#endif
+}
+
+bool
+CleanDirectory(const std::filesystem::path& Dir)
+{
+#if ZEN_PLATFORM_WINDOWS
+ return CleanDirectory(Dir.c_str());
+#else
+ if (std::filesystem::exists(Dir))
+ {
+ bool Success = true;
+
+ std::error_code ErrorCode;
+ for (const auto& Item : std::filesystem::directory_iterator(Dir))
+ {
+ Success &= std::filesystem::remove_all(Item, ErrorCode);
+ }
+
+ return Success;
+ }
+
+ return CreateDirectories(Dir);
+#endif
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+bool
+SupportsBlockRefCounting(std::filesystem::path Path)
+{
+#if ZEN_PLATFORM_WINDOWS
+ ATL::CHandle Handle(CreateFileW(Path.c_str(),
+ GENERIC_READ,
+ FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
+ nullptr,
+ OPEN_EXISTING,
+ FILE_FLAG_BACKUP_SEMANTICS,
+ nullptr));
+
+ if (Handle == INVALID_HANDLE_VALUE)
+ {
+ Handle.Detach();
+ return false;
+ }
+
+ ULONG FileSystemFlags = 0;
+ if (!GetVolumeInformationByHandleW(Handle, nullptr, 0, nullptr, nullptr, /* lpFileSystemFlags */ &FileSystemFlags, nullptr, 0))
+ {
+ return false;
+ }
+
+ if (!(FileSystemFlags & FILE_SUPPORTS_BLOCK_REFCOUNTING))
+ {
+ return false;
+ }
+
+ return true;
+#else
+ ZEN_UNUSED(Path);
+ return false;
+#endif // ZEN_PLATFORM_WINDOWS
+}
+
+bool
+CloneFile(std::filesystem::path FromPath, std::filesystem::path ToPath)
+{
+#if ZEN_PLATFORM_WINDOWS
+ ATL::CHandle FromFile(CreateFileW(FromPath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr));
+ if (FromFile == INVALID_HANDLE_VALUE)
+ {
+ FromFile.Detach();
+ return false;
+ }
+
+ ULONG FileSystemFlags;
+ if (!GetVolumeInformationByHandleW(FromFile, nullptr, 0, nullptr, nullptr, /* lpFileSystemFlags */ &FileSystemFlags, nullptr, 0))
+ {
+ return false;
+ }
+ if (!(FileSystemFlags & FILE_SUPPORTS_BLOCK_REFCOUNTING))
+ {
+ SetLastError(ERROR_NOT_CAPABLE);
+ return false;
+ }
+
+ FILE_END_OF_FILE_INFO FileSize;
+ if (!GetFileSizeEx(FromFile, &FileSize.EndOfFile))
+ {
+ return false;
+ }
+
+ FILE_BASIC_INFO BasicInfo;
+ if (!GetFileInformationByHandleEx(FromFile, FileBasicInfo, &BasicInfo, sizeof BasicInfo))
+ {
+ return false;
+ }
+
+ DWORD dwBytesReturned = 0;
+ FSCTL_GET_INTEGRITY_INFORMATION_BUFFER GetIntegrityInfoBuffer;
+ if (!DeviceIoControl(FromFile,
+ FSCTL_GET_INTEGRITY_INFORMATION,
+ nullptr,
+ 0,
+ &GetIntegrityInfoBuffer,
+ sizeof GetIntegrityInfoBuffer,
+ &dwBytesReturned,
+ nullptr))
+ {
+ return false;
+ }
+
+ SetFileAttributesW(ToPath.c_str(), FILE_ATTRIBUTE_NORMAL);
+
+ ATL::CHandle TargetFile(CreateFileW(ToPath.c_str(),
+ GENERIC_READ | GENERIC_WRITE | DELETE,
+ /* no sharing */ FILE_SHARE_READ,
+ nullptr,
+ OPEN_ALWAYS,
+ 0,
+ /* hTemplateFile */ FromFile));
+
+ if (TargetFile == INVALID_HANDLE_VALUE)
+ {
+ TargetFile.Detach();
+ return false;
+ }
+
+ // Delete target file when handle is closed (we only reset this if the copy succeeds)
+ FILE_DISPOSITION_INFO FileDisposition = {TRUE};
+ if (!SetFileInformationByHandle(TargetFile, FileDispositionInfo, &FileDisposition, sizeof FileDisposition))
+ {
+ TargetFile.Close();
+ DeleteFileW(ToPath.c_str());
+ return false;
+ }
+
+ // Make file sparse so we don't end up allocating space when we change the file size
+ if (!DeviceIoControl(TargetFile, FSCTL_SET_SPARSE, nullptr, 0, nullptr, 0, &dwBytesReturned, nullptr))
+ {
+ return false;
+ }
+
+ // Copy integrity checking information
+ FSCTL_SET_INTEGRITY_INFORMATION_BUFFER IntegritySet = {GetIntegrityInfoBuffer.ChecksumAlgorithm,
+ GetIntegrityInfoBuffer.Reserved,
+ GetIntegrityInfoBuffer.Flags};
+ if (!DeviceIoControl(TargetFile, FSCTL_SET_INTEGRITY_INFORMATION, &IntegritySet, sizeof IntegritySet, nullptr, 0, nullptr, nullptr))
+ {
+ return false;
+ }
+
+ // Resize file - note that the file is sparse at this point so no additional data will be written
+ if (!SetFileInformationByHandle(TargetFile, FileEndOfFileInfo, &FileSize, sizeof FileSize))
+ {
+ return false;
+ }
+
+ constexpr auto RoundToClusterSize = [](LONG64 FileSize, ULONG ClusterSize) -> LONG64 {
+ return (FileSize + ClusterSize - 1) / ClusterSize * ClusterSize;
+ };
+ static_assert(RoundToClusterSize(5678, 4 * 1024) == 8 * 1024);
+
+ // Loop for cloning file contents. This is necessary as the API has a 32-bit size
+ // limit for some reason
+
+ const LONG64 SplitThreshold = (1LL << 32) - GetIntegrityInfoBuffer.ClusterSizeInBytes;
+
+ DUPLICATE_EXTENTS_DATA DuplicateExtentsData{.FileHandle = FromFile};
+
+ for (LONG64 CurrentByteOffset = 0,
+ RemainingBytes = RoundToClusterSize(FileSize.EndOfFile.QuadPart, GetIntegrityInfoBuffer.ClusterSizeInBytes);
+ RemainingBytes > 0;
+ CurrentByteOffset += SplitThreshold, RemainingBytes -= SplitThreshold)
+ {
+ DuplicateExtentsData.SourceFileOffset.QuadPart = CurrentByteOffset;
+ DuplicateExtentsData.TargetFileOffset.QuadPart = CurrentByteOffset;
+ DuplicateExtentsData.ByteCount.QuadPart = std::min(SplitThreshold, RemainingBytes);
+
+ if (!DeviceIoControl(TargetFile,
+ FSCTL_DUPLICATE_EXTENTS_TO_FILE,
+ &DuplicateExtentsData,
+ sizeof DuplicateExtentsData,
+ nullptr,
+ 0,
+ &dwBytesReturned,
+ nullptr))
+ {
+ return false;
+ }
+ }
+
+ // Make the file not sparse again now that we have populated the contents
+ if (!(BasicInfo.FileAttributes & FILE_ATTRIBUTE_SPARSE_FILE))
+ {
+ FILE_SET_SPARSE_BUFFER SetSparse = {FALSE};
+
+ if (!DeviceIoControl(TargetFile, FSCTL_SET_SPARSE, &SetSparse, sizeof SetSparse, nullptr, 0, &dwBytesReturned, nullptr))
+ {
+ return false;
+ }
+ }
+
+ // Update timestamps (but don't lie about the creation time)
+ BasicInfo.CreationTime.QuadPart = 0;
+ if (!SetFileInformationByHandle(TargetFile, FileBasicInfo, &BasicInfo, sizeof BasicInfo))
+ {
+ return false;
+ }
+
+ if (!FlushFileBuffers(TargetFile))
+ {
+ return false;
+ }
+
+ // Finally now everything is done - make sure the file is not deleted on close!
+
+ FileDisposition = {FALSE};
+
+ const bool AllOk = (TRUE == SetFileInformationByHandle(TargetFile, FileDispositionInfo, &FileDisposition, sizeof FileDisposition));
+
+ return AllOk;
+#elif ZEN_PLATFORM_LINUX
+# if 0
+ struct ScopedFd
+ {
+ ~ScopedFd() { close(Fd); }
+ int Fd;
+ };
+
+ // The 'from' file
+ int FromFd = open(FromPath.c_str(), O_RDONLY|O_CLOEXEC);
+ if (FromFd < 0)
+ {
+ return false;
+ }
+ ScopedFd $From = { FromFd };
+
+ // The 'to' file
+ int ToFd = open(ToPath.c_str(), O_WRONLY|O_CREAT|O_EXCL|O_CLOEXEC, 0666);
+ if (ToFd < 0)
+ {
+ return false;
+ }
+ fchmod(ToFd, 0666);
+ ScopedFd $To = { FromFd };
+
+ ioctl(ToFd, FICLONE, FromFd);
+
+ return false;
+# endif // 0
+ ZEN_UNUSED(FromPath, ToPath);
+ ZEN_ERROR("CloneFile() is not implemented on this platform");
+ return false;
+#elif ZEN_PLATFORM_MAC
+ /* clonefile() syscall if APFS */
+ ZEN_UNUSED(FromPath, ToPath);
+ ZEN_ERROR("CloneFile() is not implemented on this platform");
+ return false;
+#endif // ZEN_PLATFORM_WINDOWS
+}
+
+bool
+CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const CopyFileOptions& Options)
+{
+ bool Success = false;
+
+ if (Options.EnableClone)
+ {
+ Success = CloneFile(FromPath.native(), ToPath.native());
+
+ if (Success)
+ {
+ return true;
+ }
+ }
+
+ if (Options.MustClone)
+ {
+ return false;
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ BOOL CancelFlag = FALSE;
+ Success = !!::CopyFileExW(FromPath.c_str(),
+ ToPath.c_str(),
+ /* lpProgressRoutine */ nullptr,
+ /* lpData */ nullptr,
+ &CancelFlag,
+ /* dwCopyFlags */ 0);
+#else
+ struct ScopedFd
+ {
+ ~ScopedFd() { close(Fd); }
+ int Fd;
+ };
+
+ // From file
+ int FromFd = open(FromPath.c_str(), O_RDONLY | O_CLOEXEC);
+ if (FromFd < 0)
+ {
+ ThrowLastError(fmt::format("failed to open file {}", FromPath));
+ }
+ ScopedFd $From = {FromFd};
+
+ // To file
+ int ToFd = open(ToPath.c_str(), O_WRONLY | O_CREAT | O_EXCL | O_CLOEXEC, 0666);
+ if (ToFd < 0)
+ {
+ ThrowLastError(fmt::format("failed to create file {}", ToPath));
+ }
+ fchmod(ToFd, 0666);
+ ScopedFd $To = {ToFd};
+
+ // Copy impl
+ static const size_t BufferSize = 64 << 10;
+ void* Buffer = malloc(BufferSize);
+ while (true)
+ {
+ int BytesRead = read(FromFd, Buffer, BufferSize);
+ if (BytesRead <= 0)
+ {
+ Success = (BytesRead == 0);
+ break;
+ }
+
+ if (write(ToFd, Buffer, BytesRead) != BufferSize)
+ {
+ Success = false;
+ break;
+ }
+ }
+ free(Buffer);
+#endif // ZEN_PLATFORM_WINDOWS
+
+ if (!Success)
+ {
+ ThrowLastError("file copy failed"sv);
+ }
+
+ return true;
+}
+
+void
+WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t BufferCount)
+{
+#if ZEN_PLATFORM_WINDOWS
+ CAtlFile Outfile;
+ HRESULT hRes = Outfile.Create(Path.c_str(), GENERIC_WRITE, FILE_SHARE_READ, CREATE_ALWAYS);
+ if (hRes == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND))
+ {
+ CreateDirectories(Path.parent_path());
+
+ hRes = Outfile.Create(Path.c_str(), GENERIC_WRITE, FILE_SHARE_READ, CREATE_ALWAYS);
+ }
+
+ if (FAILED(hRes))
+ {
+ ThrowSystemException(hRes, fmt::format("File open failed for '{}'", Path).c_str());
+ }
+
+#else
+ int OpenFlags = O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC;
+ int Fd = open(Path.c_str(), OpenFlags, 0666);
+ if (Fd < 0)
+ {
+ zen::CreateDirectories(Path.parent_path());
+ Fd = open(Path.c_str(), OpenFlags, 0666);
+ }
+
+ if (Fd < 0)
+ {
+ ThrowLastError(fmt::format("File open failed for '{}'", Path));
+ }
+
+ fchmod(Fd, 0666);
+#endif
+
+ // TODO: this should be block-enlightened
+
+ for (size_t i = 0; i < BufferCount; ++i)
+ {
+ uint64_t WriteSize = Data[i]->Size();
+ const void* DataPtr = Data[i]->Data();
+
+ while (WriteSize)
+ {
+ const uint64_t ChunkSize = Min<uint64_t>(WriteSize, uint64_t(2) * 1024 * 1024 * 1024);
+
+#if ZEN_PLATFORM_WINDOWS
+ hRes = Outfile.Write(DataPtr, gsl::narrow_cast<uint32_t>(WriteSize));
+ if (FAILED(hRes))
+ {
+ ThrowSystemException(hRes, fmt::format("File write failed for '{}'", Path).c_str());
+ }
+#else
+ if (write(Fd, DataPtr, WriteSize) != int64_t(WriteSize))
+ {
+ ThrowLastError(fmt::format("File write failed for '{}'", Path));
+ }
+#endif // ZEN_PLATFORM_WINDOWS
+
+ WriteSize -= ChunkSize;
+ DataPtr = reinterpret_cast<const uint8_t*>(DataPtr) + ChunkSize;
+ }
+ }
+
+#if !ZEN_PLATFORM_WINDOWS
+ close(Fd);
+#endif
+}
+
+void
+WriteFile(std::filesystem::path Path, IoBuffer Data)
+{
+ const IoBuffer* const DataPtr = &Data;
+
+ WriteFile(Path, &DataPtr, 1);
+}
+
+IoBuffer
+FileContents::Flatten()
+{
+ if (Data.size() == 1)
+ {
+ return Data[0];
+ }
+ else if (Data.empty())
+ {
+ return {};
+ }
+ else
+ {
+ ZEN_NOT_IMPLEMENTED();
+ }
+}
+
+FileContents
+ReadStdIn()
+{
+ BinaryWriter Writer;
+
+ do
+ {
+ uint8_t ReadBuffer[1024];
+
+ size_t BytesRead = fread(ReadBuffer, 1, sizeof ReadBuffer, stdin);
+ Writer.Write(ReadBuffer, BytesRead);
+ } while (!feof(stdin));
+
+ FileContents Contents;
+ Contents.Data.emplace_back(IoBuffer(IoBuffer::Clone, Writer.GetData(), Writer.GetSize()));
+
+ return Contents;
+}
+
+FileContents
+ReadFile(std::filesystem::path Path)
+{
+ uint64_t FileSizeBytes;
+ void* Handle;
+
+#if ZEN_PLATFORM_WINDOWS
+ ATL::CHandle FromFile(CreateFileW(Path.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr));
+ if (FromFile == INVALID_HANDLE_VALUE)
+ {
+ FromFile.Detach();
+ return FileContents{.ErrorCode = std::error_code(::GetLastError(), std::system_category())};
+ }
+
+ FILE_END_OF_FILE_INFO FileSize;
+ if (!GetFileSizeEx(FromFile, &FileSize.EndOfFile))
+ {
+ return FileContents{.ErrorCode = std::error_code(::GetLastError(), std::system_category())};
+ }
+
+ FileSizeBytes = FileSize.EndOfFile.QuadPart;
+ Handle = FromFile.Detach();
+#else
+ int Fd = open(Path.c_str(), O_RDONLY | O_CLOEXEC);
+ if (Fd < 0)
+ {
+ FileContents Ret;
+ Ret.ErrorCode = std::error_code(zen::GetLastError(), std::system_category());
+ return Ret;
+ }
+
+ static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files");
+ struct stat Stat;
+ fstat(Fd, &Stat);
+
+ FileSizeBytes = Stat.st_size;
+ Handle = (void*)uintptr_t(Fd);
+#endif
+
+ FileContents Contents;
+ Contents.Data.emplace_back(IoBuffer(IoBuffer::File, Handle, 0, FileSizeBytes));
+ return Contents;
+}
+
+bool
+ScanFile(std::filesystem::path Path, const uint64_t ChunkSize, std::function<void(const void* Data, size_t Size)>&& ProcessFunc)
+{
+#if ZEN_PLATFORM_WINDOWS
+ ATL::CHandle FromFile(CreateFileW(Path.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr));
+ if (FromFile == INVALID_HANDLE_VALUE)
+ {
+ FromFile.Detach();
+ return false;
+ }
+
+ std::vector<uint8_t> ReadBuffer(ChunkSize);
+
+ for (;;)
+ {
+ DWORD dwBytesRead = 0;
+ BOOL Success = ::ReadFile(FromFile, ReadBuffer.data(), (DWORD)ReadBuffer.size(), &dwBytesRead, nullptr);
+
+ if (!Success)
+ {
+ throw std::system_error(std::error_code(::GetLastError(), std::system_category()), "file scan failed");
+ }
+
+ if (dwBytesRead == 0)
+ break;
+
+ ProcessFunc(ReadBuffer.data(), dwBytesRead);
+ }
+#else
+ int Fd = open(Path.c_str(), O_RDONLY | O_CLOEXEC);
+ if (Fd < 0)
+ {
+ return false;
+ }
+
+ bool Success = true;
+
+ void* Buffer = malloc(ChunkSize);
+ while (true)
+ {
+ int BytesRead = read(Fd, Buffer, ChunkSize);
+ if (BytesRead < 0)
+ {
+ Success = false;
+ break;
+ }
+
+ if (BytesRead == 0)
+ {
+ break;
+ }
+
+ ProcessFunc(Buffer, BytesRead);
+ }
+
+ free(Buffer);
+ close(Fd);
+
+ if (!Success)
+ {
+ ThrowLastError("file scan failed");
+ }
+#endif // ZEN_PLATFORM_WINDOWS
+
+ return true;
+}
+
+void
+PathToUtf8(const std::filesystem::path& Path, StringBuilderBase& Out)
+{
+#if ZEN_PLATFORM_WINDOWS
+ WideToUtf8(Path.native().c_str(), Out);
+#else
+ Out << Path.c_str();
+#endif
+}
+
+std::string
+PathToUtf8(const std::filesystem::path& Path)
+{
+#if ZEN_PLATFORM_WINDOWS
+ return WideToUtf8(Path.native().c_str());
+#else
+ return Path.string();
+#endif
+}
+
+DiskSpace
+DiskSpaceInfo(std::filesystem::path Directory, std::error_code& Error)
+{
+ using namespace std::filesystem;
+
+ space_info SpaceInfo = space(Directory, Error);
+ if (Error)
+ {
+ return {};
+ }
+
+ return {
+ .Free = uint64_t(SpaceInfo.available),
+ .Total = uint64_t(SpaceInfo.capacity),
+ };
+}
+
+void
+FileSystemTraversal::TraverseFileSystem(const std::filesystem::path& RootDir, TreeVisitor& Visitor)
+{
+#if ZEN_PLATFORM_WINDOWS
+ uint64_t FileInfoBuffer[8 * 1024];
+
+ FILE_INFO_BY_HANDLE_CLASS FibClass = FileIdBothDirectoryRestartInfo;
+ bool Continue = true;
+
+ CAtlFile RootDirHandle;
+ HRESULT hRes =
+ RootDirHandle.Create(RootDir.c_str(), GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_WRITE, OPEN_EXISTING, FILE_FLAG_BACKUP_SEMANTICS);
+
+ if (FAILED(hRes))
+ {
+ ThrowSystemException(hRes, "Failed to open handle to volume root");
+ }
+
+ while (Continue)
+ {
+ BOOL Success = GetFileInformationByHandleEx(RootDirHandle, FibClass, FileInfoBuffer, sizeof FileInfoBuffer);
+ FibClass = FileIdBothDirectoryInfo; // Set up for next iteration
+
+ uint64_t EntryOffset = 0;
+
+ if (!Success)
+ {
+ DWORD LastError = GetLastError();
+
+ if (LastError == ERROR_NO_MORE_FILES)
+ {
+ break;
+ }
+
+ throw std::system_error(std::error_code(LastError, std::system_category()), "file system traversal error");
+ }
+
+ for (;;)
+ {
+ const FILE_ID_BOTH_DIR_INFO* DirInfo =
+ reinterpret_cast<const FILE_ID_BOTH_DIR_INFO*>(reinterpret_cast<const uint8_t*>(FileInfoBuffer) + EntryOffset);
+
+ std::wstring_view FileName(DirInfo->FileName, DirInfo->FileNameLength / sizeof(wchar_t));
+
+ if (DirInfo->FileAttributes & FILE_ATTRIBUTE_DIRECTORY)
+ {
+ if (FileName == L"."sv || FileName == L".."sv)
+ {
+ // Not very interesting
+ }
+ else
+ {
+ const bool ShouldDescend = Visitor.VisitDirectory(RootDir, FileName);
+
+ if (ShouldDescend)
+ {
+ // Note that this recursion combined with the buffer could
+ // blow the stack, we should consider a different strategy
+
+ std::filesystem::path FullPath = RootDir / FileName;
+
+ TraverseFileSystem(FullPath, Visitor);
+ }
+ }
+ }
+ else if (DirInfo->FileAttributes & FILE_ATTRIBUTE_DEVICE)
+ {
+ ZEN_WARN("encountered device node during file system traversal: '{}' found in '{}'", WideToUtf8(FileName), RootDir);
+ }
+ else
+ {
+ Visitor.VisitFile(RootDir, FileName, DirInfo->EndOfFile.QuadPart);
+ }
+
+ const uint64_t NextOffset = DirInfo->NextEntryOffset;
+
+ if (NextOffset == 0)
+ {
+ break;
+ }
+
+ EntryOffset += NextOffset;
+ }
+ }
+#else
+ /* Could also implement this using Linux's getdents() syscall */
+
+ DIR* Dir = opendir(RootDir.c_str());
+ if (Dir == nullptr)
+ {
+ ThrowLastError(fmt::format("Failed to open directory for traversal: {}", RootDir.c_str()));
+ }
+
+ for (struct dirent* Entry; (Entry = readdir(Dir));)
+ {
+ const char* FileName = Entry->d_name;
+
+ struct stat Stat;
+ std::filesystem::path FullPath = RootDir / FileName;
+ stat(FullPath.c_str(), &Stat);
+
+ if (S_ISDIR(Stat.st_mode))
+ {
+ if (strcmp(FileName, ".") == 0 || strcmp(FileName, "..") == 0)
+ {
+ /* nop */
+ }
+ else if (Visitor.VisitDirectory(RootDir, FileName))
+ {
+ TraverseFileSystem(FullPath, Visitor);
+ }
+ }
+ else if (S_ISREG(Stat.st_mode))
+ {
+ Visitor.VisitFile(RootDir, FileName, Stat.st_size);
+ }
+ else
+ {
+ ZEN_WARN("encountered non-regular file during file system traversal ({}): {} found in {}",
+ Stat.st_mode,
+ FileName,
+ RootDir.c_str());
+ }
+ }
+
+ closedir(Dir);
+#endif // ZEN_PLATFORM_WINDOWS
+}
+
+std::filesystem::path
+PathFromHandle(void* NativeHandle)
+{
+#if ZEN_PLATFORM_WINDOWS
+ if (NativeHandle == nullptr || NativeHandle == INVALID_HANDLE_VALUE)
+ {
+ return std::filesystem::path();
+ }
+
+ auto GetFinalPathNameByHandleWRetry = [](HANDLE hFile, LPWSTR lpszFilePath, DWORD cchFilePath, DWORD dwFlags) -> DWORD {
+ while (true)
+ {
+ DWORD Res = GetFinalPathNameByHandleW(hFile, lpszFilePath, cchFilePath, dwFlags);
+ if (Res == 0)
+ {
+ DWORD LastError = zen::GetLastError();
+ // Under heavy concurrent loads we might get access denied on a file handle while trying to get path name.
+ // Retry if that is the case.
+ if (LastError != ERROR_ACCESS_DENIED)
+ {
+ ThrowSystemError(LastError, fmt::format("failed to get path from file handle {}", hFile));
+ }
+ // Retry
+ continue;
+ }
+ ZEN_ASSERT(Res != 1); // We don't accept empty path names
+ return Res;
+ }
+ };
+
+ static const DWORD PathDataSize = 512;
+ wchar_t PathData[PathDataSize];
+ DWORD RequiredLengthIncludingNul = GetFinalPathNameByHandleWRetry(NativeHandle, PathData, PathDataSize, FILE_NAME_OPENED);
+ if (RequiredLengthIncludingNul == 0)
+ {
+ ThrowLastError(fmt::format("failed to get path from file handle {}", NativeHandle));
+ }
+
+ if (RequiredLengthIncludingNul < PathDataSize)
+ {
+ std::wstring FullPath(PathData, gsl::narrow<size_t>(RequiredLengthIncludingNul));
+ return FullPath;
+ }
+
+ std::wstring FullPath;
+ FullPath.resize(RequiredLengthIncludingNul - 1);
+
+ const DWORD FinalLength = GetFinalPathNameByHandleWRetry(NativeHandle, FullPath.data(), RequiredLengthIncludingNul, FILE_NAME_OPENED);
+ ZEN_UNUSED(FinalLength);
+ return FullPath;
+
+#elif ZEN_PLATFORM_LINUX
+ char Link[PATH_MAX];
+ char Path[64];
+
+ sprintf(Path, "/proc/self/fd/%d", int(uintptr_t(NativeHandle)));
+ ssize_t BytesRead = readlink(Path, Link, sizeof(Link) - 1);
+ if (BytesRead <= 0)
+ {
+ return std::filesystem::path();
+ }
+
+ Link[BytesRead] = '\0';
+ return Link;
+#elif ZEN_PLATFORM_MAC
+ int Fd = int(uintptr_t(NativeHandle));
+ char Path[MAXPATHLEN];
+ if (fcntl(Fd, F_GETPATH, Path) < 0)
+ {
+ return std::filesystem::path();
+ }
+
+ return Path;
+#endif // ZEN_PLATFORM_WINDOWS
+}
+
+std::filesystem::path
+GetRunningExecutablePath()
+{
+#if ZEN_PLATFORM_WINDOWS
+ TCHAR ExePath[MAX_PATH];
+ DWORD PathLength = GetModuleFileName(NULL, ExePath, ZEN_ARRAY_COUNT(ExePath));
+
+ return {std::wstring_view(ExePath, PathLength)};
+#elif ZEN_PLATFORM_LINUX
+ char Link[256];
+ ssize_t BytesRead = readlink("/proc/self/exe", Link, sizeof(Link) - 1);
+ if (BytesRead < 0)
+ return {};
+
+ Link[BytesRead] = '\0';
+ return Link;
+#elif ZEN_PLATFORM_MAC
+ char Buffer[PROC_PIDPATHINFO_MAXSIZE];
+
+ int SelfPid = GetCurrentProcessId();
+ if (proc_pidpath(SelfPid, Buffer, sizeof(Buffer)) <= 0)
+ return {};
+
+ return Buffer;
+#endif // ZEN_PLATFORM_WINDOWS
+}
+
+void
+MaximizeOpenFileCount()
+{
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ struct rlimit Limit;
+ int Error = getrlimit(RLIMIT_NOFILE, &Limit);
+ if (Error)
+ {
+ ZEN_WARN("failed getting rlimit RLIMIT_NOFILE, reason '{}'", zen::MakeErrorCode(Error).message());
+ }
+ else
+ {
+ struct rlimit NewLimit = Limit;
+ NewLimit.rlim_cur = NewLimit.rlim_max;
+ ZEN_INFO("changing RLIMIT_NOFILE from rlim_cur = {}, rlim_max {} to rlim_cur = {}, rlim_max {}",
+ Limit.rlim_cur,
+ Limit.rlim_max,
+ NewLimit.rlim_cur,
+ NewLimit.rlim_max);
+
+ Error = setrlimit(RLIMIT_NOFILE, &NewLimit);
+ if (Error != 0)
+ {
+ ZEN_WARN("failed to set RLIMIT_NOFILE limits from rlim_cur = {}, rlim_max {} to rlim_cur = {}, rlim_max {}, reason '{}'",
+ Limit.rlim_cur,
+ Limit.rlim_max,
+ NewLimit.rlim_cur,
+ NewLimit.rlim_max,
+ zen::MakeErrorCode(Error).message());
+ }
+ }
+#endif
+}
+
+void
+GetDirectoryContent(const std::filesystem::path& RootDir, uint8_t Flags, DirectoryContent& OutContent)
+{
+ FileSystemTraversal Traversal;
+ struct Visitor : public FileSystemTraversal::TreeVisitor
+ {
+ Visitor(uint8_t Flags, DirectoryContent& OutContent) : Flags(Flags), Content(OutContent) {}
+
+ virtual void VisitFile([[maybe_unused]] const std::filesystem::path& Parent,
+ [[maybe_unused]] const path_view& File,
+ [[maybe_unused]] uint64_t FileSize) override
+ {
+ if (Flags & DirectoryContent::IncludeFilesFlag)
+ {
+ Content.Files.push_back(Parent / File);
+ }
+ }
+
+ virtual bool VisitDirectory([[maybe_unused]] const std::filesystem::path& Parent, const path_view& DirectoryName) override
+ {
+ if (Flags & DirectoryContent::IncludeDirsFlag)
+ {
+ Content.Directories.push_back(Parent / DirectoryName);
+ }
+ return (Flags & DirectoryContent::RecursiveFlag) != 0;
+ }
+
+ const uint8_t Flags;
+ DirectoryContent& Content;
+ } Visit(Flags, OutContent);
+
+ Traversal.TraverseFileSystem(RootDir, Visit);
+}
+
+std::string
+GetEnvVariable(std::string_view VariableName)
+{
+ ZEN_ASSERT(!VariableName.empty());
+#if ZEN_PLATFORM_WINDOWS
+
+ CHAR EnvVariableBuffer[1023 + 1];
+ DWORD RESULT = GetEnvironmentVariableA(std::string(VariableName).c_str(), EnvVariableBuffer, sizeof(EnvVariableBuffer));
+ if (RESULT > 0 && RESULT < sizeof(EnvVariableBuffer))
+ {
+ return std::string(EnvVariableBuffer);
+ }
+#endif
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ char* EnvVariable = getenv(std::string(VariableName).c_str());
+ if (EnvVariable)
+ {
+ return std::string(EnvVariable);
+ }
+#endif
+ return "";
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+void
+filesystem_forcelink()
+{
+}
+
+TEST_CASE("filesystem")
+{
+ using namespace std::filesystem;
+
+ // GetExePath -- this is not a great test as it's so dependent on where the this code gets linked in
+ path BinPath = GetRunningExecutablePath();
+ const bool ExpectedExe = PathToUtf8(BinPath.stem().native()).ends_with("-test"sv) || BinPath.stem() == "zenserver";
+ CHECK(ExpectedExe);
+ CHECK(is_regular_file(BinPath));
+
+ // PathFromHandle
+ void* Handle;
+# if ZEN_PLATFORM_WINDOWS
+ Handle = CreateFileW(BinPath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr);
+ CHECK(Handle != INVALID_HANDLE_VALUE);
+# else
+ int Fd = open(BinPath.c_str(), O_RDONLY | O_CLOEXEC);
+ CHECK(Fd >= 0);
+ Handle = (void*)uintptr_t(Fd);
+# endif
+
+ auto FromHandle = PathFromHandle((void*)uintptr_t(Handle));
+ CHECK(equivalent(FromHandle, BinPath));
+
+# if ZEN_PLATFORM_WINDOWS
+ CloseHandle(Handle);
+# else
+ close(int(uintptr_t(Handle)));
+# endif
+
+ // Traversal
+ struct : public FileSystemTraversal::TreeVisitor
+ {
+ virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t) override
+ {
+ bFoundExpected |= std::filesystem::equivalent(Parent / File, Expected);
+ }
+
+ virtual bool VisitDirectory(const std::filesystem::path&, const path_view&) override { return true; }
+
+ bool bFoundExpected = false;
+ std::filesystem::path Expected;
+ } Visitor;
+ Visitor.Expected = BinPath;
+
+ FileSystemTraversal().TraverseFileSystem(BinPath.parent_path().parent_path(), Visitor);
+ CHECK(Visitor.bFoundExpected);
+
+ // Scan/read file
+ FileContents BinRead = ReadFile(BinPath);
+ std::vector<uint8_t> BinScan;
+ ScanFile(BinPath, 16 << 10, [&](const void* Data, size_t Size) {
+ const auto* Ptr = (uint8_t*)Data;
+ BinScan.insert(BinScan.end(), Ptr, Ptr + Size);
+ });
+ CHECK_EQ(BinRead.Data.size(), 1);
+ CHECK_EQ(BinScan.size(), BinRead.Data[0].GetSize());
+}
+
+TEST_CASE("WriteFile")
+{
+ std::filesystem::path TempFile = GetRunningExecutablePath().parent_path();
+ TempFile /= "write_file_test";
+
+ uint64_t Magics[] = {
+ 0x0'a9e'a9e'a9e'a9e'a9e,
+ 0x0'493'493'493'493'493,
+ };
+
+ struct
+ {
+ const void* Data;
+ size_t Size;
+ } MagicTests[] = {
+ {
+ Magics,
+ sizeof(Magics),
+ },
+ {
+ Magics + 1,
+ sizeof(Magics[0]),
+ },
+ };
+ for (auto& MagicTest : MagicTests)
+ {
+ WriteFile(TempFile, IoBuffer(IoBuffer::Wrap, MagicTest.Data, MagicTest.Size));
+
+ FileContents MagicsReadback = ReadFile(TempFile);
+ CHECK_EQ(MagicsReadback.Data.size(), 1);
+ CHECK_EQ(MagicsReadback.Data[0].GetSize(), MagicTest.Size);
+ CHECK_EQ(memcmp(MagicTest.Data, MagicsReadback.Data[0].Data(), MagicTest.Size), 0);
+ }
+
+ std::filesystem::remove(TempFile);
+}
+
+TEST_CASE("DiskSpaceInfo")
+{
+ std::filesystem::path BinPath = GetRunningExecutablePath();
+
+ DiskSpace Space = {};
+
+ std::error_code Error;
+ Space = DiskSpaceInfo(BinPath, Error);
+ CHECK(!Error);
+
+ bool Okay = DiskSpaceInfo(BinPath, Space);
+ CHECK(Okay);
+
+ CHECK(int64_t(Space.Total) > 0);
+ CHECK(int64_t(Space.Free) > 0); // Hopefully there's at least one byte free
+}
+
+TEST_CASE("PathBuilder")
+{
+# if ZEN_PLATFORM_WINDOWS
+ const char* foo_bar = "/foo\\bar";
+# else
+ const char* foo_bar = "/foo/bar";
+# endif
+
+ ExtendablePathBuilder<32> Path;
+ for (const char* Prefix : {"/foo", "/foo/"})
+ {
+ Path.Reset();
+ Path.Append(Prefix);
+ Path /= "bar";
+ CHECK(Path.ToPath() == foo_bar);
+ }
+
+ using fspath = std::filesystem::path;
+
+ Path.Reset();
+ Path.Append(fspath("/foo/"));
+ Path /= (fspath("bar"));
+ CHECK(Path.ToPath() == foo_bar);
+
+# if ZEN_PLATFORM_WINDOWS
+ Path.Reset();
+ Path.Append(fspath(L"/\u0119oo/"));
+ Path /= L"bar";
+ printf("%ls\n", Path.ToPath().c_str());
+ CHECK(Path.ToView() == L"/\u0119oo/bar");
+ CHECK(Path.ToPath() == L"\\\u0119oo\\bar");
+# endif
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/atomic.h b/src/zencore/include/zencore/atomic.h
new file mode 100644
index 000000000..bf549e21d
--- /dev/null
+++ b/src/zencore/include/zencore/atomic.h
@@ -0,0 +1,74 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#if ZEN_COMPILER_MSC
+# include <intrin.h>
+#else
+# include <atomic>
+#endif
+
+#include <cinttypes>
+
+namespace zen {
+
+inline uint32_t
+AtomicIncrement(volatile uint32_t& value)
+{
+#if ZEN_COMPILER_MSC
+ return _InterlockedIncrement((long volatile*)&value);
+#else
+ return ((std::atomic<uint32_t>*)(&value))->fetch_add(1, std::memory_order_seq_cst) + 1;
+#endif
+}
+inline uint32_t
+AtomicDecrement(volatile uint32_t& value)
+{
+#if ZEN_COMPILER_MSC
+ return _InterlockedDecrement((long volatile*)&value);
+#else
+ return ((std::atomic<uint32_t>*)(&value))->fetch_sub(1, std::memory_order_seq_cst) - 1;
+#endif
+}
+
+inline uint64_t
+AtomicIncrement(volatile uint64_t& value)
+{
+#if ZEN_COMPILER_MSC
+ return _InterlockedIncrement64((__int64 volatile*)&value);
+#else
+ return ((std::atomic<uint64_t>*)(&value))->fetch_add(1, std::memory_order_seq_cst) + 1;
+#endif
+}
+inline uint64_t
+AtomicDecrement(volatile uint64_t& value)
+{
+#if ZEN_COMPILER_MSC
+ return _InterlockedDecrement64((__int64 volatile*)&value);
+#else
+ return ((std::atomic<uint64_t>*)(&value))->fetch_sub(1, std::memory_order_seq_cst) - 1;
+#endif
+}
+
+inline uint32_t
+AtomicAdd(volatile uint32_t& value, uint32_t amount)
+{
+#if ZEN_COMPILER_MSC
+ return _InterlockedExchangeAdd((long volatile*)&value, amount);
+#else
+ return ((std::atomic<uint32_t>*)(&value))->fetch_add(amount, std::memory_order_seq_cst);
+#endif
+}
+inline uint64_t
+AtomicAdd(volatile uint64_t& value, uint64_t amount)
+{
+#if ZEN_COMPILER_MSC
+ return _InterlockedExchangeAdd64((__int64 volatile*)&value, amount);
+#else
+ return ((std::atomic<uint64_t>*)(&value))->fetch_add(amount, std::memory_order_seq_cst);
+#endif
+}
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/base64.h b/src/zencore/include/zencore/base64.h
new file mode 100644
index 000000000..4d78b085f
--- /dev/null
+++ b/src/zencore/include/zencore/base64.h
@@ -0,0 +1,17 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+namespace zen {
+
+struct Base64
+{
+ template<typename CharType>
+ static uint32_t Encode(const uint8_t* Source, uint32_t Length, CharType* Dest);
+
+ static inline constexpr int32_t GetEncodedDataSize(uint32_t Size) { return ((Size + 2) / 3) * 4; }
+};
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/blake3.h b/src/zencore/include/zencore/blake3.h
new file mode 100644
index 000000000..b31b710a7
--- /dev/null
+++ b/src/zencore/include/zencore/blake3.h
@@ -0,0 +1,62 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <cinttypes>
+#include <compare>
+#include <cstring>
+
+#include <zencore/memory.h>
+
+namespace zen {
+
+class CompositeBuffer;
+class StringBuilderBase;
+
+/**
+ * BLAKE3 hash - 256 bits
+ */
+struct BLAKE3
+{
+ uint8_t Hash[32];
+
+ inline auto operator<=>(const BLAKE3& Rhs) const = default;
+
+ static BLAKE3 HashBuffer(const CompositeBuffer& Buffer);
+ static BLAKE3 HashMemory(const void* Data, size_t ByteCount);
+ static BLAKE3 FromHexString(const char* String);
+ const char* ToHexString(char* OutString /* 40 characters + NUL terminator */) const;
+ StringBuilderBase& ToHexString(StringBuilderBase& OutBuilder) const;
+
+ static const int StringLength = 64;
+ typedef char String_t[StringLength + 1];
+
+ static BLAKE3 Zero; // Initialized to all zeroes
+
+ struct Hasher
+ {
+ size_t operator()(const BLAKE3& v) const
+ {
+ size_t h;
+ memcpy(&h, v.Hash, sizeof h);
+ return h;
+ }
+ };
+};
+
+struct BLAKE3Stream
+{
+ BLAKE3Stream();
+
+ void Reset(); // Begin streaming hash compute (not needed on freshly constructed instance)
+ BLAKE3Stream& Append(const void* data, size_t byteCount); // Append another chunk
+ BLAKE3Stream& Append(MemoryView DataView) { return Append(DataView.GetData(), DataView.GetSize()); } // Append another chunk
+ BLAKE3 GetHash(); // Obtain final hash. If you wish to reuse the instance call reset()
+
+private:
+ alignas(16) uint8_t m_HashState[2048];
+};
+
+void blake3_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/blockingqueue.h b/src/zencore/include/zencore/blockingqueue.h
new file mode 100644
index 000000000..f92df5a54
--- /dev/null
+++ b/src/zencore/include/zencore/blockingqueue.h
@@ -0,0 +1,76 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <atomic>
+#include <condition_variable>
+#include <deque>
+#include <mutex>
+
+namespace zen {
+
+template<typename T>
+class BlockingQueue
+{
+public:
+ BlockingQueue() = default;
+
+ ~BlockingQueue() { CompleteAdding(); }
+
+ void Enqueue(T&& Item)
+ {
+ {
+ std::lock_guard Lock(m_Lock);
+ m_Queue.emplace_back(std::move(Item));
+ m_Size++;
+ }
+
+ m_NewItemSignal.notify_one();
+ }
+
+ bool WaitAndDequeue(T& Item)
+ {
+ if (m_CompleteAdding.load())
+ {
+ return false;
+ }
+
+ std::unique_lock Lock(m_Lock);
+ m_NewItemSignal.wait(Lock, [this]() { return !m_Queue.empty() || m_CompleteAdding.load(); });
+
+ if (!m_Queue.empty())
+ {
+ Item = std::move(m_Queue.front());
+ m_Queue.pop_front();
+ m_Size--;
+
+ return true;
+ }
+
+ return false;
+ }
+
+ void CompleteAdding()
+ {
+ if (!m_CompleteAdding.load())
+ {
+ m_CompleteAdding.store(true);
+ m_NewItemSignal.notify_all();
+ }
+ }
+
+ std::size_t Size() const
+ {
+ std::unique_lock Lock(m_Lock);
+ return m_Queue.size();
+ }
+
+private:
+ mutable std::mutex m_Lock;
+ std::condition_variable m_NewItemSignal;
+ std::deque<T> m_Queue;
+ std::atomic_bool m_CompleteAdding{false};
+ std::atomic_uint32_t m_Size;
+};
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/compactbinary.h b/src/zencore/include/zencore/compactbinary.h
new file mode 100644
index 000000000..b546f97aa
--- /dev/null
+++ b/src/zencore/include/zencore/compactbinary.h
@@ -0,0 +1,1475 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#include <zencore/enumflags.h>
+#include <zencore/intmath.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/memory.h>
+#include <zencore/meta.h>
+#include <zencore/sharedbuffer.h>
+#include <zencore/uid.h>
+#include <zencore/varint.h>
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <type_traits>
+#include <vector>
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+class CbObjectView;
+class CbArrayView;
+class BinaryReader;
+class BinaryWriter;
+class CompressedBuffer;
+class CbValue;
+
+class DateTime
+{
+public:
+ explicit DateTime(uint64_t InTicks) : Ticks(InTicks) {}
+ inline DateTime(int Year, int Month, int Day, int Hours = 0, int Minutes = 0, int Seconds = 0, int MilliSeconds = 0)
+ {
+ Set(Year, Month, Day, Hours, Minutes, Seconds, MilliSeconds);
+ }
+
+ inline uint64_t GetTicks() const { return Ticks; }
+
+ static uint64_t NowTicks();
+ static DateTime Now();
+
+ int GetYear() const;
+ int GetMonth() const;
+ int GetDay() const;
+ int GetHour() const;
+ int GetHour12() const;
+ int GetMinute() const;
+ int GetSecond() const;
+ int GetMillisecond() const;
+ void GetDate(int& Year, int& Month, int& Day) const;
+
+ inline bool operator==(const DateTime& Rhs) const { return Ticks == Rhs.Ticks; }
+ inline auto operator<=>(const DateTime& Rhs) const { return Ticks - Rhs.Ticks; }
+
+ std::string ToString(const char* Format) const;
+ std::string ToIso8601() const;
+
+private:
+ void Set(int Year, int Month, int Day, int Hours, int Minutes, int Seconds, int MilliSecond);
+ uint64_t Ticks; // 1 tick == 0.1us == 100ns, epoch == Jan 1st 0001
+};
+
+class TimeSpan
+{
+public:
+ explicit TimeSpan(uint64_t InTicks) : Ticks(InTicks) {}
+ inline TimeSpan(int Hours, int Minutes, int Seconds) { Set(0, Hours, Minutes, Seconds, 0); }
+ inline TimeSpan(int Days, int Hours, int Minutes, int Seconds) { Set(Days, Hours, Minutes, Seconds, 0); }
+ inline TimeSpan(int Days, int Hours, int Minutes, int Seconds, int Nanos) { Set(Days, Hours, Minutes, Seconds, Nanos); }
+
+ inline uint64_t GetTicks() const { return Ticks; }
+ inline bool operator==(const TimeSpan& Rhs) const { return Ticks == Rhs.Ticks; }
+ inline auto operator<=>(const TimeSpan& Rhs) const { return Ticks - Rhs.Ticks; }
+
+ /**
+ * Time span related constants.
+ */
+
+ /** The maximum number of ticks that can be represented in FTimespan. */
+ static constexpr int64_t MaxTicks = 9223372036854775807;
+
+ /** The minimum number of ticks that can be represented in FTimespan. */
+ static constexpr int64_t MinTicks = -9223372036854775807 - 1;
+
+ /** The number of nanoseconds per tick. */
+ static constexpr int64_t NanosecondsPerTick = 100;
+
+ /** The number of timespan ticks per day. */
+ static constexpr int64_t TicksPerDay = 864000000000;
+
+ /** The number of timespan ticks per hour. */
+ static constexpr int64_t TicksPerHour = 36000000000;
+
+ /** The number of timespan ticks per microsecond. */
+ static constexpr int64_t TicksPerMicrosecond = 10;
+
+ /** The number of timespan ticks per millisecond. */
+ static constexpr int64_t TicksPerMillisecond = 10000;
+
+ /** The number of timespan ticks per minute. */
+ static constexpr int64_t TicksPerMinute = 600000000;
+
+ /** The number of timespan ticks per second. */
+ static constexpr int64_t TicksPerSecond = 10000000;
+
+ /** The number of timespan ticks per week. */
+ static constexpr int64_t TicksPerWeek = 6048000000000;
+
+ /** The number of timespan ticks per year (365 days, not accounting for leap years). */
+ static constexpr int64_t TicksPerYear = 365 * TicksPerDay;
+
+ int GetFractionTicks() const { return (int)(Ticks % TicksPerSecond); }
+
+ int GetFractionMicro() const { return (int)((Ticks % TicksPerSecond) / TicksPerMicrosecond); }
+
+ int GetFractionMilli() const { return (int)((Ticks % TicksPerSecond) / TicksPerMillisecond); }
+
+ int GetFractionNano() const { return (int)((Ticks % TicksPerSecond) * NanosecondsPerTick); }
+
+ int GetDays() const { return (int)(Ticks / TicksPerDay); }
+
+ int GetHours() const { return (int)((Ticks / TicksPerHour) % 24); }
+
+ int GetMinutes() const { return (int)((Ticks / TicksPerMinute) % 60); }
+
+ int GetSeconds() const { return (int)((Ticks / TicksPerSecond) % 60); }
+
+ ZENCORE_API std::string ToString(const char* Format) const;
+ ZENCORE_API std::string ToString() const;
+
+private:
+ void Set(int Days, int Hours, int Minutes, int Seconds, int FractionNano);
+
+ uint64_t Ticks;
+};
+
+struct Guid
+{
+ uint32_t A, B, C, D;
+
+ StringBuilderBase& ToString(StringBuilderBase& OutString) const;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+/**
+ * Field types and flags for CbField.
+ *
+ * This is a private type and is only declared here to enable inline use below.
+ *
+ * DO NOT CHANGE THE VALUE OF ANY MEMBERS OF THIS ENUM!
+ * BACKWARD COMPATIBILITY REQUIRES THAT THESE VALUES BE FIXED!
+ * SERIALIZATION USES HARD-CODED CONSTANTS BASED ON THESE VALUES!
+ */
+enum class CbFieldType : uint8_t
+{
+ /** A field type that does not occur in a valid object. */
+ None = 0x00,
+
+ /** Null. Payload is empty. */
+ Null = 0x01,
+
+ /**
+ * Object is an array of fields with unique non-empty names.
+ *
+ * Payload is a VarUInt byte count for the encoded fields followed by the fields.
+ */
+ Object = 0x02,
+ /**
+ * UniformObject is an array of fields with the same field types and unique non-empty names.
+ *
+ * Payload is a VarUInt byte count for the encoded fields followed by the fields.
+ */
+ UniformObject = 0x03,
+
+ /**
+ * Array is an array of fields with no name that may be of different types.
+ *
+ * Payload is a VarUInt byte count, followed by a VarUInt item count, followed by the fields.
+ */
+ Array = 0x04,
+ /**
+ * UniformArray is an array of fields with no name and with the same field type.
+ *
+ * Payload is a VarUInt byte count, followed by a VarUInt item count, followed by field type,
+ * followed by the fields without their field type.
+ */
+ UniformArray = 0x05,
+
+ /** Binary. Payload is a VarUInt byte count followed by the data. */
+ Binary = 0x06,
+
+ /** String in UTF-8. Payload is a VarUInt byte count then an unterminated UTF-8 string. */
+ String = 0x07,
+
+ /**
+ * Non-negative integer with the range of a 64-bit unsigned integer.
+ *
+ * Payload is the value encoded as a VarUInt.
+ */
+ IntegerPositive = 0x08,
+ /**
+ * Negative integer with the range of a 64-bit signed integer.
+ *
+ * Payload is the ones' complement of the value encoded as a VarUInt.
+ */
+ IntegerNegative = 0x09,
+
+ /** Single precision float. Payload is one big endian IEEE 754 binary32 float. */
+ Float32 = 0x0a,
+ /** Double precision float. Payload is one big endian IEEE 754 binary64 float. */
+ Float64 = 0x0b,
+
+ /** Boolean false value. Payload is empty. */
+ BoolFalse = 0x0c,
+ /** Boolean true value. Payload is empty. */
+ BoolTrue = 0x0d,
+
+ /**
+ * ObjectAttachment is a reference to a compact binary attachment stored externally.
+ *
+ * Payload is a 160-bit hash digest of the referenced compact binary data.
+ */
+ ObjectAttachment = 0x0e,
+ /**
+ * BinaryAttachment is a reference to a binary attachment stored externally.
+ *
+ * Payload is a 160-bit hash digest of the referenced binary data.
+ */
+ BinaryAttachment = 0x0f,
+
+ /** Hash. Payload is a 160-bit hash digest. */
+ Hash = 0x10,
+ /** UUID/GUID. Payload is a 128-bit UUID as defined by RFC 4122. */
+ Uuid = 0x11,
+
+ /**
+ * Date and time between 0001-01-01 00:00:00.0000000 and 9999-12-31 23:59:59.9999999.
+ *
+ * Payload is a big endian int64 count of 100ns ticks since 0001-01-01 00:00:00.0000000.
+ */
+ DateTime = 0x12,
+ /**
+ * Difference between two date/time values.
+ *
+ * Payload is a big endian int64 count of 100ns ticks in the span, and may be negative.
+ */
+ TimeSpan = 0x13,
+
+ /**
+ * Object ID
+ *
+ * Payload is a 12-byte opaque identifier
+ */
+ ObjectId = 0x14,
+
+ /**
+ * CustomById identifies the sub-type of its payload by an integer identifier.
+ *
+ * Payload is a VarUInt byte count of the sub-type identifier and the sub-type payload, followed
+ * by a VarUInt of the sub-type identifier then the payload of the sub-type.
+ */
+ CustomById = 0x1e,
+ /**
+ * CustomByType identifies the sub-type of its payload by a string identifier.
+ *
+ * Payload is a VarUInt byte count of the sub-type identifier and the sub-type payload, followed
+ * by a VarUInt byte count of the unterminated sub-type identifier, then the sub-type identifier
+ * without termination, then the payload of the sub-type.
+ */
+ CustomByName = 0x1f,
+
+ /** Reserved for future use as a flag. Do not add types in this range. */
+ Reserved = 0x20,
+
+ /**
+ * A transient flag which indicates that the object or array containing this field has stored
+ * the field type before the payload and name. Non-uniform objects and fields will set this.
+ *
+ * Note: Since the flag must never be serialized, this bit may be repurposed in the future.
+ */
+ HasFieldType = 0x40,
+
+ /** A persisted flag which indicates that the field has a name stored before the payload. */
+ HasFieldName = 0x80,
+};
+
+ENUM_CLASS_FLAGS(CbFieldType);
+
+/** Functions that operate on CbFieldType. */
+class CbFieldTypeOps
+{
+ static constexpr CbFieldType SerializedTypeMask = CbFieldType(0b1011'1111);
+ static constexpr CbFieldType TypeMask = CbFieldType(0b0011'1111);
+ static constexpr CbFieldType ObjectMask = CbFieldType(0b0011'1110);
+ static constexpr CbFieldType ObjectBase = CbFieldType(0b0000'0010);
+ static constexpr CbFieldType ArrayMask = CbFieldType(0b0011'1110);
+ static constexpr CbFieldType ArrayBase = CbFieldType(0b0000'0100);
+ static constexpr CbFieldType IntegerMask = CbFieldType(0b0011'1110);
+ static constexpr CbFieldType IntegerBase = CbFieldType(0b0000'1000);
+ static constexpr CbFieldType FloatMask = CbFieldType(0b0011'1100);
+ static constexpr CbFieldType FloatBase = CbFieldType(0b0000'1000);
+ static constexpr CbFieldType BoolMask = CbFieldType(0b0011'1110);
+ static constexpr CbFieldType BoolBase = CbFieldType(0b0000'1100);
+ static constexpr CbFieldType AttachmentMask = CbFieldType(0b0011'1110);
+ static constexpr CbFieldType AttachmentBase = CbFieldType(0b0000'1110);
+
+ static void StaticAssertTypeConstants();
+
+public:
+ /** The type with flags removed. */
+ static constexpr inline CbFieldType GetType(CbFieldType Type) { return Type & TypeMask; }
+ /** The type with transient flags removed. */
+ static constexpr inline CbFieldType GetSerializedType(CbFieldType Type) { return Type & SerializedTypeMask; }
+
+ static constexpr inline bool HasFieldType(CbFieldType Type) { return EnumHasAnyFlags(Type, CbFieldType::HasFieldType); }
+ static constexpr inline bool HasFieldName(CbFieldType Type) { return EnumHasAnyFlags(Type, CbFieldType::HasFieldName); }
+
+ static constexpr inline bool IsNone(CbFieldType Type) { return GetType(Type) == CbFieldType::None; }
+ static constexpr inline bool IsNull(CbFieldType Type) { return GetType(Type) == CbFieldType::Null; }
+
+ static constexpr inline bool IsObject(CbFieldType Type) { return (Type & ObjectMask) == ObjectBase; }
+ static constexpr inline bool IsArray(CbFieldType Type) { return (Type & ArrayMask) == ArrayBase; }
+
+ static constexpr inline bool IsBinary(CbFieldType Type) { return GetType(Type) == CbFieldType::Binary; }
+ static constexpr inline bool IsString(CbFieldType Type) { return GetType(Type) == CbFieldType::String; }
+
+ static constexpr inline bool IsInteger(CbFieldType Type) { return (Type & IntegerMask) == IntegerBase; }
+ /** Whether the field is a float, or integer due to implicit conversion. */
+ static constexpr inline bool IsFloat(CbFieldType Type) { return (Type & FloatMask) == FloatBase; }
+ static constexpr inline bool IsBool(CbFieldType Type) { return (Type & BoolMask) == BoolBase; }
+
+ static constexpr inline bool IsObjectAttachment(CbFieldType Type) { return GetType(Type) == CbFieldType::ObjectAttachment; }
+ static constexpr inline bool IsBinaryAttachment(CbFieldType Type) { return GetType(Type) == CbFieldType::BinaryAttachment; }
+ static constexpr inline bool IsAttachment(CbFieldType Type) { return (Type & AttachmentMask) == AttachmentBase; }
+
+ static constexpr inline bool IsHash(CbFieldType Type)
+ {
+ switch (GetType(Type))
+ {
+ case CbFieldType::Hash:
+ case CbFieldType::BinaryAttachment:
+ case CbFieldType::ObjectAttachment:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ static constexpr inline bool IsUuid(CbFieldType Type) { return GetType(Type) == CbFieldType::Uuid; }
+ static constexpr inline bool IsObjectId(CbFieldType Type) { return GetType(Type) == CbFieldType::ObjectId; }
+
+ static constexpr inline bool IsCustomById(CbFieldType Type) { return GetType(Type) == CbFieldType::CustomById; }
+ static constexpr inline bool IsCustomByName(CbFieldType Type) { return GetType(Type) == CbFieldType::CustomByName; }
+
+ static constexpr inline bool IsDateTime(CbFieldType Type) { return GetType(Type) == CbFieldType::DateTime; }
+ static constexpr inline bool IsTimeSpan(CbFieldType Type) { return GetType(Type) == CbFieldType::TimeSpan; }
+
+ /** Whether the type is or may contain fields of any attachment type. */
+ static constexpr inline bool MayContainAttachments(CbFieldType Type)
+ {
+ return int(IsObject(Type) == true) | int(IsArray(Type) == true) | int(IsAttachment(Type) == true);
+ }
+};
+
+/** Errors that can occur when accessing a field. */
+enum class CbFieldError : uint8_t
+{
+ /** The field is not in an error state. */
+ None,
+ /** The value type does not match the requested type. */
+ TypeError,
+ /** The value is out of range for the requested type. */
+ RangeError,
+};
+
+class ICbVisitor
+{
+public:
+ virtual void SetName(std::string_view Name) = 0;
+ virtual void BeginObject() = 0;
+ virtual void EndObject() = 0;
+ virtual void BeginArray() = 0;
+ virtual void EndArray() = 0;
+ virtual void VisitNull() = 0;
+ virtual void VisitBinary(SharedBuffer Value) = 0;
+ virtual void VisitString(std::string_view Value) = 0;
+ virtual void VisitInteger(int64_t Value) = 0;
+ virtual void VisitInteger(uint64_t Value) = 0;
+ virtual void VisitFloat(float Value) = 0;
+ virtual void VisitDouble(double Value) = 0;
+ virtual void VisitBool(bool value) = 0;
+ virtual void VisitCbAttachment(const IoHash& Value) = 0;
+ virtual void VisitBinaryAttachment(const IoHash& Value) = 0;
+ virtual void VisitHash(const IoHash& Value) = 0;
+ virtual void VisitUuid(const Guid& Value) = 0;
+ virtual void VisitObjectId(const Oid& Value) = 0;
+ virtual void VisitDateTime(DateTime Value) = 0;
+ virtual void VisitTimeSpan(TimeSpan Value) = 0;
+};
+
+/** A custom compact binary field type with an integer identifier. */
+struct CbCustomById
+{
+ /** An identifier for the sub-type of the field. */
+ uint64_t Id = 0;
+ /** A view of the value. Lifetime is tied to the field that the value is associated with. */
+ MemoryView Data;
+};
+
+/** A custom compact binary field type with a string identifier. */
+struct CbCustomByName
+{
+ /** An identifier for the sub-type of the field. Lifetime is tied to the field that the name is associated with. */
+ std::u8string_view Name;
+ /** A view of the value. Lifetime is tied to the field that the value is associated with. */
+ MemoryView Data;
+};
+
+namespace CompactBinaryPrivate {
+ /** Parameters for converting to an integer. */
+ struct IntegerParams
+ {
+ /** Whether the output type has a sign bit. */
+ uint32_t IsSigned : 1;
+ /** Bits of magnitude. (7 for int8) */
+ uint32_t MagnitudeBits : 31;
+ };
+
+ /** Make integer params for the given integer type. */
+ template<typename IntType>
+ static constexpr inline IntegerParams MakeIntegerParams()
+ {
+ IntegerParams Params;
+ Params.IsSigned = IntType(-1) < IntType(0);
+ Params.MagnitudeBits = 8 * sizeof(IntType) - Params.IsSigned;
+ return Params;
+ }
+
+} // namespace CompactBinaryPrivate
+
+/**
+ * An atom of data in the compact binary format.
+ *
+ * Accessing the value of a field is always a safe operation, even if accessed as the wrong type.
+ * An invalid access will return a default value for the requested type, and set an error code on
+ * the field that can be checked with GetLastError and HasLastError. A valid access will clear an
+ * error from a previous invalid access.
+ *
+ * A field is encoded in one or more bytes, depending on its type and the type of object or array
+ * that contains it. A field of an object or array which is non-uniform encodes its field type in
+ * the first byte, and includes the HasFieldName flag for a field in an object. The field name is
+ * encoded in a variable-length unsigned integer of its size in bytes, for named fields, followed
+ * by that many bytes of the UTF-8 encoding of the name with no null terminator. The remainder of
+ * the field is the payload and is described in the field type enum. Every field must be uniquely
+ * addressable when encoded, which means a zero-byte field is not permitted, and only arises in a
+ * uniform array of fields with no payload, where the answer is to encode as a non-uniform array.
+ *
+ * This type only provides a view into memory and does not perform any memory management itself.
+ * Use CbFieldRef to hold a reference to the underlying memory when necessary.
+ */
+
+class CbFieldView
+{
+public:
+ CbFieldView() = default;
+
+ ZENCORE_API CbFieldView(const void* DataPointer, CbFieldType FieldType = CbFieldType::HasFieldType);
+
+ /** Construct a field from a value, without access to the name. */
+ inline explicit CbFieldView(const CbValue& Value);
+
+ /** Returns the name of the field if it has a name, otherwise an empty view. */
+ constexpr inline std::string_view GetName() const { return std::string_view(static_cast<const char*>(Payload) - NameLen, NameLen); }
+ /** Returns the name of the field if it has a name, otherwise an empty view. */
+ constexpr inline std::u8string_view GetU8Name() const
+ {
+ return std::u8string_view(static_cast<const char8_t*>(Payload) - NameLen, NameLen);
+ }
+
+ /** Returns the value for unchecked access. Prefer the typed accessors below. */
+ inline CbValue GetValue() const;
+
+ ZENCORE_API MemoryView AsBinaryView(MemoryView Default = MemoryView());
+ ZENCORE_API CbObjectView AsObjectView();
+ ZENCORE_API CbArrayView AsArrayView();
+ ZENCORE_API std::string_view AsString(std::string_view Default = std::string_view());
+ ZENCORE_API std::u8string_view AsU8String(std::u8string_view Default = std::u8string_view());
+
+ ZENCORE_API void IterateAttachments(std::function<void(CbFieldView)> Visitor) const;
+
+ /** Access the field as an int8. Returns the provided default on error. */
+ inline int8_t AsInt8(int8_t Default = 0) { return AsInteger<int8_t>(Default); }
+ /** Access the field as an int16. Returns the provided default on error. */
+ inline int16_t AsInt16(int16_t Default = 0) { return AsInteger<int16_t>(Default); }
+ /** Access the field as an int32. Returns the provided default on error. */
+ inline int32_t AsInt32(int32_t Default = 0) { return AsInteger<int32_t>(Default); }
+ /** Access the field as an int64. Returns the provided default on error. */
+ inline int64_t AsInt64(int64_t Default = 0) { return AsInteger<int64_t>(Default); }
+ /** Access the field as a uint8. Returns the provided default on error. */
+ inline uint8_t AsUInt8(uint8_t Default = 0) { return AsInteger<uint8_t>(Default); }
+ /** Access the field as a uint16. Returns the provided default on error. */
+ inline uint16_t AsUInt16(uint16_t Default = 0) { return AsInteger<uint16_t>(Default); }
+ /** Access the field as a uint32. Returns the provided default on error. */
+ inline uint32_t AsUInt32(uint32_t Default = 0) { return AsInteger<uint32_t>(Default); }
+ /** Access the field as a uint64. Returns the provided default on error. */
+ inline uint64_t AsUInt64(uint64_t Default = 0) { return AsInteger<uint64_t>(Default); }
+
+ /** Access the field as a float. Returns the provided default on error. */
+ ZENCORE_API float AsFloat(float Default = 0.0f);
+ /** Access the field as a double. Returns the provided default on error. */
+ ZENCORE_API double AsDouble(double Default = 0.0);
+
+ /** Access the field as a bool. Returns the provided default on error. */
+ ZENCORE_API bool AsBool(bool bDefault = false);
+
+ /** Access the field as a hash referencing a compact binary attachment. Returns the provided default on error. */
+ ZENCORE_API IoHash AsObjectAttachment(const IoHash& Default = IoHash());
+ /** Access the field as a hash referencing a binary attachment. Returns the provided default on error. */
+ ZENCORE_API IoHash AsBinaryAttachment(const IoHash& Default = IoHash());
+ /** Access the field as a hash referencing an attachment. Returns the provided default on error. */
+ ZENCORE_API IoHash AsAttachment(const IoHash& Default = IoHash());
+
+ /** Access the field as a hash. Returns the provided default on error. */
+ ZENCORE_API IoHash AsHash(const IoHash& Default = IoHash());
+
+ /** Access the field as a UUID. Returns a nil UUID on error. */
+ ZENCORE_API Guid AsUuid();
+ /** Access the field as a UUID. Returns the provided default on error. */
+ ZENCORE_API Guid AsUuid(const Guid& Default);
+
+ /** Access the field as an OID. Returns a nil OID on error. */
+ ZENCORE_API Oid AsObjectId();
+ /** Access the field as a OID. Returns the provided default on error. */
+ ZENCORE_API Oid AsObjectId(const Oid& Default);
+
+ /** Access the field as a custom sub-type with an integer identifier. Returns the provided default on error. */
+ ZENCORE_API CbCustomById AsCustomById(CbCustomById Default);
+ /** Access the field as a custom sub-type with a string identifier. Returns the provided default on error. */
+ ZENCORE_API CbCustomByName AsCustomByName(CbCustomByName Default);
+
+ /** Access the field as a date/time tick count. Returns the provided default on error. */
+ ZENCORE_API int64_t AsDateTimeTicks(int64_t Default = 0);
+
+ /** Access the field as a date/time. Returns a date/time at the epoch on error. */
+ ZENCORE_API DateTime AsDateTime();
+ /** Access the field as a date/time. Returns the provided default on error. */
+ ZENCORE_API DateTime AsDateTime(DateTime Default);
+
+ /** Access the field as a timespan tick count. Returns the provided default on error. */
+ ZENCORE_API int64_t AsTimeSpanTicks(int64_t Default = 0);
+
+ /** Access the field as a timespan. Returns an empty timespan on error. */
+ ZENCORE_API TimeSpan AsTimeSpan();
+ /** Access the field as a timespan. Returns the provided default on error. */
+ ZENCORE_API TimeSpan AsTimeSpan(TimeSpan Default);
+
+ /** True if the field has a name. */
+ constexpr inline bool HasName() const { return CbFieldTypeOps::HasFieldName(Type); }
+
+ constexpr inline bool IsNull() const { return CbFieldTypeOps::IsNull(Type); }
+
+ constexpr inline bool IsObject() const { return CbFieldTypeOps::IsObject(Type); }
+ constexpr inline bool IsArray() const { return CbFieldTypeOps::IsArray(Type); }
+
+ constexpr inline bool IsBinary() const { return CbFieldTypeOps::IsBinary(Type); }
+ constexpr inline bool IsString() const { return CbFieldTypeOps::IsString(Type); }
+
+ /** Whether the field is an integer of unspecified range and sign. */
+ constexpr inline bool IsInteger() const { return CbFieldTypeOps::IsInteger(Type); }
+ /** Whether the field is a float, or integer that supports implicit conversion. */
+ constexpr inline bool IsFloat() const { return CbFieldTypeOps::IsFloat(Type); }
+ constexpr inline bool IsBool() const { return CbFieldTypeOps::IsBool(Type); }
+
+ constexpr inline bool IsObjectAttachment() const { return CbFieldTypeOps::IsObjectAttachment(Type); }
+ constexpr inline bool IsBinaryAttachment() const { return CbFieldTypeOps::IsBinaryAttachment(Type); }
+ constexpr inline bool IsAttachment() const { return CbFieldTypeOps::IsAttachment(Type); }
+
+ constexpr inline bool IsHash() const { return CbFieldTypeOps::IsHash(Type); }
+ constexpr inline bool IsUuid() const { return CbFieldTypeOps::IsUuid(Type); }
+ constexpr inline bool IsObjectId() const { return CbFieldTypeOps::IsObjectId(Type); }
+
+ constexpr inline bool IsDateTime() const { return CbFieldTypeOps::IsDateTime(Type); }
+ constexpr inline bool IsTimeSpan() const { return CbFieldTypeOps::IsTimeSpan(Type); }
+
+ /** Whether the field has a value. */
+ constexpr inline explicit operator bool() const { return HasValue(); }
+
+ /**
+ * Whether the field has a value.
+ *
+ * All fields in a valid object or array have a value. A field with no value is returned when
+ * finding a field by name fails or when accessing an iterator past the end.
+ */
+ constexpr inline bool HasValue() const { return !CbFieldTypeOps::IsNone(Type); };
+
+ /** Whether the last field access encountered an error. */
+ constexpr inline bool HasError() const { return Error != CbFieldError::None; }
+
+ /** The type of error that occurred on the last field access, or None. */
+ constexpr inline CbFieldError GetError() const { return Error; }
+
+ /** Returns the size of the field in bytes, including the type and name. */
+ ZENCORE_API uint64_t GetSize() const;
+
+ /** Calculate the hash of the field, including the type and name. */
+ ZENCORE_API IoHash GetHash() const;
+
+ ZENCORE_API void GetHash(IoHashStream& HashStream) const;
+
+ /** Feed the field (including type and name) to the stream function */
+ inline void WriteToStream(auto Hash) const
+ {
+ const CbFieldType SerializedType = CbFieldTypeOps::GetSerializedType(Type);
+ Hash(&SerializedType, sizeof(SerializedType));
+ auto View = GetViewNoType();
+ Hash(View.GetData(), View.GetSize());
+ }
+
+ /** Copy the field into a buffer of exactly GetSize() bytes, including the type and name. */
+ ZENCORE_API void CopyTo(MutableMemoryView Buffer) const;
+
+ /** Copy the field into an archive, including its type and name. */
+ ZENCORE_API void CopyTo(BinaryWriter& Ar) const;
+
+ /**
+ * Whether this field is identical to the other field.
+ *
+ * Performs a deep comparison of any contained arrays or objects and their fields. Comparison
+ * assumes that both fields are valid and are written in the canonical format. Fields must be
+ * written in the same order in arrays and objects, and name comparison is case sensitive. If
+ * these assumptions do not hold, this may return false for equivalent inputs. Validation can
+ * be performed with ValidateCompactBinary, except for field order and field name case.
+ */
+ ZENCORE_API bool Equals(const CbFieldView& Other) const;
+
+ /** Returns a view of the field, including the type and name when present. */
+ ZENCORE_API MemoryView GetView() const;
+
+ /**
+ * Try to get a view of the field as it would be serialized, such as by CopyTo.
+ *
+ * A serialized view is not available if the field has an externally-provided type.
+ * Access the serialized form of such fields using CopyTo or FCbFieldRef::Clone.
+ */
+ inline bool TryGetSerializedView(MemoryView& OutView) const
+ {
+ if (CbFieldTypeOps::HasFieldType(Type))
+ {
+ OutView = GetView();
+ return true;
+ }
+ return false;
+ }
+
+protected:
+ /** Returns a view of the name and value payload, which excludes the type. */
+ ZENCORE_API MemoryView GetViewNoType() const;
+
+ /** Returns a view of the value payload, which excludes the type and name. */
+ inline MemoryView GetPayloadView() const { return MemoryView(Payload, GetPayloadSize()); }
+
+ /** Returns the type of the field including flags. */
+ constexpr inline CbFieldType GetType() const { return Type; }
+
+ /** Returns the start of the value payload. */
+ constexpr inline const void* GetPayload() const { return Payload; }
+
+ /** Returns the end of the value payload. */
+ inline const void* GetPayloadEnd() const { return static_cast<const uint8_t*>(Payload) + GetPayloadSize(); }
+
+ /** Returns the size of the value payload in bytes, which is the field excluding the type and name. */
+ ZENCORE_API uint64_t GetPayloadSize() const;
+
+ /** Assign a field from a pointer to its data and an optional externally-provided type. */
+ inline void Assign(const void* InData, const CbFieldType InType)
+ {
+ static_assert(std::is_trivially_destructible<CbFieldView>::value,
+ "This optimization requires CbField to be trivially destructible!");
+ new (this) CbFieldView(InData, InType);
+ }
+
+private:
+ /**
+ * Access the field as the given integer type.
+ *
+ * Returns the provided default if the value cannot be represented in the output type.
+ */
+ template<typename IntType>
+ inline IntType AsInteger(IntType Default)
+ {
+ return IntType(AsInteger(uint64_t(Default), CompactBinaryPrivate::MakeIntegerParams<IntType>()));
+ }
+
+ ZENCORE_API uint64_t AsInteger(uint64_t Default, CompactBinaryPrivate::IntegerParams Params);
+
+private:
+ /** The field type, with the transient HasFieldType flag if the field contains its type. */
+ CbFieldType Type = CbFieldType::None;
+ /** The error (if any) that occurred on the last field access. */
+ CbFieldError Error = CbFieldError::None;
+ /** The number of bytes for the name stored before the payload. */
+ uint32_t NameLen = 0;
+ /** The value payload, which also points to the end of the name. */
+ const void* Payload = nullptr;
+};
+
+template<typename FieldType>
+class TCbFieldIterator : public FieldType
+{
+public:
+ /** Construct an empty field range. */
+ constexpr TCbFieldIterator() = default;
+
+ inline TCbFieldIterator& operator++()
+ {
+ const void* const PayloadEnd = FieldType::GetPayloadEnd();
+ const int64_t AtEndMask = int64_t(PayloadEnd == FieldsEnd) - 1;
+ const CbFieldType NextType = CbFieldType(int64_t(FieldType::GetType()) & AtEndMask);
+ const void* const NextField = reinterpret_cast<const void*>(int64_t(PayloadEnd) & AtEndMask);
+ const void* const NextFieldsEnd = reinterpret_cast<const void*>(int64_t(FieldsEnd) & AtEndMask);
+
+ FieldType::Assign(NextField, NextType);
+ FieldsEnd = NextFieldsEnd;
+ return *this;
+ }
+
+ inline TCbFieldIterator operator++(int)
+ {
+ TCbFieldIterator It(*this);
+ ++*this;
+ return It;
+ }
+
+ constexpr inline FieldType& operator*() { return *this; }
+ constexpr inline FieldType* operator->() { return this; }
+
+ /** Reset this to an empty field range. */
+ inline void Reset() { *this = TCbFieldIterator(); }
+
+ /** Returns the size of the fields in the range in bytes. */
+ ZENCORE_API uint64_t GetRangeSize() const;
+
+ /** Calculate the hash of every field in the range. */
+ ZENCORE_API IoHash GetRangeHash() const;
+ ZENCORE_API void GetRangeHash(IoHashStream& Hash) const;
+
+ using FieldType::Equals;
+
+ template<typename OtherFieldType>
+ constexpr inline bool Equals(const TCbFieldIterator<OtherFieldType>& Other) const
+ {
+ return FieldType::GetPayload() == Other.OtherFieldType::GetPayload() && FieldsEnd == Other.FieldsEnd;
+ }
+
+ template<typename OtherFieldType>
+ constexpr inline bool operator==(const TCbFieldIterator<OtherFieldType>& Other) const
+ {
+ return Equals(Other);
+ }
+
+ template<typename OtherFieldType>
+ constexpr inline bool operator!=(const TCbFieldIterator<OtherFieldType>& Other) const
+ {
+ return !Equals(Other);
+ }
+
+ /** Copy the field range into a buffer of exactly GetRangeSize() bytes. */
+ ZENCORE_API void CopyRangeTo(MutableMemoryView Buffer) const;
+
+ /** Invoke the visitor for every attachment in the field range. */
+ ZENCORE_API void IterateRangeAttachments(std::function<void(CbFieldView)> Visitor) const;
+
+ /** Create a view of every field in the range. */
+ inline MemoryView GetRangeView() const { return MemoryView(FieldType::GetView().GetData(), FieldsEnd); }
+
+ /**
+ * Try to get a view of every field in the range as they would be serialized.
+ *
+ * A serialized view is not available if the underlying fields have an externally-provided type.
+ * Access the serialized form of such ranges using CbFieldRefIterator::CloneRange.
+ */
+ inline bool TryGetSerializedRangeView(MemoryView& OutView) const
+ {
+ if (CbFieldTypeOps::HasFieldType(FieldType::GetType()))
+ {
+ OutView = GetRangeView();
+ return true;
+ }
+ return false;
+ }
+
+protected:
+ /** Construct a field range that contains exactly one field. */
+ constexpr inline explicit TCbFieldIterator(FieldType InField) : FieldType(std::move(InField)), FieldsEnd(FieldType::GetPayloadEnd()) {}
+
+ /**
+ * Construct a field range from the first field and a pointer to the end of the last field.
+ *
+ * @param InField The first field, or the default field if there are no fields.
+ * @param InFieldsEnd A pointer to the end of the payload of the last field, or null.
+ */
+ constexpr inline TCbFieldIterator(FieldType&& InField, const void* InFieldsEnd) : FieldType(std::move(InField)), FieldsEnd(InFieldsEnd)
+ {
+ }
+
+ /** Returns the end of the last field, or null for an iterator at the end. */
+ template<typename OtherFieldType>
+ static inline const void* GetFieldsEnd(const TCbFieldIterator<OtherFieldType>& It)
+ {
+ return It.FieldsEnd;
+ }
+
+private:
+ friend inline TCbFieldIterator begin(const TCbFieldIterator& Iterator) { return Iterator; }
+ friend inline TCbFieldIterator end(const TCbFieldIterator&) { return TCbFieldIterator(); }
+
+private:
+ template<typename OtherType>
+ friend class TCbFieldIterator;
+
+ friend class CbFieldViewIterator;
+
+ friend class CbFieldIterator;
+
+ /** Pointer to the first byte past the end of the last field. Set to null at the end. */
+ const void* FieldsEnd = nullptr;
+};
+
+/**
+ * Iterator for CbField.
+ *
+ * @see CbFieldIterator
+ */
+class CbFieldViewIterator : public TCbFieldIterator<CbFieldView>
+{
+public:
+ constexpr CbFieldViewIterator() = default;
+
+ /** Construct a field range that contains exactly one field. */
+ static inline CbFieldViewIterator MakeSingle(const CbFieldView& Field) { return CbFieldViewIterator(Field); }
+
+ /**
+ * Construct a field range from a buffer containing zero or more valid fields.
+ *
+ * @param View A buffer containing zero or more valid fields.
+ * @param Type HasFieldType means that View contains the type. Otherwise, use the given type.
+ */
+ static inline CbFieldViewIterator MakeRange(MemoryView View, CbFieldType Type = CbFieldType::HasFieldType)
+ {
+ return !View.IsEmpty() ? TCbFieldIterator(CbFieldView(View.GetData(), Type), View.GetDataEnd()) : CbFieldViewIterator();
+ }
+
+ /** Construct an iterator from another iterator. */
+ template<typename OtherFieldType>
+ inline CbFieldViewIterator(const TCbFieldIterator<OtherFieldType>& It)
+ : TCbFieldIterator(ImplicitConv<CbFieldView>(It), GetFieldsEnd(It))
+ {
+ }
+
+private:
+ using TCbFieldIterator::TCbFieldIterator;
+};
+
+/**
+ * Serialize a compact binary array to JSON.
+ */
+ZENCORE_API void CompactBinaryToJson(const CbArrayView& Object, StringBuilderBase& Builder);
+
+/**
+ * Array of CbField that have no names.
+ *
+ * Accessing a field of the array requires iteration. Access by index is not provided because the
+ * cost of accessing an item by index scales linearly with the index.
+ *
+ * This type only provides a view into memory and does not perform any memory management itself.
+ * Use CbArrayRef to hold a reference to the underlying memory when necessary.
+ */
+class CbArrayView : protected CbFieldView
+{
+ friend class CbFieldView;
+
+public:
+ /** @see CbField::CbField */
+ using CbFieldView::CbFieldView;
+
+ /** Construct an array with no fields. */
+ ZENCORE_API CbArrayView();
+
+ /** Returns the number of items in the array. */
+ ZENCORE_API uint64_t Num() const;
+
+ /** Create an iterator for the fields of this array. */
+ ZENCORE_API CbFieldViewIterator CreateViewIterator() const;
+
+ /** Visit the fields of this array. */
+ ZENCORE_API void VisitFields(ICbVisitor& Visitor);
+
+ /** Access the array as an array field. */
+ inline CbFieldView AsFieldView() const { return static_cast<const CbFieldView&>(*this); }
+
+ /** Construct an array from an array field. No type check is performed! */
+ static inline CbArrayView FromFieldView(const CbFieldView& Field) { return CbArrayView(Field); }
+
+ /** Whether the array has any fields. */
+ inline explicit operator bool() const { return Num() > 0; }
+
+ /** Returns the size of the array in bytes if serialized by itself with no name. */
+ ZENCORE_API uint64_t GetSize() const;
+
+ /** Calculate the hash of the array if serialized by itself with no name. */
+ ZENCORE_API IoHash GetHash() const;
+
+ ZENCORE_API void GetHash(IoHashStream& Stream) const;
+
+ /**
+ * Whether this array is identical to the other array.
+ *
+ * Performs a deep comparison of any contained arrays or objects and their fields. Comparison
+ * assumes that both fields are valid and are written in the canonical format. Fields must be
+ * written in the same order in arrays and objects, and name comparison is case sensitive. If
+ * these assumptions do not hold, this may return false for equivalent inputs. Validation can
+ * be done with the All mode to check these assumptions about the format of the inputs.
+ */
+ ZENCORE_API bool Equals(const CbArrayView& Other) const;
+
+ /** Copy the array into a buffer of exactly GetSize() bytes, with no name. */
+ ZENCORE_API void CopyTo(MutableMemoryView Buffer) const;
+
+ /** Copy the array into an archive, including its type and name. */
+ ZENCORE_API void CopyTo(BinaryWriter& Ar) const;
+
+ ///** Invoke the visitor for every attachment in the array. */
+ inline void IterateAttachments(std::function<void(CbFieldView)> Visitor) const
+ {
+ CreateViewIterator().IterateRangeAttachments(Visitor);
+ }
+
+ /** Returns a view of the array, including the type and name when present. */
+ using CbFieldView::GetView;
+
+ StringBuilderBase& ToJson(StringBuilderBase& Builder) const
+ {
+ CompactBinaryToJson(*this, Builder);
+ return Builder;
+ }
+
+private:
+ friend inline CbFieldViewIterator begin(const CbArrayView& Array) { return Array.CreateViewIterator(); }
+ friend inline CbFieldViewIterator end(const CbArrayView&) { return CbFieldViewIterator(); }
+
+ /** Construct an array from an array field. No type check is performed! Use via FromField. */
+ inline explicit CbArrayView(const CbFieldView& Field) : CbFieldView(Field) {}
+};
+
+/**
+ * Serialize a compact binary object to JSON.
+ */
+ZENCORE_API void CompactBinaryToJson(const CbObjectView& Object, StringBuilderBase& Builder);
+
+class CbObjectView : protected CbFieldView
+{
+ friend class CbFieldView;
+
+public:
+ /** @see CbField::CbField */
+ using CbFieldView::CbFieldView;
+
+ using CbFieldView::TryGetSerializedView;
+
+ /** Construct an object with no fields. */
+ ZENCORE_API CbObjectView();
+
+ /** Create an iterator for the fields of this object. */
+ ZENCORE_API CbFieldViewIterator CreateViewIterator() const;
+
+ /** Visit the fields of this object. */
+ ZENCORE_API void VisitFields(ICbVisitor& Visitor);
+
+ /**
+ * Find a field by case-sensitive name comparison.
+ *
+ * The cost of this operation scales linearly with the number of fields in the object. Prefer
+ * to iterate over the fields only once when consuming an object.
+ *
+ * @param Name The name of the field.
+ * @return The matching field if found, otherwise a field with no value.
+ */
+ ZENCORE_API CbFieldView FindView(std::string_view Name) const;
+
+ /** Find a field by case-insensitive name comparison. */
+ ZENCORE_API CbFieldView FindViewIgnoreCase(std::string_view Name) const;
+
+ /** Find a field by case-sensitive name comparison. */
+ inline CbFieldView operator[](std::string_view Name) const { return FindView(Name); }
+
+ /** Access the object as an object field. */
+ inline CbFieldView AsFieldView() const { return static_cast<const CbFieldView&>(*this); }
+
+ /** Construct an object from an object field. No type check is performed! */
+ static inline CbObjectView FromFieldView(const CbFieldView& Field) { return CbObjectView(Field); }
+
+ /** Whether the object has any fields. */
+ ZENCORE_API explicit operator bool() const;
+
+ /** Returns the size of the object in bytes if serialized by itself with no name. */
+ ZENCORE_API uint64_t GetSize() const;
+
+ /** Calculate the hash of the object if serialized by itself with no name. */
+ ZENCORE_API IoHash GetHash() const;
+
+ ZENCORE_API void GetHash(IoHashStream& HashStream) const;
+
+ /**
+ * Whether this object is identical to the other object.
+ *
+ * Performs a deep comparison of any contained arrays or objects and their fields. Comparison
+ * assumes that both fields are valid and are written in the canonical format. Fields must be
+ * written in the same order in arrays and objects, and name comparison is case sensitive. If
+ * these assumptions do not hold, this may return false for equivalent inputs. Validation can
+ * be done with the All mode to check these assumptions about the format of the inputs.
+ */
+ ZENCORE_API bool Equals(const CbObjectView& Other) const;
+
+ /** Copy the object into a buffer of exactly GetSize() bytes, with no name. */
+ ZENCORE_API void CopyTo(MutableMemoryView Buffer) const;
+
+ /** Copy the field into an archive, including its type and name. */
+ ZENCORE_API void CopyTo(BinaryWriter& Ar) const;
+
+ ///** Invoke the visitor for every attachment in the object. */
+ inline void IterateAttachments(std::function<void(CbFieldView)> Visitor) const
+ {
+ CreateViewIterator().IterateRangeAttachments(Visitor);
+ }
+
+ /** Returns a view of the object, including the type and name when present. */
+ using CbFieldView::GetView;
+
+ /** Whether the field has a value. */
+ using CbFieldView::operator bool;
+
+ StringBuilderBase& ToJson(StringBuilderBase& Builder) const
+ {
+ CompactBinaryToJson(*this, Builder);
+ return Builder;
+ }
+
+private:
+ friend inline CbFieldViewIterator begin(const CbObjectView& Object) { return Object.CreateViewIterator(); }
+ friend inline CbFieldViewIterator end(const CbObjectView&) { return CbFieldViewIterator(); }
+
+ /** Construct an object from an object field. No type check is performed! Use via FromField. */
+ inline explicit CbObjectView(const CbFieldView& Field) : CbFieldView(Field) {}
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+/** A reference to a function that is used to allocate buffers for compact binary data. */
+using BufferAllocator = std::function<UniqueBuffer(uint64_t Size)>;
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/** A wrapper that holds a reference to the buffer that contains its compact binary value. */
+template<typename BaseType>
+class CbBuffer : public BaseType
+{
+public:
+ /** Construct a default value. */
+ CbBuffer() = default;
+
+ /**
+ * Construct a value from a pointer to its data and an optional externally-provided type.
+ *
+ * @param ValueBuffer A buffer that exactly contains the value.
+ * @param Type HasFieldType means that ValueBuffer contains the type. Otherwise, use the given type.
+ */
+ inline explicit CbBuffer(SharedBuffer ValueBuffer, CbFieldType Type = CbFieldType::HasFieldType)
+ {
+ if (ValueBuffer)
+ {
+ BaseType::operator=(BaseType(ValueBuffer.GetData(), Type));
+ ZEN_ASSERT(ValueBuffer.GetView().Contains(BaseType::GetView()));
+ Buffer = std::move(ValueBuffer);
+ }
+ }
+
+ /** Construct a value that holds a reference to the buffer that contains it. */
+ inline CbBuffer(const BaseType& Value, SharedBuffer OuterBuffer) : BaseType(Value)
+ {
+ if (OuterBuffer)
+ {
+ ZEN_ASSERT(OuterBuffer.GetView().Contains(BaseType::GetView()));
+ Buffer = std::move(OuterBuffer);
+ }
+ }
+
+ /** Construct a value that holds a reference to the buffer of the outer that contains it. */
+ template<typename OtherBaseType>
+ inline CbBuffer(const BaseType& Value, CbBuffer<OtherBaseType> OuterRef) : CbBuffer(Value, std::move(OuterRef.Buffer))
+ {
+ }
+
+ /** Reset this to a default value and null buffer. */
+ inline void Reset() { *this = CbBuffer(); }
+
+ /** Whether this reference has ownership of the memory in its buffer. */
+ inline bool IsOwned() const { return Buffer && Buffer.IsOwned(); }
+
+ /** Clone the value, if necessary, to a buffer that this reference has ownership of. */
+ inline void MakeOwned()
+ {
+ if (!IsOwned())
+ {
+ UniqueBuffer MutableBuffer = UniqueBuffer::Alloc(BaseType::GetSize());
+ BaseType::CopyTo(MutableBuffer);
+ BaseType::operator=(BaseType(MutableBuffer.GetData()));
+ Buffer = std::move(MutableBuffer);
+ }
+ }
+
+ /** Returns a buffer that exactly contains this value. */
+ inline SharedBuffer GetBuffer() const
+ {
+ const MemoryView View = BaseType::GetView();
+ const SharedBuffer& OuterBuffer = GetOuterBuffer();
+ return View == OuterBuffer.GetView() ? OuterBuffer : SharedBuffer::MakeView(View, OuterBuffer);
+ }
+
+ /** Returns the outer buffer (if any) that contains this value. */
+ inline const SharedBuffer& GetOuterBuffer() const& { return Buffer; }
+ inline SharedBuffer GetOuterBuffer() && { return std::move(Buffer); }
+
+private:
+ template<typename OtherType>
+ friend class CbBuffer;
+
+ SharedBuffer Buffer;
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * Factory functions for types derived from CbBuffer.
+ *
+ * This uses the curiously recurring template pattern to construct the correct type of reference.
+ * The derived type inherits from CbBufferRef and this type to expose the factory functions.
+ */
+template<typename RefType, typename BaseType>
+class CbBufferFactory
+{
+public:
+ /** Construct a value from an owned clone of its memory. */
+ static inline RefType Clone(const void* const Data) { return Clone(BaseType(Data)); }
+
+ /** Construct a value from an owned clone of its memory. */
+ static inline RefType Clone(const BaseType& Value)
+ {
+ RefType Ref = MakeView(Value);
+ Ref.MakeOwned();
+ return Ref;
+ }
+
+ /** Construct a value from a read-only view of its memory and its optional outer buffer. */
+ static inline RefType MakeView(const void* const Data, SharedBuffer OuterBuffer = SharedBuffer())
+ {
+ return MakeView(BaseType(Data), std::move(OuterBuffer));
+ }
+
+ /** Construct a value from a read-only view of its memory and its optional outer buffer. */
+ static inline RefType MakeView(const BaseType& Value, SharedBuffer OuterBuffer = SharedBuffer())
+ {
+ return RefType(Value, std::move(OuterBuffer));
+ }
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+class CbArray;
+class CbObject;
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * A field that can hold a reference to the memory that contains it.
+ *
+ * @see CbBufferRef
+ */
+class CbField : public CbBuffer<CbFieldView>, public CbBufferFactory<CbField, CbFieldView>
+{
+public:
+ using CbBuffer::CbBuffer;
+
+ /** Access the field as an object. Defaults to an empty object on error. */
+ inline CbObject AsObject() &;
+ inline CbObject AsObject() &&;
+
+ /** Access the field as an array. Defaults to an empty array on error. */
+ inline CbArray AsArray() &;
+ inline CbArray AsArray() &&;
+
+ /** Access the field as binary. Returns the provided default on error. */
+ inline SharedBuffer AsBinary(const SharedBuffer& Default = SharedBuffer()) &;
+ inline SharedBuffer AsBinary(const SharedBuffer& Default = SharedBuffer()) &&;
+};
+
+/**
+ * Iterator for CbFieldRef.
+ *
+ * @see CbFieldIterator
+ */
+class CbFieldIterator : public TCbFieldIterator<CbField>
+{
+public:
+ /** Construct a field range from an owned clone of a range. */
+ ZENCORE_API static CbFieldIterator CloneRange(const CbFieldViewIterator& It);
+
+ /** Construct a field range from an owned clone of a range. */
+ static inline CbFieldIterator CloneRange(const CbFieldIterator& It) { return CloneRange(CbFieldViewIterator(It)); }
+
+ /** Construct a field range that contains exactly one field. */
+ static inline CbFieldIterator MakeSingle(CbField Field) { return CbFieldIterator(std::move(Field)); }
+
+ /**
+ * Construct a field range from a buffer containing zero or more valid fields.
+ *
+ * @param Buffer A buffer containing zero or more valid fields.
+ * @param Type HasFieldType means that Buffer contains the type. Otherwise, use the given type.
+ */
+ static inline CbFieldIterator MakeRange(SharedBuffer Buffer, CbFieldType Type = CbFieldType::HasFieldType)
+ {
+ if (Buffer.GetSize())
+ {
+ const void* const DataEnd = Buffer.GetView().GetDataEnd();
+ return CbFieldIterator(CbField(std::move(Buffer), Type), DataEnd);
+ }
+ return CbFieldIterator();
+ }
+
+ /** Construct a field range from an iterator and its optional outer buffer. */
+ static inline CbFieldIterator MakeRangeView(const CbFieldViewIterator& It, SharedBuffer OuterBuffer = SharedBuffer())
+ {
+ return CbFieldIterator(CbField(It, std::move(OuterBuffer)), GetFieldsEnd(It));
+ }
+
+ /** Construct an empty field range. */
+ constexpr CbFieldIterator() = default;
+
+ /** Clone the range, if necessary, to a buffer that this reference has ownership of. */
+ inline void MakeRangeOwned()
+ {
+ if (!IsOwned())
+ {
+ *this = CloneRange(*this);
+ }
+ }
+
+ /** Returns a buffer that exactly contains the field range. */
+ ZENCORE_API SharedBuffer GetRangeBuffer() const;
+
+private:
+ using TCbFieldIterator::TCbFieldIterator;
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * An array that can hold a reference to the memory that contains it.
+ *
+ * @see CbBuffer
+ */
+class CbArray : public CbBuffer<CbArrayView>, public CbBufferFactory<CbArray, CbArrayView>
+{
+public:
+ using CbBuffer::CbBuffer;
+
+ /** Create an iterator for the fields of this array. */
+ inline CbFieldIterator CreateIterator() const { return CbFieldIterator::MakeRangeView(CreateViewIterator(), GetOuterBuffer()); }
+
+ /** Access the array as an array field. */
+ inline CbField AsField() const& { return CbField(CbArrayView::AsFieldView(), *this); }
+
+ /** Access the array as an array field. */
+ inline CbField AsField() && { return CbField(CbArrayView::AsFieldView(), std::move(*this)); }
+
+private:
+ friend inline CbFieldIterator begin(const CbArray& Array) { return Array.CreateIterator(); }
+ friend inline CbFieldIterator end(const CbArray&) { return CbFieldIterator(); }
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * An object that can hold a reference to the memory that contains it.
+ *
+ * @see CbBuffer
+ */
+class CbObject : public CbBuffer<CbObjectView>, public CbBufferFactory<CbObject, CbObjectView>
+{
+public:
+ using CbBuffer::CbBuffer;
+
+ /** Create an iterator for the fields of this object. */
+ inline CbFieldIterator CreateIterator() const { return CbFieldIterator::MakeRangeView(CreateViewIterator(), GetOuterBuffer()); }
+
+ /** Find a field by case-sensitive name comparison. */
+ inline CbField Find(std::string_view Name) const
+ {
+ if (CbFieldView Field = FindView(Name))
+ {
+ return CbField(Field, *this);
+ }
+ return CbField();
+ }
+
+ /** Find a field by case-insensitive name comparison. */
+ inline CbField FindIgnoreCase(std::string_view Name) const
+ {
+ if (CbFieldView Field = FindViewIgnoreCase(Name))
+ {
+ return CbField(Field, *this);
+ }
+ return CbField();
+ }
+
+ /** Find a field by case-sensitive name comparison. */
+ inline CbFieldView operator[](std::string_view Name) const { return Find(Name); }
+
+ /** Access the object as an object field. */
+ inline CbField AsField() const& { return CbField(CbObjectView::AsFieldView(), *this); }
+
+ /** Access the object as an object field. */
+ inline CbField AsField() && { return CbField(CbObjectView::AsFieldView(), std::move(*this)); }
+
+private:
+ friend inline CbFieldIterator begin(const CbObject& Object) { return Object.CreateIterator(); }
+ friend inline CbFieldIterator end(const CbObject&) { return CbFieldIterator(); }
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline CbObject
+CbField::AsObject() &
+{
+ return IsObject() ? CbObject(AsObjectView(), *this) : CbObject();
+}
+
+inline CbObject
+CbField::AsObject() &&
+{
+ return IsObject() ? CbObject(AsObjectView(), std::move(*this)) : CbObject();
+}
+
+inline CbArray
+CbField::AsArray() &
+{
+ return IsArray() ? CbArray(AsArrayView(), *this) : CbArray();
+}
+
+inline CbArray
+CbField::AsArray() &&
+{
+ return IsArray() ? CbArray(AsArrayView(), std::move(*this)) : CbArray();
+}
+
+inline SharedBuffer
+CbField::AsBinary(const SharedBuffer& Default) &
+{
+ const MemoryView View = AsBinaryView();
+ return !HasError() ? SharedBuffer::MakeView(View, GetOuterBuffer()) : Default;
+}
+
+inline SharedBuffer
+CbField::AsBinary(const SharedBuffer& Default) &&
+{
+ const MemoryView View = AsBinaryView();
+ return !HasError() ? SharedBuffer::MakeView(View, std::move(*this).GetOuterBuffer()) : Default;
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * Load a compact binary field from an archive.
+ *
+ * The field may be an array or an object, which the caller can convert to by using AsArray or
+ * AsObject as appropriate. The buffer allocator is called to provide the buffer for the field
+ * to load into once its size has been determined.
+ *
+ * @param Ar Archive to read the field from. An error state is set on failure.
+ * @param Allocator Allocator for the buffer that the field is loaded into.
+ * @return A field with a reference to the allocated buffer, or a default field on failure.
+ */
+ZENCORE_API CbField LoadCompactBinary(BinaryReader& Ar, BufferAllocator Allocator);
+
+ZENCORE_API CbObject LoadCompactBinaryObject(IoBuffer&& Payload);
+ZENCORE_API CbObject LoadCompactBinaryObject(const IoBuffer& Payload);
+ZENCORE_API CbObject LoadCompactBinaryObject(CompressedBuffer&& Payload);
+ZENCORE_API CbObject LoadCompactBinaryObject(const CompressedBuffer& Payload);
+
+/**
+ * Load a compact binary from JSON.
+ */
+ZENCORE_API CbFieldIterator LoadCompactBinaryFromJson(std::string_view Json, std::string& Error);
+ZENCORE_API CbFieldIterator LoadCompactBinaryFromJson(std::string_view Json);
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * Determine the size in bytes of the compact binary field at the start of the view.
+ *
+ * This may be called on an incomplete or invalid field, in which case the returned size is zero.
+ * A size can always be extracted from a valid field with no name if a view of at least the first
+ * 10 bytes is provided, regardless of field size. For fields with names, the size of view needed
+ * to calculate a size is at most 10 + MaxNameLen + MeasureVarUInt(MaxNameLen).
+ *
+ * This function can be used when streaming a field, for example, to determine the size of buffer
+ * to fill before attempting to construct a field from it.
+ *
+ * @param View A memory view that may contain the start of a field.
+ * @param Type HasFieldType means that View contains the type. Otherwise, use the given type.
+ */
+ZENCORE_API uint64_t MeasureCompactBinary(MemoryView View, CbFieldType Type = CbFieldType::HasFieldType);
+
+/**
+ * Try to determine the type and size of the compact binary field at the start of the view.
+ *
+ * This may be called on an incomplete or invalid field, in which case it will return false, with
+ * OutSize being 0 for invalid fields, otherwise the minimum view size necessary to make progress
+ * in measuring the field on the next call to this function.
+ *
+ * @note A return of true from this function does not indicate that the entire field is valid.
+ *
+ * @param InView A memory view that may contain the start of a field.
+ * @param OutType The type (with flags) of the field. None is written until a value is available.
+ * @param OutSize The total field size for a return of true, 0 for invalid fields, or the size to
+ * make progress in measuring the field on the next call to this function.
+ * @param InType HasFieldType means that InView contains the type. Otherwise, use the given type.
+ * @return true if the size of the field was determined, otherwise false.
+ */
+ZENCORE_API bool TryMeasureCompactBinary(MemoryView InView,
+ CbFieldType& OutType,
+ uint64_t& OutSize,
+ CbFieldType InType = CbFieldType::HasFieldType);
+
+inline CbFieldViewIterator
+begin(CbFieldView& View)
+{
+ if (View.IsArray())
+ {
+ return View.AsArrayView().CreateViewIterator();
+ }
+ else if (View.IsObject())
+ {
+ return View.AsObjectView().CreateViewIterator();
+ }
+
+ return CbFieldViewIterator();
+}
+
+inline CbFieldViewIterator
+end(CbFieldView&)
+{
+ return CbFieldViewIterator();
+}
+
+void uson_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/compactbinarybuilder.h b/src/zencore/include/zencore/compactbinarybuilder.h
new file mode 100644
index 000000000..4be8c2ba5
--- /dev/null
+++ b/src/zencore/include/zencore/compactbinarybuilder.h
@@ -0,0 +1,661 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#include <zencore/compactbinary.h>
+
+#include <zencore/enumflags.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/refcount.h>
+#include <zencore/sha1.h>
+
+#include <atomic>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <type_traits>
+#include <vector>
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+class CbAttachment;
+class BinaryWriter;
+
+/**
+ * A writer for compact binary object, arrays, and fields.
+ *
+ * The writer produces a sequence of fields that can be saved to a provided memory buffer or into
+ * a new owned buffer. The typical use case is to write a single object, which can be accessed by
+ * calling Save().AsObjectRef() or Save(Buffer).AsObject().
+ *
+ * The writer will assert on most incorrect usage and will always produce valid compact binary if
+ * provided with valid input. The writer does not check for invalid UTF-8 string encoding, object
+ * fields with duplicate names, or invalid compact binary being copied from another source.
+ *
+ * It is most convenient to use the streaming API for the writer, as demonstrated in the example.
+ *
+ * When writing a small amount of compact binary data, TCbWriter can be more efficient as it uses
+ * a fixed-size stack buffer for storage before spilling onto the heap.
+ *
+ * @see TCbWriter
+ *
+ * Example:
+ *
+ * CbObjectRef WriteObject()
+ * {
+ * CbWriter<256> Writer;
+ * Writer.BeginObject();
+ *
+ * Writer << "Resize" << true;
+ * Writer << "MaxWidth" << 1024;
+ * Writer << "MaxHeight" << 1024;
+ *
+ * Writer.BeginArray();
+ * Writer << "FormatA" << "FormatB" << "FormatC";
+ * Writer.EndArray();
+ *
+ * Writer.EndObject();
+ * return Writer.Save().AsObjectRef();
+ * }
+ */
+class CbWriter
+{
+public:
+ ZENCORE_API CbWriter();
+ ZENCORE_API ~CbWriter();
+
+ CbWriter(const CbWriter&) = delete;
+ CbWriter& operator=(const CbWriter&) = delete;
+
+ /** Empty the writer without releasing any allocated memory. */
+ ZENCORE_API void Reset();
+
+ /**
+ * Serialize the field(s) to an owned buffer and return it as an iterator.
+ *
+ * It is not valid to call this function in the middle of writing an object, array, or field.
+ * The writer remains valid for further use when this function returns.
+ */
+ ZENCORE_API CbFieldIterator Save();
+
+ /**
+ * Serialize the field(s) to memory.
+ *
+ * It is not valid to call this function in the middle of writing an object, array, or field.
+ * The writer remains valid for further use when this function returns.
+ *
+ * @param Buffer A mutable memory view to write to. Must be exactly GetSaveSize() bytes.
+ * @return An iterator for the field(s) written to the buffer.
+ */
+ ZENCORE_API CbFieldViewIterator Save(MutableMemoryView Buffer);
+
+ ZENCORE_API void Save(BinaryWriter& Writer);
+
+ /**
+ * The size of buffer (in bytes) required to serialize the fields that have been written.
+ *
+ * It is not valid to call this function in the middle of writing an object, array, or field.
+ */
+ ZENCORE_API uint64_t GetSaveSize() const;
+
+ /**
+ * Sets the name of the next field to be written.
+ *
+ * It is not valid to call this function when writing a field inside an array.
+ * Names must be valid UTF-8 and must be unique within an object.
+ */
+ ZENCORE_API CbWriter& SetName(std::string_view Name);
+
+ /** Copy the value (not the name) of an existing field. */
+ inline void AddField(std::string_view Name, const CbFieldView& Value)
+ {
+ SetName(Name);
+ AddField(Value);
+ }
+
+ ZENCORE_API void AddField(const CbFieldView& Value);
+
+ /** Copy the value (not the name) of an existing field. Holds a reference if owned. */
+ inline void AddField(std::string_view Name, const CbField& Value)
+ {
+ SetName(Name);
+ AddField(Value);
+ }
+ ZENCORE_API void AddField(const CbField& Value);
+
+ /** Begin a new object. Must have a matching call to EndObject. */
+ inline void BeginObject(std::string_view Name)
+ {
+ SetName(Name);
+ BeginObject();
+ }
+ ZENCORE_API void BeginObject();
+ /** End an object after its fields have been written. */
+ ZENCORE_API void EndObject();
+
+ /** Copy the value (not the name) of an existing object. */
+ inline void AddObject(std::string_view Name, const CbObjectView& Value)
+ {
+ SetName(Name);
+ AddObject(Value);
+ }
+ ZENCORE_API void AddObject(const CbObjectView& Value);
+ /** Copy the value (not the name) of an existing object. Holds a reference if owned. */
+ inline void AddObject(std::string_view Name, const CbObject& Value)
+ {
+ SetName(Name);
+ AddObject(Value);
+ }
+ ZENCORE_API void AddObject(const CbObject& Value);
+
+ /** Begin a new array. Must have a matching call to EndArray. */
+ inline void BeginArray(std::string_view Name)
+ {
+ SetName(Name);
+ BeginArray();
+ }
+ ZENCORE_API void BeginArray();
+ /** End an array after its fields have been written. */
+ ZENCORE_API void EndArray();
+
+ /** Copy the value (not the name) of an existing array. */
+ inline void AddArray(std::string_view Name, const CbArrayView& Value)
+ {
+ SetName(Name);
+ AddArray(Value);
+ }
+ ZENCORE_API void AddArray(const CbArrayView& Value);
+ /** Copy the value (not the name) of an existing array. Holds a reference if owned. */
+ inline void AddArray(std::string_view Name, const CbArray& Value)
+ {
+ SetName(Name);
+ AddArray(Value);
+ }
+ ZENCORE_API void AddArray(const CbArray& Value);
+
+ /** Write a null field. */
+ inline void AddNull(std::string_view Name)
+ {
+ SetName(Name);
+ AddNull();
+ }
+ ZENCORE_API void AddNull();
+
+ /** Write a binary field by copying Size bytes from Value. */
+ inline void AddBinary(std::string_view Name, const void* Value, uint64_t Size)
+ {
+ SetName(Name);
+ AddBinary(Value, Size);
+ }
+ ZENCORE_API void AddBinary(const void* Value, uint64_t Size);
+ /** Write a binary field by copying the view. */
+ inline void AddBinary(std::string_view Name, MemoryView Value)
+ {
+ SetName(Name);
+ AddBinary(Value);
+ }
+ inline void AddBinary(MemoryView Value) { AddBinary(Value.GetData(), Value.GetSize()); }
+
+ /** Write a binary field by copying the buffer. Holds a reference if owned. */
+ inline void AddBinary(std::string_view Name, IoBuffer Value)
+ {
+ SetName(Name);
+ AddBinary(std::move(Value));
+ }
+ ZENCORE_API void AddBinary(IoBuffer Value);
+ ZENCORE_API void AddBinary(SharedBuffer Value);
+
+ inline void AddBinary(std::string_view Name, const CompositeBuffer& Buffer)
+ {
+ SetName(Name);
+ AddBinary(Buffer);
+ }
+ ZENCORE_API void AddBinary(const CompositeBuffer& Buffer);
+
+ /** Write a string field by copying the UTF-8 value. */
+ inline void AddString(std::string_view Name, std::string_view Value)
+ {
+ SetName(Name);
+ AddString(Value);
+ }
+ ZENCORE_API void AddString(std::string_view Value);
+ /** Write a string field by converting the UTF-16 value to UTF-8. */
+ inline void AddString(std::string_view Name, std::wstring_view Value)
+ {
+ SetName(Name);
+ AddString(Value);
+ }
+ ZENCORE_API void AddString(std::wstring_view Value);
+
+ /** Write an integer field. */
+ inline void AddInteger(std::string_view Name, int32_t Value)
+ {
+ SetName(Name);
+ AddInteger(Value);
+ }
+ ZENCORE_API void AddInteger(int32_t Value);
+ /** Write an integer field. */
+ inline void AddInteger(std::string_view Name, int64_t Value)
+ {
+ SetName(Name);
+ AddInteger(Value);
+ }
+ ZENCORE_API void AddInteger(int64_t Value);
+ /** Write an integer field. */
+ inline void AddInteger(std::string_view Name, uint32_t Value)
+ {
+ SetName(Name);
+ AddInteger(Value);
+ }
+ ZENCORE_API void AddInteger(uint32_t Value);
+ /** Write an integer field. */
+ inline void AddInteger(std::string_view Name, uint64_t Value)
+ {
+ SetName(Name);
+ AddInteger(Value);
+ }
+ ZENCORE_API void AddInteger(uint64_t Value);
+
+ /** Write a float field from a 32-bit float value. */
+ inline void AddFloat(std::string_view Name, float Value)
+ {
+ SetName(Name);
+ AddFloat(Value);
+ }
+ ZENCORE_API void AddFloat(float Value);
+
+ /** Write a float field from a 64-bit float value. */
+ inline void AddFloat(std::string_view Name, double Value)
+ {
+ SetName(Name);
+ AddFloat(Value);
+ }
+ ZENCORE_API void AddFloat(double Value);
+
+ /** Write a bool field. */
+ inline void AddBool(std::string_view Name, bool bValue)
+ {
+ SetName(Name);
+ AddBool(bValue);
+ }
+ ZENCORE_API void AddBool(bool bValue);
+
+ /** Write a field referencing a compact binary attachment by its hash. */
+ inline void AddObjectAttachment(std::string_view Name, const IoHash& Value)
+ {
+ SetName(Name);
+ AddObjectAttachment(Value);
+ }
+ ZENCORE_API void AddObjectAttachment(const IoHash& Value);
+
+ /** Write a field referencing a binary attachment by its hash. */
+ inline void AddBinaryAttachment(std::string_view Name, const IoHash& Value)
+ {
+ SetName(Name);
+ AddBinaryAttachment(Value);
+ }
+ ZENCORE_API void AddBinaryAttachment(const IoHash& Value);
+
+ /** Write a field referencing the attachment by its hash. */
+ inline void AddAttachment(std::string_view Name, const CbAttachment& Attachment)
+ {
+ SetName(Name);
+ AddAttachment(Attachment);
+ }
+ ZENCORE_API void AddAttachment(const CbAttachment& Attachment);
+
+ /** Write a hash field. */
+ inline void AddHash(std::string_view Name, const IoHash& Value)
+ {
+ SetName(Name);
+ AddHash(Value);
+ }
+ ZENCORE_API void AddHash(const IoHash& Value);
+
+ /** Write a UUID field. */
+ inline void AddUuid(std::string_view Name, const Guid& Value)
+ {
+ SetName(Name);
+ AddUuid(Value);
+ }
+ ZENCORE_API void AddUuid(const Guid& Value);
+
+ /** Write an ObjectId field. */
+ inline void AddObjectId(std::string_view Name, const Oid& Value)
+ {
+ SetName(Name);
+ AddObjectId(Value);
+ }
+ ZENCORE_API void AddObjectId(const Oid& Value);
+
+ /** Write a date/time field with the specified count of 100ns ticks since the epoch. */
+ inline void AddDateTimeTicks(std::string_view Name, int64_t Ticks)
+ {
+ SetName(Name);
+ AddDateTimeTicks(Ticks);
+ }
+ ZENCORE_API void AddDateTimeTicks(int64_t Ticks);
+
+ /** Write a date/time field. */
+ inline void AddDateTime(std::string_view Name, DateTime Value)
+ {
+ SetName(Name);
+ AddDateTime(Value);
+ }
+ ZENCORE_API void AddDateTime(DateTime Value);
+
+ /** Write a time span field with the specified count of 100ns ticks. */
+ inline void AddTimeSpanTicks(std::string_view Name, int64_t Ticks)
+ {
+ SetName(Name);
+ AddTimeSpanTicks(Ticks);
+ }
+ ZENCORE_API void AddTimeSpanTicks(int64_t Ticks);
+
+ /** Write a time span field. */
+ inline void AddTimeSpan(std::string_view Name, TimeSpan Value)
+ {
+ SetName(Name);
+ AddTimeSpan(Value);
+ }
+ ZENCORE_API void AddTimeSpan(TimeSpan Value);
+
+ /** Private flags that are public to work with ENUM_CLASS_FLAGS. */
+ enum class StateFlags : uint8_t;
+
+protected:
+ /** Reserve the specified size up front until the format is optimized. */
+ ZENCORE_API explicit CbWriter(int64_t InitialSize);
+
+private:
+ friend CbWriter& operator<<(CbWriter& Writer, std::string_view NameOrValue);
+
+ /** Begin writing a field. May be called twice for named fields. */
+ void BeginField();
+
+ /** Finish writing a field by writing its type. */
+ void EndField(CbFieldType Type);
+
+ /** Set the field name if valid in this state, otherwise write add a string field. */
+ ZENCORE_API void SetNameOrAddString(std::string_view NameOrValue);
+
+ /** Returns a view of the name of the active field, if any, otherwise the empty view. */
+ std::string_view GetActiveName() const;
+
+ /** Remove field types after the first to make the sequence uniform. */
+ void MakeFieldsUniform(int64_t FieldBeginOffset, int64_t FieldEndOffset);
+
+ /** State of the object, array, or top-level field being written. */
+ struct WriterState
+ {
+ StateFlags Flags{};
+ /** The type of the fields in the sequence if uniform, otherwise None. */
+ CbFieldType UniformType{};
+ /** The offset of the start of the current field. */
+ int64_t Offset{};
+ /** The number of fields written in this state. */
+ uint64_t Count{};
+ };
+
+private:
+ // This is a prototype-quality format for the writer. Using an array of bytes is inefficient,
+ // and will lead to many unnecessary copies and moves of the data to resize the array, insert
+ // object and array sizes, and remove field types for uniform objects and uniform arrays. The
+ // optimized format will be a list of power-of-two blocks and an optional first block that is
+ // provided externally, such as on the stack. That format will store the offsets that require
+ // object or array sizes to be inserted and field types to be removed, and will perform those
+ // operations only when saving to a buffer.
+ std::vector<uint8_t> Data;
+ std::vector<WriterState> States;
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * A writer for compact binary object, arrays, and fields that uses a fixed-size stack buffer.
+ *
+ * @see CbWriter
+ */
+template<uint32_t InlineBufferSize>
+class FixedCbWriter : public CbWriter
+{
+public:
+ inline FixedCbWriter() : CbWriter(InlineBufferSize) {}
+
+ FixedCbWriter(const FixedCbWriter&) = delete;
+ FixedCbWriter& operator=(const FixedCbWriter&) = delete;
+
+private:
+ // Reserve the inline buffer now even though we are unable to use it. This will avoid causing
+ // new stack overflows when this functionality is properly implemented in the future.
+ uint8_t Buffer[InlineBufferSize];
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+class CbObjectWriter : public CbWriter
+{
+public:
+ CbObjectWriter() { BeginObject(); }
+
+ ZENCORE_API CbObject Save()
+ {
+ Finalize();
+ return CbWriter::Save().AsObject();
+ }
+
+ ZENCORE_API void Save(BinaryWriter& Writer)
+ {
+ Finalize();
+ return CbWriter::Save(Writer);
+ }
+
+ ZENCORE_API CbFieldViewIterator Save(MutableMemoryView Buffer)
+ {
+ ZEN_ASSERT(m_Finalized);
+ return CbWriter::Save(Buffer);
+ }
+
+ uint64_t GetSaveSize()
+ {
+ ZEN_ASSERT(m_Finalized);
+ return CbWriter::GetSaveSize();
+ }
+
+ void Finalize()
+ {
+ if (m_Finalized == false)
+ {
+ EndObject();
+ m_Finalized = true;
+ }
+ }
+
+ CbObjectWriter(const CbWriter&) = delete;
+ CbObjectWriter& operator=(const CbWriter&) = delete;
+
+private:
+ bool m_Finalized = false;
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/** Write the field name if valid in this state, otherwise write the string value. */
+inline CbWriter&
+operator<<(CbWriter& Writer, std::string_view NameOrValue)
+{
+ Writer.SetNameOrAddString(NameOrValue);
+ return Writer;
+}
+
+/** Write the field name if valid in this state, otherwise write the string value. */
+inline CbWriter&
+operator<<(CbWriter& Writer, const char* NameOrValue)
+{
+ return Writer << std::string_view(NameOrValue);
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const CbFieldView& Value)
+{
+ Writer.AddField(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const CbField& Value)
+{
+ Writer.AddField(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const CbObjectView& Value)
+{
+ Writer.AddObject(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const CbObject& Value)
+{
+ Writer.AddObject(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const CbArrayView& Value)
+{
+ Writer.AddArray(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const CbArray& Value)
+{
+ Writer.AddArray(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, std::nullptr_t)
+{
+ Writer.AddNull();
+ return Writer;
+}
+
+#if defined(__clang__) && defined(__APPLE__)
+/* Apple Clang has different types for uint64_t and size_t so an override is
+ needed here. Without it, Clang can't disambiguate integer overloads */
+inline CbWriter&
+operator<<(CbWriter& Writer, std::size_t Value)
+{
+ Writer.AddInteger(uint64_t(Value));
+ return Writer;
+}
+#endif
+
+inline CbWriter&
+operator<<(CbWriter& Writer, std::wstring_view Value)
+{
+ Writer.AddString(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const wchar_t* Value)
+{
+ Writer.AddString(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, int32_t Value)
+{
+ Writer.AddInteger(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, int64_t Value)
+{
+ Writer.AddInteger(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, uint32_t Value)
+{
+ Writer.AddInteger(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, uint64_t Value)
+{
+ Writer.AddInteger(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, float Value)
+{
+ Writer.AddFloat(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, double Value)
+{
+ Writer.AddFloat(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, bool Value)
+{
+ Writer.AddBool(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const CbAttachment& Attachment)
+{
+ Writer.AddAttachment(Attachment);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const IoHash& Value)
+{
+ Writer.AddHash(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const Guid& Value)
+{
+ Writer.AddUuid(Value);
+ return Writer;
+}
+
+inline CbWriter&
+operator<<(CbWriter& Writer, const Oid& Value)
+{
+ Writer.AddObjectId(Value);
+ return Writer;
+}
+
+ZENCORE_API CbWriter& operator<<(CbWriter& Writer, DateTime Value);
+ZENCORE_API CbWriter& operator<<(CbWriter& Writer, TimeSpan Value);
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+void usonbuilder_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/compactbinarypackage.h b/src/zencore/include/zencore/compactbinarypackage.h
new file mode 100644
index 000000000..16f723edc
--- /dev/null
+++ b/src/zencore/include/zencore/compactbinarypackage.h
@@ -0,0 +1,341 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compress.h>
+#include <zencore/iohash.h>
+
+#include <functional>
+#include <span>
+#include <variant>
+
+#ifdef GetObject
+# error "windows.h pollution"
+# undef GetObject
+#endif
+
+namespace zen {
+
+class CbWriter;
+class BinaryReader;
+class BinaryWriter;
+class IoBuffer;
+class CbAttachment;
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * An attachment is either binary or compact binary and is identified by its hash.
+ *
+ * A compact binary attachment is also a valid binary attachment and may be accessed as binary.
+ *
+ * Attachments are serialized as one or two compact binary fields with no name. A Binary field is
+ * written first with its content. The content hash is omitted when the content size is zero, and
+ * is otherwise written as a BinaryReference or CompactBinaryReference depending on the type.
+ */
+class CbAttachment
+{
+public:
+ /** Construct a null attachment. */
+ CbAttachment() = default;
+
+ /** Construct a compact binary attachment. Value is cloned if not owned. */
+ inline explicit CbAttachment(const CbObject& InValue) : CbAttachment(InValue, nullptr) {}
+
+ /** Construct a compact binary attachment. Value is cloned if not owned. Hash must match Value. */
+ inline explicit CbAttachment(const CbObject& InValue, const IoHash& Hash) : CbAttachment(InValue, &Hash) {}
+
+ /** Construct a raw binary attachment. Value is cloned if not owned. */
+ ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue);
+
+ /** Construct a raw binary attachment. Value is cloned if not owned. Hash must match Value. */
+ ZENCORE_API explicit CbAttachment(const SharedBuffer& InValue, const IoHash& Hash);
+
+ /** Construct a raw binary attachment. Value is cloned if not owned. */
+ ZENCORE_API explicit CbAttachment(const CompositeBuffer& InValue);
+
+ /** Construct a raw binary attachment. Value is cloned if not owned. */
+ ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue);
+
+ /** Construct a raw binary attachment. Value is cloned if not owned. */
+ ZENCORE_API explicit CbAttachment(CompositeBuffer&& InValue, const IoHash& Hash);
+
+ /** Construct a compressed binary attachment. Value is cloned if not owned. */
+ ZENCORE_API explicit CbAttachment(const CompressedBuffer& InValue, const IoHash& Hash);
+ ZENCORE_API explicit CbAttachment(CompressedBuffer&& InValue, const IoHash& Hash);
+
+ /** Reset this to a null attachment. */
+ inline void Reset() { *this = CbAttachment(); }
+
+ /** Whether the attachment has a value. */
+ inline explicit operator bool() const { return !IsNull(); }
+
+ /** Whether the attachment has a value. */
+ ZENCORE_API [[nodiscard]] bool IsNull() const;
+
+ /** Access the attachment as binary. Defaults to a null buffer on error. */
+ ZENCORE_API [[nodiscard]] SharedBuffer AsBinary() const;
+
+ /** Access the attachment as raw binary. Defaults to a null buffer on error. */
+ ZENCORE_API [[nodiscard]] CompositeBuffer AsCompositeBinary() const;
+
+ /** Access the attachment as compressed binary. Defaults to a null buffer if the attachment is null. */
+ ZENCORE_API [[nodiscard]] CompressedBuffer AsCompressedBinary() const;
+
+ /** Access the attachment as compact binary. Defaults to a field iterator with no value on error. */
+ ZENCORE_API [[nodiscard]] CbObject AsObject() const;
+
+ /** Returns true if the attachment is binary */
+ ZENCORE_API [[nodiscard]] bool IsBinary() const;
+
+ /** Returns true if the attachment is compressed binary */
+ ZENCORE_API [[nodiscard]] bool IsCompressedBinary() const;
+
+ /** Returns whether the attachment is an object. */
+ ZENCORE_API [[nodiscard]] bool IsObject() const;
+
+ /** Returns the hash of the attachment value. */
+ ZENCORE_API [[nodiscard]] IoHash GetHash() const;
+
+ /** Compares attachments by their hash. Any discrepancy in type must be handled externally. */
+ inline bool operator==(const CbAttachment& Attachment) const { return GetHash() == Attachment.GetHash(); }
+ inline bool operator!=(const CbAttachment& Attachment) const { return GetHash() != Attachment.GetHash(); }
+ inline bool operator<(const CbAttachment& Attachment) const { return GetHash() < Attachment.GetHash(); }
+
+ /**
+ * Load the attachment from compact binary as written by Save.
+ *
+ * The attachment references the input iterator if it is owned, and otherwise clones the value.
+ *
+ * The iterator is advanced as attachment fields are consumed from it.
+ */
+ ZENCORE_API bool TryLoad(CbFieldIterator& Fields);
+
+ /**
+ * Load the attachment from compact binary as written by Save.
+ */
+ ZENCORE_API bool TryLoad(BinaryReader& Reader, BufferAllocator Allocator = UniqueBuffer::Alloc);
+
+ /**
+ * Load the attachment from compact binary as written by Save.
+ */
+ ZENCORE_API bool TryLoad(IoBuffer& Buffer, BufferAllocator Allocator = UniqueBuffer::Alloc);
+
+ /** Save the attachment into the writer as a stream of compact binary fields. */
+ ZENCORE_API void Save(CbWriter& Writer) const;
+
+ /** Save the attachment into the writer as a stream of compact binary fields. */
+ ZENCORE_API void Save(BinaryWriter& Writer) const;
+
+private:
+ ZENCORE_API CbAttachment(const CbObject& Value, const IoHash* Hash);
+
+ IoHash Hash;
+ std::variant<std::nullptr_t, CbObject, CompositeBuffer, CompressedBuffer> Value;
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * A package is a compact binary object with attachments for its external references.
+ *
+ * A package is basically a Merkle tree with compact binary as its root and other non-leaf nodes,
+ * and either binary or compact binary as its leaf nodes. A node references its child nodes using
+ * BinaryHash or FieldHash fields in its compact binary representation.
+ *
+ * It is invalid for a package to include attachments that are not referenced by its object or by
+ * one of its referenced compact binary attachments. When attachments are added explicitly, it is
+ * the responsibility of the package creator to follow this requirement. Attachments that are not
+ * referenced may not survive a round-trip through certain storage systems.
+ *
+ * It is valid for a package to exclude referenced attachments, but then it is the responsibility
+ * of the package consumer to have a mechanism for resolving those references when necessary.
+ *
+ * A package is serialized as a sequence of compact binary fields with no name. The object may be
+ * both preceded and followed by attachments. The object itself is written as an Object field and
+ * followed by its hash in a CompactBinaryReference field when the object is non-empty. A package
+ * ends with a Null field. The canonical order of components is the object and its hash, followed
+ * by the attachments ordered by hash, followed by a Null field. It is valid for the a package to
+ * have its components serialized in any order, provided there is at most one object and the null
+ * field is written last.
+ */
+class CbPackage
+{
+public:
+ /**
+ * A function that resolves a hash to a buffer containing the data matching that hash.
+ *
+ * The resolver may return a null buffer to skip resolving an attachment for the hash.
+ */
+ using AttachmentResolver = std::function<SharedBuffer(const IoHash& Hash)>;
+
+ /** Construct a null package. */
+ CbPackage() = default;
+
+ /**
+ * Construct a package from a root object without gathering attachments.
+ *
+ * @param InObject The root object, which will be cloned unless it is owned.
+ */
+ inline explicit CbPackage(CbObject InObject) { SetObject(std::move(InObject)); }
+
+ /**
+ * Construct a package from a root object and gather attachments using the resolver.
+ *
+ * @param InObject The root object, which will be cloned unless it is owned.
+ * @param InResolver A function that is invoked for every reference and binary reference field.
+ */
+ inline explicit CbPackage(CbObject InObject, AttachmentResolver InResolver) { SetObject(std::move(InObject), InResolver); }
+
+ /**
+ * Construct a package from a root object without gathering attachments.
+ *
+ * @param InObject The root object, which will be cloned unless it is owned.
+ * @param InObjectHash The hash of the object, which must match to avoid validation errors.
+ */
+ inline explicit CbPackage(CbObject InObject, const IoHash& InObjectHash) { SetObject(std::move(InObject), InObjectHash); }
+
+ /**
+ * Construct a package from a root object and gather attachments using the resolver.
+ *
+ * @param InObject The root object, which will be cloned unless it is owned.
+ * @param InObjectHash The hash of the object, which must match to avoid validation errors.
+ * @param InResolver A function that is invoked for every reference and binary reference field.
+ */
+ inline explicit CbPackage(CbObject InObject, const IoHash& InObjectHash, AttachmentResolver InResolver)
+ {
+ SetObject(std::move(InObject), InObjectHash, InResolver);
+ }
+
+ /** Reset this to a null package. */
+ inline void Reset() { *this = CbPackage(); }
+
+ /** Whether the package has a non-empty object or attachments. */
+ inline explicit operator bool() const { return !IsNull(); }
+
+ /** Whether the package has an empty object and no attachments. */
+ inline bool IsNull() const { return !Object && Attachments.size() == 0; }
+
+ /** Returns the compact binary object for the package. */
+ inline const CbObject& GetObject() const { return Object; }
+
+ /** Returns the has of the compact binary object for the package. */
+ inline const IoHash& GetObjectHash() const { return ObjectHash; }
+
+ /**
+ * Set the root object without gathering attachments.
+ *
+ * @param InObject The root object, which will be cloned unless it is owned.
+ */
+ inline void SetObject(CbObject InObject) { SetObject(std::move(InObject), nullptr, nullptr); }
+
+ /**
+ * Set the root object and gather attachments using the resolver.
+ *
+ * @param InObject The root object, which will be cloned unless it is owned.
+ * @param InResolver A function that is invoked for every reference and binary reference field.
+ */
+ inline void SetObject(CbObject InObject, AttachmentResolver InResolver) { SetObject(std::move(InObject), nullptr, &InResolver); }
+
+ /**
+ * Set the root object without gathering attachments.
+ *
+ * @param InObject The root object, which will be cloned unless it is owned.
+ * @param InObjectHash The hash of the object, which must match to avoid validation errors.
+ */
+ inline void SetObject(CbObject InObject, const IoHash& InObjectHash) { SetObject(std::move(InObject), &InObjectHash, nullptr); }
+
+ /**
+ * Set the root object and gather attachments using the resolver.
+ *
+ * @param InObject The root object, which will be cloned unless it is owned.
+ * @param InObjectHash The hash of the object, which must match to avoid validation errors.
+ * @param InResolver A function that is invoked for every reference and binary reference field.
+ */
+ inline void SetObject(CbObject InObject, const IoHash& InObjectHash, AttachmentResolver InResolver)
+ {
+ SetObject(std::move(InObject), &InObjectHash, &InResolver);
+ }
+
+ /** Returns the attachments in this package. */
+ inline std::span<const CbAttachment> GetAttachments() const { return Attachments; }
+
+ /**
+ * Find an attachment by its hash.
+ *
+ * @return The attachment, or null if the attachment is not found.
+ * @note The returned pointer is only valid until the attachments on this package are modified.
+ */
+ ZENCORE_API const CbAttachment* FindAttachment(const IoHash& Hash) const;
+
+ /** Find an attachment if it exists in the package. */
+ inline const CbAttachment* FindAttachment(const CbAttachment& Attachment) const { return FindAttachment(Attachment.GetHash()); }
+
+ /** Add the attachment to this package. */
+ inline void AddAttachment(const CbAttachment& Attachment) { AddAttachment(Attachment, nullptr); }
+
+ /** Add the attachment to this package, along with any references that can be resolved. */
+ inline void AddAttachment(const CbAttachment& Attachment, AttachmentResolver Resolver) { AddAttachment(Attachment, &Resolver); }
+
+ void AddAttachments(std::span<const CbAttachment> Attachments);
+
+ /**
+ * Remove an attachment by hash.
+ *
+ * @return Number of attachments removed, which will be either 0 or 1.
+ */
+ ZENCORE_API int32_t RemoveAttachment(const IoHash& Hash);
+ inline int32_t RemoveAttachment(const CbAttachment& Attachment) { return RemoveAttachment(Attachment.GetHash()); }
+
+ /** Compares packages by their object and attachment hashes. */
+ ZENCORE_API bool Equals(const CbPackage& Package) const;
+ inline bool operator==(const CbPackage& Package) const { return Equals(Package); }
+ inline bool operator!=(const CbPackage& Package) const { return !Equals(Package); }
+
+ /**
+ * Load the object and attachments from compact binary as written by Save.
+ *
+ * The object and attachments reference the input iterator, if it is owned, and otherwise clones
+ * the object and attachments individually to make owned copies.
+ *
+ * The iterator is advanced as object and attachment fields are consumed from it.
+ */
+ ZENCORE_API bool TryLoad(CbFieldIterator& Fields);
+ ZENCORE_API bool TryLoad(IoBuffer Buffer, BufferAllocator Allocator = UniqueBuffer::Alloc, AttachmentResolver* Mapper = nullptr);
+ ZENCORE_API bool TryLoad(BinaryReader& Reader, BufferAllocator Allocator = UniqueBuffer::Alloc, AttachmentResolver* Mapper = nullptr);
+
+ /** Save the object and attachments into the writer as a stream of compact binary fields. */
+ ZENCORE_API void Save(CbWriter& Writer) const;
+
+ /** Save the object and attachments into the writer as a stream of compact binary fields. */
+ ZENCORE_API void Save(BinaryWriter& Writer) const;
+
+private:
+ ZENCORE_API void SetObject(CbObject Object, const IoHash* Hash, AttachmentResolver* Resolver);
+ ZENCORE_API void AddAttachment(const CbAttachment& Attachment, AttachmentResolver* Resolver);
+
+ void GatherAttachments(const CbObject& Object, AttachmentResolver Resolver);
+
+ /** Attachments ordered by their hash. */
+ std::vector<CbAttachment> Attachments;
+ CbObject Object;
+ IoHash ObjectHash;
+};
+
+namespace legacy {
+ void SaveCbAttachment(const CbAttachment& Attachment, CbWriter& Writer);
+ void SaveCbPackage(const CbPackage& Package, CbWriter& Writer);
+ void SaveCbPackage(const CbPackage& Package, BinaryWriter& Ar);
+ bool TryLoadCbPackage(CbPackage& Package, IoBuffer Buffer, BufferAllocator Allocator, CbPackage::AttachmentResolver* Mapper = nullptr);
+ bool TryLoadCbPackage(CbPackage& Package,
+ BinaryReader& Reader,
+ BufferAllocator Allocator,
+ CbPackage::AttachmentResolver* Mapper = nullptr);
+} // namespace legacy
+
+void usonpackage_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/compactbinaryvalidation.h b/src/zencore/include/zencore/compactbinaryvalidation.h
new file mode 100644
index 000000000..b1fab9572
--- /dev/null
+++ b/src/zencore/include/zencore/compactbinaryvalidation.h
@@ -0,0 +1,197 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/enumflags.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/refcount.h>
+#include <zencore/sha1.h>
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+/** Flags for validating compact binary data. */
+enum class CbValidateMode : uint32_t
+{
+ /** Skip validation if no other validation modes are enabled. */
+ None = 0,
+
+ /**
+ * Validate that the value can be read and stays inside the bounds of the memory view.
+ *
+ * This is the minimum level of validation required to be able to safely read a field, array,
+ * or object without the risk of crashing or reading out of bounds.
+ */
+ Default = 1 << 0,
+
+ /**
+ * Validate that object fields have unique non-empty names and array fields have no names.
+ *
+ * Name validation failures typically do not inhibit reading the input, but duplicated fields
+ * cannot be looked up by name other than the first, and converting to other data formats can
+ * fail in the presence of naming issues.
+ */
+ Names = 1 << 1,
+
+ /**
+ * Validate that fields are serialized in the canonical format.
+ *
+ * Format validation failures typically do not inhibit reading the input. Values that fail in
+ * this mode require more memory than in the canonical format, and comparisons of such values
+ * for equality are not reliable. Examples of failures include uniform arrays or objects that
+ * were not encoded uniformly, variable-length integers that could be encoded in fewer bytes,
+ * or 64-bit floats that could be encoded in 32 bits without loss of precision.
+ */
+ Format = 1 << 2,
+
+ /**
+ * Validate that there is no padding after the value before the end of the memory view.
+ *
+ * Padding validation failures have no impact on the ability to read the input, but are using
+ * more memory than necessary.
+ */
+ Padding = 1 << 3,
+
+ /**
+ * Validate that a package or attachment has the expected fields.
+ */
+ Package = 1 << 4,
+
+ /**
+ * Validate that a package or attachment matches its saved hashes.
+ */
+ PackageHash = 1 << 5,
+
+ /** Perform all validation described above. */
+ All = Default | Names | Format | Padding | Package | PackageHash,
+};
+
+ENUM_CLASS_FLAGS(CbValidateMode);
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/** Flags for compact binary validation errors. Multiple flags may be combined. */
+enum class CbValidateError : uint32_t
+{
+ /** The input had no validation errors. */
+ None = 0,
+
+ // Mode: Default
+
+ /** The input cannot be read without reading out of bounds. */
+ OutOfBounds = 1 << 0,
+ /** The input has a field with an unrecognized or invalid type. */
+ InvalidType = 1 << 1,
+
+ // Mode: Names
+
+ /** An object had more than one field with the same name. */
+ DuplicateName = 1 << 2,
+ /** An object had a field with no name. */
+ MissingName = 1 << 3,
+ /** An array field had a name. */
+ ArrayName = 1 << 4,
+
+ // Mode: Format
+
+ /** A name or string payload is not valid UTF-8. */
+ InvalidString = 1 << 5,
+ /** A size or integer payload can be encoded in fewer bytes. */
+ InvalidInteger = 1 << 6,
+ /** A float64 payload can be encoded as a float32 without loss of precision. */
+ InvalidFloat = 1 << 7,
+ /** An object has the same type for every field but is not uniform. */
+ NonUniformObject = 1 << 8,
+ /** An array has the same type for every field and non-empty payloads but is not uniform. */
+ NonUniformArray = 1 << 9,
+
+ // Mode: Padding
+
+ /** A value did not use the entire memory view given for validation. */
+ Padding = 1 << 10,
+
+ // Mode: Package
+
+ /** The package or attachment had missing fields or fields out of order. */
+ InvalidPackageFormat = 1 << 11,
+ /** The object or an attachment did not match the hash stored for it. */
+ InvalidPackageHash = 1 << 12,
+ /** The package contained more than one copy of the same attachment. */
+ DuplicateAttachments = 1 << 13,
+ /** The package contained more than one object. */
+ MultiplePackageObjects = 1 << 14,
+ /** The package contained an object with no fields. */
+ NullPackageObject = 1 << 15,
+ /** The package contained a null attachment. */
+ NullPackageAttachment = 1 << 16,
+};
+
+ENUM_CLASS_FLAGS(CbValidateError);
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+ * Validate the compact binary data for one field in the view as specified by the mode flags.
+ *
+ * Only one top-level field is processed from the view, and validation recurses into any array or
+ * object within that field. To validate multiple consecutive top-level fields, call the function
+ * once for each top-level field. If the given view might contain multiple top-level fields, then
+ * either exclude the Padding flag from the Mode or use MeasureCompactBinary to break up the view
+ * into its constituent fields before validating.
+ *
+ * @param View A memory view containing at least one top-level field.
+ * @param Mode A combination of the flags for the types of validation to perform.
+ * @param Type HasFieldType means that View contains the type. Otherwise, use the given type.
+ * @return None on success, otherwise the flags for the types of errors that were detected.
+ */
+ZENCORE_API CbValidateError ValidateCompactBinary(MemoryView View, CbValidateMode Mode, CbFieldType Type = CbFieldType::HasFieldType);
+
+/**
+ * Validate the compact binary data for every field in the view as specified by the mode flags.
+ *
+ * This function expects the entire view to contain fields. Any trailing region of the view which
+ * does not contain a valid field will produce an OutOfBounds or InvalidType error instead of the
+ * Padding error that would be produced by the single field validation function.
+ *
+ * @see ValidateCompactBinary
+ */
+ZENCORE_API CbValidateError ValidateCompactBinaryRange(MemoryView View, CbValidateMode Mode);
+
+/**
+ * Validate the compact binary attachment pointed to by the view as specified by the mode flags.
+ *
+ * The attachment is validated with ValidateCompactBinary by using the validation mode specified.
+ * Include ECbValidateMode::Package to validate the attachment format and hash.
+ *
+ * @see ValidateCompactBinary
+ *
+ * @param View A memory view containing a package.
+ * @param Mode A combination of the flags for the types of validation to perform.
+ * @return None on success, otherwise the flags for the types of errors that were detected.
+ */
+ZENCORE_API CbValidateError ValidateObjectAttachment(MemoryView View, CbValidateMode Mode);
+
+/**
+ * Validate the compact binary package pointed to by the view as specified by the mode flags.
+ *
+ * The package, and attachments, are validated with ValidateCompactBinary by using the validation
+ * mode specified. Include ECbValidateMode::Package to validate the package format and hashes.
+ *
+ * @see ValidateCompactBinary
+ *
+ * @param View A memory view containing a package.
+ * @param Mode A combination of the flags for the types of validation to perform.
+ * @return None on success, otherwise the flags for the types of errors that were detected.
+ */
+ZENCORE_API CbValidateError ValidateCompactBinaryPackage(MemoryView View, CbValidateMode Mode);
+
+///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+void usonvalidation_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/compactbinaryvalue.h b/src/zencore/include/zencore/compactbinaryvalue.h
new file mode 100644
index 000000000..0124a8983
--- /dev/null
+++ b/src/zencore/include/zencore/compactbinaryvalue.h
@@ -0,0 +1,290 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/endian.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/memory.h>
+
+namespace zen {
+
+namespace CompactBinaryPrivate {
+
+ template<typename T>
+ static constexpr inline T ReadUnaligned(const void* const Memory)
+ {
+#if ZEN_PLATFORM_SUPPORTS_UNALIGNED_LOADS
+ return *static_cast<const T*>(Memory);
+#else
+ T Value;
+ memcpy(&Value, Memory, sizeof(Value));
+ return Value;
+#endif
+ }
+} // namespace CompactBinaryPrivate
+/**
+ * A type that provides unchecked access to compact binary values.
+ *
+ * The main purpose of the type is to efficiently switch on field type. For every other use case,
+ * prefer to use the field, array, and object types directly. The accessors here do not check the
+ * type before reading the value, which means they can read out of bounds even on a valid compact
+ * binary value if the wrong accessor is used.
+ */
+class CbValue
+{
+public:
+ CbValue(CbFieldType Type, const void* Value);
+
+ CbObjectView AsObjectView() const;
+ CbArrayView AsArrayView() const;
+
+ MemoryView AsBinary() const;
+
+ /** Access as a string. Checks for range errors and uses the default if OutError is not null. */
+ std::string_view AsString(CbFieldError* OutError = nullptr, std::string_view Default = std::string_view()) const;
+
+ /** Access as a string as UTF8. Checks for range errors and uses the default if OutError is not null. */
+ std::u8string_view AsU8String(CbFieldError* OutError = nullptr, std::u8string_view Default = std::u8string_view()) const;
+
+ /**
+ * Access as an integer, with both positive and negative values returned as unsigned.
+ *
+ * Checks for range errors and uses the default if OutError is not null.
+ */
+ uint64_t AsInteger(CompactBinaryPrivate::IntegerParams Params, CbFieldError* OutError = nullptr, uint64_t Default = 0) const;
+
+ uint64_t AsIntegerPositive() const;
+ int64_t AsIntegerNegative() const;
+
+ float AsFloat32() const;
+ double AsFloat64() const;
+
+ bool AsBool() const;
+
+ inline IoHash AsObjectAttachment() const { return AsHash(); }
+ inline IoHash AsBinaryAttachment() const { return AsHash(); }
+ inline IoHash AsAttachment() const { return AsHash(); }
+
+ IoHash AsHash() const;
+ Guid AsUuid() const;
+
+ int64_t AsDateTimeTicks() const;
+ int64_t AsTimeSpanTicks() const;
+
+ Oid AsObjectId() const;
+
+ CbCustomById AsCustomById() const;
+ CbCustomByName AsCustomByName() const;
+
+ inline CbFieldType GetType() const { return Type; }
+ inline const void* GetData() const { return Data; }
+
+private:
+ const void* Data;
+ CbFieldType Type;
+};
+
+inline CbFieldView::CbFieldView(const CbValue& InValue) : Type(InValue.GetType()), Payload(InValue.GetData())
+{
+}
+
+inline CbValue
+CbFieldView::GetValue() const
+{
+ return CbValue(CbFieldTypeOps::GetType(Type), Payload);
+}
+
+inline CbValue::CbValue(CbFieldType InType, const void* InValue) : Data(InValue), Type(InType)
+{
+}
+
+inline CbObjectView
+CbValue::AsObjectView() const
+{
+ return CbObjectView(*this);
+}
+
+inline CbArrayView
+CbValue::AsArrayView() const
+{
+ return CbArrayView(*this);
+}
+
+inline MemoryView
+CbValue::AsBinary() const
+{
+ const uint8_t* const Bytes = static_cast<const uint8_t*>(Data);
+ uint32_t ValueSizeByteCount;
+ const uint64_t ValueSize = ReadVarUInt(Bytes, ValueSizeByteCount);
+ return MakeMemoryView(Bytes + ValueSizeByteCount, ValueSize);
+}
+
+inline std::string_view
+CbValue::AsString(CbFieldError* OutError, std::string_view Default) const
+{
+ const char* const Chars = static_cast<const char*>(Data);
+ uint32_t ValueSizeByteCount;
+ const uint64_t ValueSize = ReadVarUInt(Chars, ValueSizeByteCount);
+
+ if (OutError)
+ {
+ if (ValueSize >= (uint64_t(1) << 31))
+ {
+ *OutError = CbFieldError::RangeError;
+ return Default;
+ }
+ *OutError = CbFieldError::None;
+ }
+
+ return std::string_view(Chars + ValueSizeByteCount, int32_t(ValueSize));
+}
+
+inline std::u8string_view
+CbValue::AsU8String(CbFieldError* OutError, std::u8string_view Default) const
+{
+ const char8_t* const Chars = static_cast<const char8_t*>(Data);
+ uint32_t ValueSizeByteCount;
+ const uint64_t ValueSize = ReadVarUInt(Chars, ValueSizeByteCount);
+
+ if (OutError)
+ {
+ if (ValueSize >= (uint64_t(1) << 31))
+ {
+ *OutError = CbFieldError::RangeError;
+ return Default;
+ }
+ *OutError = CbFieldError::None;
+ }
+
+ return std::u8string_view(Chars + ValueSizeByteCount, int32_t(ValueSize));
+}
+
+inline uint64_t
+CbValue::AsInteger(CompactBinaryPrivate::IntegerParams Params, CbFieldError* OutError, uint64_t Default) const
+{
+ // A shift of a 64-bit value by 64 is undefined so shift by one less because magnitude is never zero.
+ const uint64_t OutOfRangeMask = uint64_t(-2) << (Params.MagnitudeBits - 1);
+ const uint64_t IsNegative = uint8_t(Type) & 1;
+
+ uint32_t MagnitudeByteCount;
+ const uint64_t Magnitude = ReadVarUInt(Data, MagnitudeByteCount);
+ const uint64_t Value = Magnitude ^ -int64_t(IsNegative);
+
+ if (OutError)
+ {
+ const uint64_t IsInRange = (!(Magnitude & OutOfRangeMask)) & ((!IsNegative) | Params.IsSigned);
+ *OutError = IsInRange ? CbFieldError::None : CbFieldError::RangeError;
+
+ const uint64_t UseValueMask = -int64_t(IsInRange);
+ return (Value & UseValueMask) | (Default & ~UseValueMask);
+ }
+
+ return Value;
+}
+
+inline uint64_t
+CbValue::AsIntegerPositive() const
+{
+ uint32_t MagnitudeByteCount;
+ return ReadVarUInt(Data, MagnitudeByteCount);
+}
+
+inline int64_t
+CbValue::AsIntegerNegative() const
+{
+ uint32_t MagnitudeByteCount;
+ return int64_t(ReadVarUInt(Data, MagnitudeByteCount)) ^ -int64_t(1);
+}
+
+inline float
+CbValue::AsFloat32() const
+{
+ const uint32_t Value = FromNetworkOrder(CompactBinaryPrivate::ReadUnaligned<uint32_t>(Data));
+ return reinterpret_cast<const float&>(Value);
+}
+
+inline double
+CbValue::AsFloat64() const
+{
+ const uint64_t Value = FromNetworkOrder(CompactBinaryPrivate::ReadUnaligned<uint64_t>(Data));
+ return reinterpret_cast<const double&>(Value);
+}
+
+inline bool
+CbValue::AsBool() const
+{
+ return uint8_t(Type) & 1;
+}
+
+inline IoHash
+CbValue::AsHash() const
+{
+ return IoHash::MakeFrom(Data);
+}
+
+inline Guid
+CbValue::AsUuid() const
+{
+ Guid Value;
+ memcpy(&Value, Data, sizeof(Guid));
+ Value.A = FromNetworkOrder(Value.A);
+ Value.B = FromNetworkOrder(Value.B);
+ Value.C = FromNetworkOrder(Value.C);
+ Value.D = FromNetworkOrder(Value.D);
+ return Value;
+}
+
+inline int64_t
+CbValue::AsDateTimeTicks() const
+{
+ return FromNetworkOrder(CompactBinaryPrivate::ReadUnaligned<int64_t>(Data));
+}
+
+inline int64_t
+CbValue::AsTimeSpanTicks() const
+{
+ return FromNetworkOrder(CompactBinaryPrivate::ReadUnaligned<int64_t>(Data));
+}
+
+inline Oid
+CbValue::AsObjectId() const
+{
+ return Oid::FromMemory(Data);
+}
+
+inline CbCustomById
+CbValue::AsCustomById() const
+{
+ const uint8_t* Bytes = static_cast<const uint8_t*>(Data);
+ uint32_t DataSizeByteCount;
+ const uint64_t DataSize = ReadVarUInt(Bytes, DataSizeByteCount);
+ Bytes += DataSizeByteCount;
+
+ CbCustomById Value;
+ uint32_t TypeIdByteCount;
+ Value.Id = ReadVarUInt(Bytes, TypeIdByteCount);
+ Value.Data = MakeMemoryView(Bytes + TypeIdByteCount, DataSize - TypeIdByteCount);
+ return Value;
+}
+
+inline CbCustomByName
+CbValue::AsCustomByName() const
+{
+ const uint8_t* Bytes = static_cast<const uint8_t*>(Data);
+ uint32_t DataSizeByteCount;
+ const uint64_t DataSize = ReadVarUInt(Bytes, DataSizeByteCount);
+ Bytes += DataSizeByteCount;
+
+ uint32_t TypeNameLenByteCount;
+ const uint64_t TypeNameLen = ReadVarUInt(Bytes, TypeNameLenByteCount);
+ Bytes += TypeNameLenByteCount;
+
+ CbCustomByName Value;
+ Value.Name = std::u8string_view(reinterpret_cast<const char8_t*>(Bytes), static_cast<std::u8string_view::size_type>(TypeNameLen));
+ Value.Data = MakeMemoryView(Bytes + TypeNameLen, DataSize - TypeNameLen - TypeNameLenByteCount);
+ return Value;
+}
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/compositebuffer.h b/src/zencore/include/zencore/compositebuffer.h
new file mode 100644
index 000000000..4e4b4d002
--- /dev/null
+++ b/src/zencore/include/zencore/compositebuffer.h
@@ -0,0 +1,142 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/sharedbuffer.h>
+#include <zencore/zencore.h>
+
+#include <functional>
+#include <span>
+#include <vector>
+
+namespace zen {
+
+/**
+ * CompositeBuffer is a non-contiguous buffer composed of zero or more immutable shared buffers.
+ *
+ * A composite buffer is most efficient when its segments are consumed as they are, but it can be
+ * flattened into a contiguous buffer, when necessary, by calling Flatten(). Ownership of segment
+ * buffers is not changed on construction, but if ownership of segments is required then that can
+ * be guaranteed by calling MakeOwned().
+ */
+
+class CompositeBuffer
+{
+public:
+ /**
+ * Construct a composite buffer by concatenating the buffers. Does not enforce ownership.
+ *
+ * Buffer parameters may be SharedBuffer, CompositeBuffer, or std::vector<SharedBuffer>.
+ */
+ template<typename... BufferTypes>
+ inline explicit CompositeBuffer(BufferTypes&&... Buffers)
+ {
+ if constexpr (sizeof...(Buffers) > 0)
+ {
+ m_Segments.reserve((GetBufferCount(std::forward<BufferTypes>(Buffers)) + ...));
+ (AppendBuffers(std::forward<BufferTypes>(Buffers)), ...);
+ std::erase_if(m_Segments, [](const SharedBuffer& It) { return It.IsNull(); });
+ }
+ }
+
+ /** Reset this to null. */
+ ZENCORE_API void Reset();
+
+ /** Returns the total size of the composite buffer in bytes. */
+ [[nodiscard]] ZENCORE_API uint64_t GetSize() const;
+
+ /** Returns the segments that the buffer is composed from. */
+ [[nodiscard]] inline std::span<const SharedBuffer> GetSegments() const { return std::span<const SharedBuffer>{m_Segments}; }
+
+ /** Returns true if the composite buffer is not null. */
+ [[nodiscard]] inline explicit operator bool() const { return !IsNull(); }
+
+ /** Returns true if the composite buffer is null. */
+ [[nodiscard]] inline bool IsNull() const { return m_Segments.empty(); }
+
+ /** Returns true if every segment in the composite buffer is owned. */
+ [[nodiscard]] ZENCORE_API bool IsOwned() const;
+
+ /** Returns a copy of the buffer where every segment is owned. */
+ [[nodiscard]] ZENCORE_API CompositeBuffer MakeOwned() const&;
+ [[nodiscard]] ZENCORE_API CompositeBuffer MakeOwned() &&;
+
+ /** Returns the concatenation of the segments into a contiguous buffer. */
+ [[nodiscard]] ZENCORE_API SharedBuffer Flatten() const&;
+ [[nodiscard]] ZENCORE_API SharedBuffer Flatten() &&;
+
+ /** Returns the middle part of the buffer by taking the size starting at the offset. */
+ [[nodiscard]] ZENCORE_API CompositeBuffer Mid(uint64_t Offset, uint64_t Size = ~uint64_t(0)) const;
+
+ /**
+ * Returns a view of the range if contained by one segment, otherwise a view of a copy of the range.
+ *
+ * @note CopyBuffer is reused if large enough, and otherwise allocated when needed.
+ *
+ * @param Offset The byte offset in this buffer that the range starts at.
+ * @param Size The number of bytes in the range to view or copy.
+ * @param CopyBuffer The buffer to write the copy into if a copy is required.
+ */
+ [[nodiscard]] ZENCORE_API MemoryView ViewOrCopyRange(uint64_t Offset, uint64_t Size, UniqueBuffer& CopyBuffer) const;
+
+ /**
+ * Copies a range of the buffer to a contiguous region of memory.
+ *
+ * @param Target The view to copy to. Must be no larger than the data available at the offset.
+ * @param Offset The byte offset in this buffer to start copying from.
+ */
+ ZENCORE_API void CopyTo(MutableMemoryView Target, uint64_t Offset = 0) const;
+
+ /**
+ * Invokes a visitor with a view of each segment that intersects with a range.
+ *
+ * @param Offset The byte offset in this buffer to start visiting from.
+ * @param Size The number of bytes in the range to visit.
+ * @param Visitor The visitor to invoke from zero to GetSegments().Num() times.
+ */
+ ZENCORE_API void IterateRange(uint64_t Offset, uint64_t Size, std::function<void(MemoryView View)> Visitor) const;
+ ZENCORE_API void IterateRange(uint64_t Offset,
+ uint64_t Size,
+ std::function<void(MemoryView View, const SharedBuffer& ViewOuter)> Visitor) const;
+
+ struct Iterator
+ {
+ size_t SegmentIndex = 0;
+ uint64_t OffsetInSegment = 0;
+ };
+ ZENCORE_API Iterator GetIterator(uint64_t Offset) const;
+ ZENCORE_API MemoryView ViewOrCopyRange(Iterator& It, uint64_t Size, UniqueBuffer& CopyBuffer) const;
+ ZENCORE_API void CopyTo(MutableMemoryView Target, Iterator& It) const;
+
+ /** A null composite buffer. */
+ static const CompositeBuffer Null;
+
+private:
+ static inline size_t GetBufferCount(const CompositeBuffer& Buffer) { return Buffer.m_Segments.size(); }
+ inline void AppendBuffers(const CompositeBuffer& Buffer)
+ {
+ m_Segments.insert(m_Segments.end(), begin(Buffer.m_Segments), end(Buffer.m_Segments));
+ }
+ inline void AppendBuffers(CompositeBuffer&& Buffer)
+ {
+ // TODO: this operates just like the by-reference version above
+ m_Segments.insert(m_Segments.end(), begin(Buffer.m_Segments), end(Buffer.m_Segments));
+ }
+
+ static inline size_t GetBufferCount(const SharedBuffer&) { return 1; }
+ inline void AppendBuffers(const SharedBuffer& Buffer) { m_Segments.push_back(Buffer); }
+ inline void AppendBuffers(SharedBuffer&& Buffer) { m_Segments.push_back(std::move(Buffer)); }
+
+ static inline size_t GetBufferCount(std::vector<SharedBuffer>&& Container) { return Container.size(); }
+ inline void AppendBuffers(std::vector<SharedBuffer>&& Container)
+ {
+ m_Segments.insert(m_Segments.end(), begin(Container), end(Container));
+ }
+
+private:
+ std::vector<SharedBuffer> m_Segments;
+};
+
+void compositebuffer_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/compress.h b/src/zencore/include/zencore/compress.h
new file mode 100644
index 000000000..99ce20d8a
--- /dev/null
+++ b/src/zencore/include/zencore/compress.h
@@ -0,0 +1,165 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore/zencore.h"
+
+#include "zencore/blake3.h"
+#include "zencore/compositebuffer.h"
+
+namespace zen {
+
+enum class OodleCompressor : uint8_t
+{
+ NotSet = 0,
+ Selkie = 1,
+ Mermaid = 2,
+ Kraken = 3,
+ Leviathan = 4,
+};
+
+enum class OodleCompressionLevel : int8_t
+{
+ HyperFast4 = -4,
+ HyperFast3 = -3,
+ HyperFast2 = -2,
+ HyperFast1 = -1,
+ None = 0,
+ SuperFast = 1,
+ VeryFast = 2,
+ Fast = 3,
+ Normal = 4,
+ Optimal1 = 5,
+ Optimal2 = 6,
+ Optimal3 = 7,
+ Optimal4 = 8,
+};
+
+/**
+ * A compressed buffer stores compressed data in a self-contained format.
+ *
+ * A buffer is self-contained in the sense that it can be decompressed without external knowledge
+ * of the compression format or the size of the raw data.
+ */
+class CompressedBuffer
+{
+public:
+ /**
+ * Compress the buffer using the specified compressor and compression level.
+ *
+ * Data that does not compress will be return uncompressed, as if with level None.
+ *
+ * @note Using a level of None will return a buffer that references owned raw data.
+ *
+ * @param RawData The raw data to be compressed.
+ * @param Compressor The compressor to encode with. May use NotSet if level is None.
+ * @param CompressionLevel The compression level to encode with.
+ * @param BlockSize The power-of-two block size to encode raw data in. 0 is default.
+ * @return An owned compressed buffer, or null on error.
+ */
+ [[nodiscard]] ZENCORE_API static CompressedBuffer Compress(const CompositeBuffer& RawData,
+ OodleCompressor Compressor = OodleCompressor::Mermaid,
+ OodleCompressionLevel CompressionLevel = OodleCompressionLevel::VeryFast,
+ uint64_t BlockSize = 0);
+ [[nodiscard]] ZENCORE_API static CompressedBuffer Compress(const SharedBuffer& RawData,
+ OodleCompressor Compressor = OodleCompressor::Mermaid,
+ OodleCompressionLevel CompressionLevel = OodleCompressionLevel::VeryFast,
+ uint64_t BlockSize = 0);
+
+ /**
+ * Construct from a compressed buffer previously created by Compress().
+ *
+ * @return A compressed buffer, or null on error, such as an invalid format or corrupt header.
+ */
+ [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressed(const CompositeBuffer& CompressedData,
+ IoHash& OutRawHash,
+ uint64_t& OutRawSize);
+ [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressed(CompositeBuffer&& CompressedData,
+ IoHash& OutRawHash,
+ uint64_t& OutRawSize);
+ [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressed(const SharedBuffer& CompressedData,
+ IoHash& OutRawHash,
+ uint64_t& OutRawSize);
+ [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressed(SharedBuffer&& CompressedData,
+ IoHash& OutRawHash,
+ uint64_t& OutRawSize);
+ [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressedNoValidate(IoBuffer&& CompressedData);
+ [[nodiscard]] ZENCORE_API static CompressedBuffer FromCompressedNoValidate(CompositeBuffer&& CompressedData);
+ [[nodiscard]] ZENCORE_API static bool ValidateCompressedHeader(IoBuffer&& CompressedData, IoHash& OutRawHash, uint64_t& OutRawSize);
+ [[nodiscard]] ZENCORE_API static bool ValidateCompressedHeader(const IoBuffer& CompressedData,
+ IoHash& OutRawHash,
+ uint64_t& OutRawSize);
+
+ /** Reset this to null. */
+ inline void Reset() { CompressedData.Reset(); }
+
+ /** Returns true if the compressed buffer is not null. */
+ [[nodiscard]] inline explicit operator bool() const { return !IsNull(); }
+
+ /** Returns true if the compressed buffer is null. */
+ [[nodiscard]] inline bool IsNull() const { return CompressedData.IsNull(); }
+
+ /** Returns true if the composite buffer is owned. */
+ [[nodiscard]] inline bool IsOwned() const { return CompressedData.IsOwned(); }
+
+ /** Returns a copy of the compressed buffer that owns its underlying memory. */
+ [[nodiscard]] inline CompressedBuffer MakeOwned() const& { return FromCompressedNoValidate(CompressedData.MakeOwned()); }
+ [[nodiscard]] inline CompressedBuffer MakeOwned() && { return FromCompressedNoValidate(std::move(CompressedData).MakeOwned()); }
+
+ /** Returns a composite buffer containing the compressed data. May be null. May not be owned. */
+ [[nodiscard]] inline const CompositeBuffer& GetCompressed() const& { return CompressedData; }
+ [[nodiscard]] inline CompositeBuffer GetCompressed() && { return std::move(CompressedData); }
+
+ /** Returns the size of the compressed data. Zero if this is null. */
+ [[nodiscard]] inline uint64_t GetCompressedSize() const { return CompressedData.GetSize(); }
+
+ /** Returns the size of the raw data. Zero on error or if this is empty or null. */
+ [[nodiscard]] ZENCORE_API uint64_t DecodeRawSize() const;
+
+ /** Returns the hash of the raw data. Zero on error or if this is null. */
+ [[nodiscard]] ZENCORE_API IoHash DecodeRawHash() const;
+
+ [[nodiscard]] ZENCORE_API CompressedBuffer CopyRange(uint64_t RawOffset, uint64_t RawSize = ~uint64_t(0)) const;
+
+ /**
+ * Returns the compressor and compression level used by this buffer.
+ *
+ * The compressor and compression level may differ from those specified when creating the buffer
+ * because an incompressible buffer is stored with no compression. Parameters cannot be accessed
+ * if this is null or uses a method other than Oodle, in which case this returns false.
+ *
+ * @return True if parameters were written, otherwise false.
+ */
+ [[nodiscard]] ZENCORE_API bool TryGetCompressParameters(OodleCompressor& OutCompressor,
+ OodleCompressionLevel& OutCompressionLevel,
+ uint64_t& OutBlockSize) const;
+
+ /**
+ * Decompress into a memory view that is less or equal GetRawSize() bytes.
+ */
+ [[nodiscard]] ZENCORE_API bool TryDecompressTo(MutableMemoryView RawView, uint64_t RawOffset = 0) const;
+
+ /**
+ * Decompress into an owned buffer.
+ *
+ * @return An owned buffer containing the raw data, or null on error.
+ */
+ [[nodiscard]] ZENCORE_API SharedBuffer Decompress(uint64_t RawOffset = 0, uint64_t RawSize = ~uint64_t(0)) const;
+
+ /**
+ * Decompress into an owned composite buffer.
+ *
+ * @return An owned buffer containing the raw data, or null on error.
+ */
+ [[nodiscard]] ZENCORE_API CompositeBuffer DecompressToComposite() const;
+
+ /** A null compressed buffer. */
+ static const CompressedBuffer Null;
+
+private:
+ CompositeBuffer CompressedData;
+};
+
+void compress_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/config.h.in b/src/zencore/include/zencore/config.h.in
new file mode 100644
index 000000000..3372eca2a
--- /dev/null
+++ b/src/zencore/include/zencore/config.h.in
@@ -0,0 +1,16 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+// NOTE: Generated from config.h.in
+
+#define ZEN_CFG_VERSION "${VERSION}"
+#define ZEN_CFG_VERSION_MAJOR ${VERSION_MAJOR}
+#define ZEN_CFG_VERSION_MINOR ${VERSION_MINOR}
+#define ZEN_CFG_VERSION_ALTER ${VERSION_ALTER}
+#define ZEN_CFG_VERSION_BUILD ${VERSION_BUILD}
+#define ZEN_CFG_VERSION_BRANCH "${GIT_BRANCH}"
+#define ZEN_CFG_VERSION_COMMIT "${GIT_COMMIT}"
+#define ZEN_CFG_VERSION_BUILD_STRING "${VERSION}-${plat}-${arch}-${mode}"
+#define ZEN_CFG_VERSION_BUILD_STRING_FULL "${VERSION}-${VERSION_BUILD}-${plat}-${arch}-${mode}-${GIT_COMMIT}"
+#define ZEN_CFG_SCHEMA_VERSION ${ZEN_SCHEMA_VERSION}
diff --git a/src/zencore/include/zencore/crc32.h b/src/zencore/include/zencore/crc32.h
new file mode 100644
index 000000000..336bda77e
--- /dev/null
+++ b/src/zencore/include/zencore/crc32.h
@@ -0,0 +1,13 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+namespace zen {
+
+uint32_t MemCrc32(const void* InData, size_t Length, uint32_t Crc = 0);
+uint32_t MemCrc32_Deprecated(const void* InData, size_t Length, uint32_t Crc = 0);
+uint32_t StrCrc_Deprecated(const char* Data);
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/crypto.h b/src/zencore/include/zencore/crypto.h
new file mode 100644
index 000000000..83d416b0f
--- /dev/null
+++ b/src/zencore/include/zencore/crypto.h
@@ -0,0 +1,77 @@
+
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/memory.h>
+#include <zencore/zencore.h>
+
+#include <memory>
+#include <optional>
+
+namespace zen {
+
+template<size_t BitCount>
+struct CryptoBits
+{
+public:
+ static constexpr size_t ByteCount = BitCount / 8;
+
+ CryptoBits() = default;
+
+ bool IsNull() const { return memcmp(&m_Bits, &Zero, ByteCount) == 0; }
+ bool IsValid() const { return IsNull() == false; }
+
+ size_t GetSize() const { return ByteCount; }
+ size_t GetBitCount() const { return BitCount; }
+
+ MemoryView GetView() const { return MemoryView(m_Bits, ByteCount); }
+
+ static CryptoBits FromMemoryView(MemoryView Bits)
+ {
+ if (Bits.GetSize() != ByteCount)
+ {
+ return CryptoBits();
+ }
+
+ return CryptoBits(Bits);
+ }
+
+ static CryptoBits FromString(std::string_view Str) { return FromMemoryView(MakeMemoryView(Str)); }
+
+private:
+ CryptoBits(MemoryView Bits)
+ {
+ ZEN_ASSERT(Bits.GetSize() == GetSize());
+ memcpy(&m_Bits, Bits.GetData(), GetSize());
+ }
+
+ static constexpr uint8_t Zero[ByteCount] = {0};
+
+ uint8_t m_Bits[ByteCount] = {0};
+};
+
+using AesKey256Bit = CryptoBits<256>;
+using AesIV128Bit = CryptoBits<128>;
+
+class Aes
+{
+public:
+ static constexpr size_t BlockSize = 16;
+
+ static MemoryView Encrypt(const AesKey256Bit& Key,
+ const AesIV128Bit& IV,
+ MemoryView In,
+ MutableMemoryView Out,
+ std::optional<std::string>& Reason);
+
+ static MemoryView Decrypt(const AesKey256Bit& Key,
+ const AesIV128Bit& IV,
+ MemoryView In,
+ MutableMemoryView Out,
+ std::optional<std::string>& Reason);
+};
+
+void crypto_forcelink();
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/endian.h b/src/zencore/include/zencore/endian.h
new file mode 100644
index 000000000..7a9e6b44c
--- /dev/null
+++ b/src/zencore/include/zencore/endian.h
@@ -0,0 +1,113 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#include <cstdint>
+
+namespace zen {
+
+inline uint16_t
+ByteSwap(uint16_t x)
+{
+#if ZEN_COMPILER_MSC
+ return _byteswap_ushort(x);
+#else
+ return __builtin_bswap16(x);
+#endif
+}
+
+inline uint32_t
+ByteSwap(uint32_t x)
+{
+#if ZEN_COMPILER_MSC
+ return _byteswap_ulong(x);
+#else
+ return __builtin_bswap32(x);
+#endif
+}
+
+inline uint64_t
+ByteSwap(uint64_t x)
+{
+#if ZEN_COMPILER_MSC
+ return _byteswap_uint64(x);
+#else
+ return __builtin_bswap64(x);
+#endif
+}
+
+inline uint16_t
+FromNetworkOrder(uint16_t x)
+{
+ return ByteSwap(x);
+}
+
+inline uint32_t
+FromNetworkOrder(uint32_t x)
+{
+ return ByteSwap(x);
+}
+
+inline uint64_t
+FromNetworkOrder(uint64_t x)
+{
+ return ByteSwap(x);
+}
+
+inline uint16_t
+FromNetworkOrder(int16_t x)
+{
+ return ByteSwap(uint16_t(x));
+}
+
+inline uint32_t
+FromNetworkOrder(int32_t x)
+{
+ return ByteSwap(uint32_t(x));
+}
+
+inline uint64_t
+FromNetworkOrder(int64_t x)
+{
+ return ByteSwap(uint64_t(x));
+}
+
+inline uint16_t
+ToNetworkOrder(uint16_t x)
+{
+ return ByteSwap(x);
+}
+
+inline uint32_t
+ToNetworkOrder(uint32_t x)
+{
+ return ByteSwap(x);
+}
+
+inline uint64_t
+ToNetworkOrder(uint64_t x)
+{
+ return ByteSwap(x);
+}
+
+inline uint16_t
+ToNetworkOrder(int16_t x)
+{
+ return ByteSwap(uint16_t(x));
+}
+
+inline uint32_t
+ToNetworkOrder(int32_t x)
+{
+ return ByteSwap(uint32_t(x));
+}
+
+inline uint64_t
+ToNetworkOrder(int64_t x)
+{
+ return ByteSwap(uint64_t(x));
+}
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/enumflags.h b/src/zencore/include/zencore/enumflags.h
new file mode 100644
index 000000000..ebe747bf0
--- /dev/null
+++ b/src/zencore/include/zencore/enumflags.h
@@ -0,0 +1,61 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+namespace zen {
+
+// Enum class helpers
+
+// Defines all bitwise operators for enum classes so it can be (mostly) used as a regular flags enum
+#define ENUM_CLASS_FLAGS(Enum) \
+ inline Enum& operator|=(Enum& Lhs, Enum Rhs) { return Lhs = (Enum)((__underlying_type(Enum))Lhs | (__underlying_type(Enum))Rhs); } \
+ inline Enum& operator&=(Enum& Lhs, Enum Rhs) { return Lhs = (Enum)((__underlying_type(Enum))Lhs & (__underlying_type(Enum))Rhs); } \
+ inline Enum& operator^=(Enum& Lhs, Enum Rhs) { return Lhs = (Enum)((__underlying_type(Enum))Lhs ^ (__underlying_type(Enum))Rhs); } \
+ inline constexpr Enum operator|(Enum Lhs, Enum Rhs) { return (Enum)((__underlying_type(Enum))Lhs | (__underlying_type(Enum))Rhs); } \
+ inline constexpr Enum operator&(Enum Lhs, Enum Rhs) { return (Enum)((__underlying_type(Enum))Lhs & (__underlying_type(Enum))Rhs); } \
+ inline constexpr Enum operator^(Enum Lhs, Enum Rhs) { return (Enum)((__underlying_type(Enum))Lhs ^ (__underlying_type(Enum))Rhs); } \
+ inline constexpr bool operator!(Enum E) { return !(__underlying_type(Enum))E; } \
+ inline constexpr Enum operator~(Enum E) { return (Enum) ~(__underlying_type(Enum))E; }
+
+// Friends all bitwise operators for enum classes so the definition can be kept private / protected.
+#define FRIEND_ENUM_CLASS_FLAGS(Enum) \
+ friend Enum& operator|=(Enum& Lhs, Enum Rhs); \
+ friend Enum& operator&=(Enum& Lhs, Enum Rhs); \
+ friend Enum& operator^=(Enum& Lhs, Enum Rhs); \
+ friend constexpr Enum operator|(Enum Lhs, Enum Rhs); \
+ friend constexpr Enum operator&(Enum Lhs, Enum Rhs); \
+ friend constexpr Enum operator^(Enum Lhs, Enum Rhs); \
+ friend constexpr bool operator!(Enum E); \
+ friend constexpr Enum operator~(Enum E);
+
+template<typename Enum>
+constexpr bool
+EnumHasAllFlags(Enum Flags, Enum Contains)
+{
+ return (((__underlying_type(Enum))Flags) & (__underlying_type(Enum))Contains) == ((__underlying_type(Enum))Contains);
+}
+
+template<typename Enum>
+constexpr bool
+EnumHasAnyFlags(Enum Flags, Enum Contains)
+{
+ return (((__underlying_type(Enum))Flags) & (__underlying_type(Enum))Contains) != 0;
+}
+
+template<typename Enum>
+void
+EnumAddFlags(Enum& Flags, Enum FlagsToAdd)
+{
+ Flags |= FlagsToAdd;
+}
+
+template<typename Enum>
+void
+EnumRemoveFlags(Enum& Flags, Enum FlagsToRemove)
+{
+ Flags &= ~FlagsToRemove;
+}
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/except.h b/src/zencore/include/zencore/except.h
new file mode 100644
index 000000000..c61db5ba9
--- /dev/null
+++ b/src/zencore/include/zencore/except.h
@@ -0,0 +1,57 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/string.h>
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#else
+# include <errno.h>
+#endif
+#if __has_include("source_location")
+# include <source_location>
+#endif
+#include <string>
+#include <system_error>
+
+namespace zen {
+
+#if ZEN_PLATFORM_WINDOWS
+ZENCORE_API void ThrowSystemException [[noreturn]] (HRESULT hRes, std::string_view Message);
+#endif // ZEN_PLATFORM_WINDOWS
+
+#if defined(__cpp_lib_source_location)
+ZENCORE_API void ThrowLastErrorImpl [[noreturn]] (std::string_view Message, const std::source_location& Location);
+# define ThrowLastError(Message) ThrowLastErrorImpl(Message, std::source_location::current())
+#else
+ZENCORE_API void ThrowLastError [[noreturn]] (std::string_view Message);
+#endif
+
+ZENCORE_API void ThrowSystemError [[noreturn]] (uint32_t ErrorCode, std::string_view Message);
+
+ZENCORE_API std::string GetLastErrorAsString();
+ZENCORE_API std::string GetSystemErrorAsString(uint32_t Win32ErrorCode);
+
+inline int32_t
+GetLastError()
+{
+#if ZEN_PLATFORM_WINDOWS
+ return ::GetLastError();
+#else
+ return errno;
+#endif
+}
+
+inline std::error_code
+MakeErrorCode(uint32_t ErrorCode) noexcept
+{
+ return std::error_code(ErrorCode, std::system_category());
+}
+
+inline std::error_code
+MakeErrorCodeFromLastError() noexcept
+{
+ return std::error_code(zen::GetLastError(), std::system_category());
+}
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/filesystem.h b/src/zencore/include/zencore/filesystem.h
new file mode 100644
index 000000000..fa5f94170
--- /dev/null
+++ b/src/zencore/include/zencore/filesystem.h
@@ -0,0 +1,190 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#include <zencore/iobuffer.h>
+#include <zencore/string.h>
+
+#include <filesystem>
+#include <functional>
+
+namespace zen {
+
+class IoBuffer;
+
+/** Delete directory (after deleting any contents)
+ */
+ZENCORE_API bool DeleteDirectories(const std::filesystem::path& dir);
+
+/** Ensure directory exists.
+
+ Will also create any required parent directories
+ */
+ZENCORE_API bool CreateDirectories(const std::filesystem::path& dir);
+
+/** Ensure directory exists and delete contents (if any) before returning
+ */
+ZENCORE_API bool CleanDirectory(const std::filesystem::path& dir);
+
+/** Map native file handle to a path
+ */
+ZENCORE_API std::filesystem::path PathFromHandle(void* NativeHandle);
+
+ZENCORE_API std::filesystem::path GetRunningExecutablePath();
+
+/** Set the max open file handle count to max allowed for the current process on Linux and MacOS
+ */
+ZENCORE_API void MaximizeOpenFileCount();
+
+struct FileContents
+{
+ std::vector<IoBuffer> Data;
+ std::error_code ErrorCode;
+
+ IoBuffer Flatten();
+};
+
+ZENCORE_API FileContents ReadStdIn();
+ZENCORE_API FileContents ReadFile(std::filesystem::path Path);
+ZENCORE_API bool ScanFile(std::filesystem::path Path, uint64_t ChunkSize, std::function<void(const void* Data, size_t Size)>&& ProcessFunc);
+ZENCORE_API void WriteFile(std::filesystem::path Path, const IoBuffer* const* Data, size_t BufferCount);
+ZENCORE_API void WriteFile(std::filesystem::path Path, IoBuffer Data);
+
+struct CopyFileOptions
+{
+ bool EnableClone = true;
+ bool MustClone = false;
+};
+
+ZENCORE_API bool CopyFile(std::filesystem::path FromPath, std::filesystem::path ToPath, const CopyFileOptions& Options);
+ZENCORE_API bool SupportsBlockRefCounting(std::filesystem::path Path);
+
+ZENCORE_API void PathToUtf8(const std::filesystem::path& Path, StringBuilderBase& Out);
+ZENCORE_API std::string PathToUtf8(const std::filesystem::path& Path);
+
+extern template class StringBuilderImpl<std::filesystem::path::value_type>;
+
+/**
+ * Helper class for building paths. Backed by a string builder.
+ *
+ */
+class PathBuilderBase : public StringBuilderImpl<std::filesystem::path::value_type>
+{
+private:
+ using Super = StringBuilderImpl<std::filesystem::path::value_type>;
+
+protected:
+ using CharType = std::filesystem::path::value_type;
+ using ViewType = std::basic_string_view<CharType>;
+
+public:
+ void Append(const std::filesystem::path& Rhs) { Super::Append(Rhs.c_str()); }
+ void operator/=(const std::filesystem::path& Rhs) { this->operator/=(Rhs.c_str()); };
+ void operator/=(const CharType* Rhs)
+ {
+ AppendSeparator();
+ Super::Append(Rhs);
+ }
+ operator ViewType() const { return ToView(); }
+ std::basic_string_view<CharType> ToView() const { return std::basic_string_view<CharType>(Data(), Size()); }
+ std::filesystem::path ToPath() const { return std::filesystem::path(ToView()); }
+
+ std::string ToUtf8() const
+ {
+#if ZEN_PLATFORM_WINDOWS
+ return WideToUtf8(ToView());
+#else
+ return std::string(ToView());
+#endif
+ }
+
+ void AppendSeparator()
+ {
+ if (ToView().ends_with(std::filesystem::path::preferred_separator)
+#if ZEN_PLATFORM_WINDOWS
+ || ToView().ends_with('/')
+#endif
+ )
+ return;
+
+ Super::Append(std::filesystem::path::preferred_separator);
+ }
+};
+
+template<size_t N>
+class PathBuilder : public PathBuilderBase
+{
+public:
+ PathBuilder() { Init(m_Buffer, N); }
+
+private:
+ PathBuilderBase::CharType m_Buffer[N];
+};
+
+template<size_t N>
+class ExtendablePathBuilder : public PathBuilder<N>
+{
+public:
+ ExtendablePathBuilder() { this->m_IsExtendable = true; }
+};
+
+struct DiskSpace
+{
+ uint64_t Free{};
+ uint64_t Total{};
+};
+
+ZENCORE_API DiskSpace DiskSpaceInfo(std::filesystem::path Directory, std::error_code& Error);
+
+inline bool
+DiskSpaceInfo(std::filesystem::path Directory, DiskSpace& Space)
+{
+ std::error_code Err;
+ Space = DiskSpaceInfo(Directory, Err);
+ return !Err;
+}
+
+/**
+ * Efficient file system traversal
+ *
+ * Uses the best available mechanism for the platform in question and could take
+ * advantage of any file system tracking mechanisms in the future
+ *
+ */
+class FileSystemTraversal
+{
+public:
+ struct TreeVisitor
+ {
+ using path_view = std::basic_string_view<std::filesystem::path::value_type>;
+ using path_string = std::filesystem::path::string_type;
+
+ virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t FileSize) = 0;
+
+ // This should return true if we should recurse into the directory
+ virtual bool VisitDirectory(const std::filesystem::path& Parent, const path_view& DirectoryName) = 0;
+ };
+
+ void TraverseFileSystem(const std::filesystem::path& RootDir, TreeVisitor& Visitor);
+};
+
+struct DirectoryContent
+{
+ static const uint8_t IncludeDirsFlag = 1u << 0;
+ static const uint8_t IncludeFilesFlag = 1u << 1;
+ static const uint8_t RecursiveFlag = 1u << 2;
+ std::vector<std::filesystem::path> Files;
+ std::vector<std::filesystem::path> Directories;
+};
+
+void GetDirectoryContent(const std::filesystem::path& RootDir, uint8_t Flags, DirectoryContent& OutContent);
+
+std::string GetEnvVariable(std::string_view VariableName);
+
+//////////////////////////////////////////////////////////////////////////
+
+void filesystem_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/fmtutils.h b/src/zencore/include/zencore/fmtutils.h
new file mode 100644
index 000000000..70867fe72
--- /dev/null
+++ b/src/zencore/include/zencore/fmtutils.h
@@ -0,0 +1,52 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iohash.h>
+#include <zencore/string.h>
+#include <zencore/uid.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <filesystem>
+#include <string_view>
+
+// Custom formatting for some zencore types
+
+template<>
+struct fmt::formatter<zen::IoHash> : formatter<string_view>
+{
+ template<typename FormatContext>
+ auto format(const zen::IoHash& Hash, FormatContext& ctx)
+ {
+ zen::IoHash::String_t String;
+ Hash.ToHexString(String);
+ return formatter<string_view>::format({String, zen::IoHash::StringLength}, ctx);
+ }
+};
+
+template<>
+struct fmt::formatter<zen::Oid> : formatter<string_view>
+{
+ template<typename FormatContext>
+ auto format(const zen::Oid& Id, FormatContext& ctx)
+ {
+ zen::StringBuilder<32> String;
+ Id.ToString(String);
+ return formatter<string_view>::format({String.c_str(), zen::Oid::StringLength}, ctx);
+ }
+};
+
+template<>
+struct fmt::formatter<std::filesystem::path> : formatter<string_view>
+{
+ template<typename FormatContext>
+ auto format(const std::filesystem::path& Path, FormatContext& ctx)
+ {
+ zen::ExtendableStringBuilder<128> String;
+ String << Path.u8string();
+ return formatter<string_view>::format(String.ToView(), ctx);
+ }
+};
diff --git a/src/zencore/include/zencore/intmath.h b/src/zencore/include/zencore/intmath.h
new file mode 100644
index 000000000..f24caed6e
--- /dev/null
+++ b/src/zencore/include/zencore/intmath.h
@@ -0,0 +1,183 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#include <stdint.h>
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_COMPILER_MSC || ZEN_PLATFORM_WINDOWS
+# pragma intrinsic(_BitScanReverse)
+# pragma intrinsic(_BitScanReverse64)
+#else
+inline uint8_t
+_BitScanReverse(unsigned long* Index, uint32_t Mask)
+{
+ if (Mask == 0)
+ {
+ return 0;
+ }
+
+ *Index = 31 - __builtin_clz(Mask);
+ return 1;
+}
+
+inline uint8_t
+_BitScanReverse64(unsigned long* Index, uint64_t Mask)
+{
+ if (Mask == 0)
+ {
+ return 0;
+ }
+
+ *Index = 63 - __builtin_clzll(Mask);
+ return 1;
+}
+
+inline uint8_t
+_BitScanForward64(unsigned long* Index, uint64_t Mask)
+{
+ if (Mask == 0)
+ {
+ return 0;
+ }
+
+ *Index = __builtin_ctzll(Mask);
+ return 1;
+}
+#endif
+
+namespace zen {
+
+inline constexpr bool
+IsPow2(uint64_t n)
+{
+ return 0 == (n & (n - 1));
+}
+
+/// Round an integer up to the closest integer multiplier of 'base' ('base' must be a power of two)
+template<Integral T>
+T
+RoundUp(T Value, auto Base)
+{
+ ZEN_ASSERT_SLOW(IsPow2(Base));
+ return ((Value + T(Base - 1)) & (~T(Base - 1)));
+}
+
+bool
+IsMultipleOf(Integral auto Value, auto MultiplierPow2)
+{
+ ZEN_ASSERT_SLOW(IsPow2(MultiplierPow2));
+ return (Value & (MultiplierPow2 - 1)) == 0;
+}
+
+inline uint64_t
+NextPow2(uint64_t n)
+{
+ // http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
+
+ --n;
+
+ n |= n >> 1;
+ n |= n >> 2;
+ n |= n >> 4;
+ n |= n >> 8;
+ n |= n >> 16;
+ n |= n >> 32;
+
+ return n + 1;
+}
+
+static inline uint32_t
+FloorLog2(uint32_t Value)
+{
+ // Use BSR to return the log2 of the integer
+ unsigned long Log2;
+ if (_BitScanReverse(&Log2, Value) != 0)
+ {
+ return Log2;
+ }
+
+ return 0;
+}
+
+static inline uint32_t
+CountLeadingZeros(uint32_t Value)
+{
+ unsigned long Log2 = 0;
+ _BitScanReverse64(&Log2, (uint64_t(Value) << 1) | 1);
+ return 32 - Log2;
+}
+
+static inline uint64_t
+FloorLog2_64(uint64_t Value)
+{
+ unsigned long Log2 = 0;
+ long Mask = -long(_BitScanReverse64(&Log2, Value) != 0);
+ return Log2 & Mask;
+}
+
+static inline uint64_t
+CountLeadingZeros64(uint64_t Value)
+{
+ unsigned long Log2 = 0;
+ long Mask = -long(_BitScanReverse64(&Log2, Value) != 0);
+ return ((63 - Log2) & Mask) | (64 & ~Mask);
+}
+
+static inline uint64_t
+CeilLogTwo64(uint64_t Arg)
+{
+ int64_t Bitmask = ((int64_t)(CountLeadingZeros64(Arg) << 57)) >> 63;
+ return (64 - CountLeadingZeros64(Arg - 1)) & (~Bitmask);
+}
+
+static inline uint64_t
+CountTrailingZeros64(uint64_t Value)
+{
+ if (Value == 0)
+ {
+ return 64;
+ }
+ unsigned long BitIndex; // 0-based, where the LSB is 0 and MSB is 31
+ _BitScanForward64(&BitIndex, Value); // Scans from LSB to MSB
+ return BitIndex;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+static inline bool
+IsPointerAligned(const void* Ptr, uint64_t Alignment)
+{
+ ZEN_ASSERT_SLOW(IsPow2(Alignment));
+
+ return 0 == (reinterpret_cast<uintptr_t>(Ptr) & (Alignment - 1));
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_PLATFORM_WINDOWS
+# ifdef min
+# error "Looks like you did #include <windows.h> -- use <zencore/windows.h> instead"
+# endif
+#endif
+
+constexpr auto
+Min(auto x, auto y)
+{
+ return x < y ? x : y;
+}
+
+constexpr auto
+Max(auto x, auto y)
+{
+ return x > y ? x : y;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void intmath_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/iobuffer.h b/src/zencore/include/zencore/iobuffer.h
new file mode 100644
index 000000000..a39dbf6d6
--- /dev/null
+++ b/src/zencore/include/zencore/iobuffer.h
@@ -0,0 +1,423 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <memory.h>
+#include <zencore/memory.h>
+#include <atomic>
+#include "refcount.h"
+#include "zencore.h"
+
+#include <filesystem>
+
+namespace zen {
+
+struct IoHash;
+struct IoBufferExtendedCore;
+
+enum class ZenContentType : uint8_t
+{
+ kBinary = 0, // Note that since this is zero, this will be the default value in IoBuffer
+ kText = 1,
+ kJSON = 2,
+ kCbObject = 3,
+ kCbPackage = 4,
+ kYAML = 5,
+ kCbPackageOffer = 6,
+ kCompressedBinary = 7,
+ kUnknownContentType = 8,
+ kHTML = 9,
+ kJavaScript = 10,
+ kCSS = 11,
+ kPNG = 12,
+ kIcon = 13,
+ kCOUNT
+};
+
+inline std::string_view
+ToString(ZenContentType ContentType)
+{
+ using namespace std::literals;
+
+ switch (ContentType)
+ {
+ default:
+ case ZenContentType::kUnknownContentType:
+ return "unknown"sv;
+ case ZenContentType::kBinary:
+ return "binary"sv;
+ case ZenContentType::kText:
+ return "text"sv;
+ case ZenContentType::kJSON:
+ return "json"sv;
+ case ZenContentType::kCbObject:
+ return "cb-object"sv;
+ case ZenContentType::kCbPackage:
+ return "cb-package"sv;
+ case ZenContentType::kCbPackageOffer:
+ return "cb-package-offer"sv;
+ case ZenContentType::kCompressedBinary:
+ return "compressed-binary"sv;
+ case ZenContentType::kYAML:
+ return "yaml"sv;
+ case ZenContentType::kHTML:
+ return "html"sv;
+ case ZenContentType::kJavaScript:
+ return "javascript"sv;
+ case ZenContentType::kCSS:
+ return "css"sv;
+ case ZenContentType::kPNG:
+ return "png"sv;
+ case ZenContentType::kIcon:
+ return "icon"sv;
+ }
+}
+
+struct IoBufferFileReference
+{
+ void* FileHandle;
+ uint64_t FileChunkOffset;
+ uint64_t FileChunkSize;
+};
+
+struct IoBufferCore
+{
+public:
+ inline IoBufferCore() : m_Flags(kIsNull) {}
+ inline IoBufferCore(const void* DataPtr, size_t SizeBytes) : m_DataPtr(DataPtr), m_DataBytes(SizeBytes) {}
+ inline IoBufferCore(const IoBufferCore* Outer, const void* DataPtr, size_t SizeBytes)
+ : m_DataPtr(DataPtr)
+ , m_DataBytes(SizeBytes)
+ , m_OuterCore(Outer)
+ {
+ }
+
+ ZENCORE_API explicit IoBufferCore(size_t SizeBytes);
+ ZENCORE_API IoBufferCore(size_t SizeBytes, size_t Alignment);
+ ZENCORE_API ~IoBufferCore();
+
+ // Reference counting
+
+ inline uint32_t AddRef() const { return AtomicIncrement(const_cast<IoBufferCore*>(this)->m_RefCount); }
+ inline uint32_t Release() const
+ {
+ const uint32_t NewRefCount = AtomicDecrement(const_cast<IoBufferCore*>(this)->m_RefCount);
+ if (NewRefCount == 0)
+ {
+ DeleteThis();
+ }
+ return NewRefCount;
+ }
+
+ // Copying reference counted objects doesn't make a lot of sense generally, so let's prevent it
+
+ IoBufferCore(const IoBufferCore&) = delete;
+ IoBufferCore(IoBufferCore&&) = delete;
+ IoBufferCore& operator=(const IoBufferCore&) = delete;
+ IoBufferCore& operator=(IoBufferCore&&) = delete;
+
+ //
+
+ ZENCORE_API void Materialize() const;
+ ZENCORE_API void DeleteThis() const;
+ ZENCORE_API void MakeOwned(bool Immutable = true);
+
+ inline void EnsureDataValid() const
+ {
+ const uint32_t LocalFlags = m_Flags.load(std::memory_order_acquire);
+ if ((LocalFlags & kIsExtended) && !(LocalFlags & kIsMaterialized))
+ {
+ Materialize();
+ }
+ }
+
+ inline bool IsOwnedByThis() const { return !!(m_Flags.load(std::memory_order_relaxed) & kIsOwnedByThis); }
+
+ inline void SetIsOwnedByThis(bool NewState)
+ {
+ if (NewState)
+ {
+ m_Flags.fetch_or(kIsOwnedByThis, std::memory_order_relaxed);
+ }
+ else
+ {
+ m_Flags.fetch_and(~kIsOwnedByThis, std::memory_order_relaxed);
+ }
+ }
+
+ inline bool IsOwned() const
+ {
+ if (IsOwnedByThis())
+ {
+ return true;
+ }
+ return m_OuterCore && m_OuterCore->IsOwned();
+ }
+
+ inline bool IsImmutable() const { return (m_Flags.load(std::memory_order_relaxed) & kIsMutable) == 0; }
+ inline bool IsWholeFile() const { return (m_Flags.load(std::memory_order_relaxed) & kIsWholeFile) != 0; }
+ inline bool IsNull() const { return (m_Flags.load(std::memory_order_relaxed) & kIsNull) != 0; }
+
+ inline IoBufferExtendedCore* ExtendedCore();
+ inline const IoBufferExtendedCore* ExtendedCore() const;
+
+ ZENCORE_API void* MutableDataPointer() const;
+
+ inline const void* DataPointer() const
+ {
+ EnsureDataValid();
+ return m_DataPtr;
+ }
+
+ inline size_t DataBytes() const { return m_DataBytes; }
+
+ inline void Set(const void* Ptr, size_t Sz)
+ {
+ m_DataPtr = Ptr;
+ m_DataBytes = Sz;
+ }
+
+ inline void SetIsImmutable(bool NewState)
+ {
+ if (!NewState)
+ {
+ m_Flags.fetch_or(kIsMutable, std::memory_order_relaxed);
+ }
+ else
+ {
+ m_Flags.fetch_and(~kIsMutable, std::memory_order_relaxed);
+ }
+ }
+
+ inline void SetIsWholeFile(bool NewState)
+ {
+ if (NewState)
+ {
+ m_Flags.fetch_or(kIsWholeFile, std::memory_order_relaxed);
+ }
+ else
+ {
+ m_Flags.fetch_and(~kIsWholeFile, std::memory_order_relaxed);
+ }
+ }
+
+ inline void SetContentType(ZenContentType ContentType)
+ {
+ ZEN_ASSERT_SLOW((uint32_t(ContentType) & kContentTypeMask) == uint32_t(ContentType));
+ uint32_t OldValue = m_Flags.load(std::memory_order_relaxed);
+ uint32_t NewValue;
+ do
+ {
+ NewValue = (OldValue & ~(kContentTypeMask << kContentTypeShift)) | (uint32_t(ContentType) << kContentTypeShift);
+ } while (!m_Flags.compare_exchange_weak(OldValue, NewValue, std::memory_order_relaxed, std::memory_order_relaxed));
+ }
+
+ inline ZenContentType GetContentType() const
+ {
+ return ZenContentType((m_Flags.load(std::memory_order_relaxed) >> kContentTypeShift) & kContentTypeMask);
+ }
+
+ inline uint32_t GetRefCount() const { return m_RefCount; }
+
+protected:
+ uint32_t m_RefCount = 0;
+ mutable std::atomic<uint32_t> m_Flags{0};
+ mutable const void* m_DataPtr = nullptr;
+ size_t m_DataBytes = 0;
+ RefPtr<const IoBufferCore> m_OuterCore;
+
+ enum
+ {
+ kContentTypeShift = 24,
+ kContentTypeMask = 0xf
+ };
+
+ static_assert((uint32_t(ZenContentType::kUnknownContentType) & ~kContentTypeMask) == 0);
+
+ enum Flags : uint32_t
+ {
+ kIsNull = 1 << 0, // This is a null IoBuffer
+ kIsMutable = 1 << 1,
+ kIsExtended = 1 << 2, // Is actually a SharedBufferExtendedCore
+ kIsMaterialized = 1 << 3, // Data pointers are valid
+ kLowLevelAlloc = 1 << 4, // Using direct memory allocation
+ kIsWholeFile = 1 << 5, // References an entire file
+ kIoBufferAlloc = 1 << 6, // Using IoBuffer allocator
+ kIsOwnedByThis = 1 << 7,
+
+ // Note that we have some extended flags defined below
+ // so not all bits are available to use here
+
+ kContentTypeBit0 = 1 << (24 + 0), // These constants
+ kContentTypeBit1 = 1 << (24 + 1), // are here mostly to
+ kContentTypeBit2 = 1 << (24 + 2), // indicate that these
+ kContentTypeBit3 = 1 << (24 + 3), // bits are reserved
+ };
+
+ void AllocateBuffer(size_t InSize, size_t Alignment) const;
+ void FreeBuffer();
+};
+
+/**
+ * An "Extended" core references a segment of a file
+ */
+
+struct IoBufferExtendedCore : public IoBufferCore
+{
+ IoBufferExtendedCore(void* FileHandle, uint64_t Offset, uint64_t Size, bool TransferHandleOwnership);
+ IoBufferExtendedCore(const IoBufferExtendedCore* Outer, uint64_t Offset, uint64_t Size);
+ ~IoBufferExtendedCore();
+
+ enum ExtendedFlags
+ {
+ kOwnsFile = 1 << 16,
+ kOwnsMmap = 1 << 17
+ };
+
+ void Materialize() const;
+ bool GetFileReference(IoBufferFileReference& OutRef) const;
+ void MarkAsDeleteOnClose();
+
+private:
+ void* m_FileHandle = nullptr;
+ uint64_t m_FileOffset = 0;
+ mutable void* m_MmapHandle = nullptr;
+ mutable void* m_MappedPointer = nullptr;
+ bool m_DeleteOnClose = false;
+};
+
+inline IoBufferExtendedCore*
+IoBufferCore::ExtendedCore()
+{
+ if (m_Flags.load(std::memory_order_relaxed) & kIsExtended)
+ {
+ return static_cast<IoBufferExtendedCore*>(this);
+ }
+
+ return nullptr;
+}
+
+inline const IoBufferExtendedCore*
+IoBufferCore::ExtendedCore() const
+{
+ if (m_Flags.load(std::memory_order_relaxed) & kIsExtended)
+ {
+ return static_cast<const IoBufferExtendedCore*>(this);
+ }
+
+ return nullptr;
+}
+
+/**
+ * I/O buffer
+ *
+ * This represents a reference to a payload in memory or on disk
+ *
+ */
+class IoBuffer
+{
+public:
+ enum ECloneTag
+ {
+ Clone
+ };
+ enum EWrapTag
+ {
+ Wrap
+ };
+ enum EFileTag
+ {
+ File
+ };
+ enum EBorrowedFileTag
+ {
+ BorrowedFile
+ };
+
+ inline IoBuffer() = default;
+ inline IoBuffer(IoBuffer&& Rhs) noexcept = default;
+ inline IoBuffer(const IoBuffer& Rhs) = default;
+ inline IoBuffer& operator=(const IoBuffer& Rhs) = default;
+ inline IoBuffer& operator=(IoBuffer&& Rhs) noexcept = default;
+
+ /** Create an uninitialized buffer of the given size
+ */
+ ZENCORE_API explicit IoBuffer(size_t InSize);
+
+ /** Create an uninitialized buffer of the given size with the specified alignment
+ */
+ ZENCORE_API explicit IoBuffer(size_t InSize, uint64_t InAlignment);
+
+ /** Create a buffer which references a sequence of bytes inside another buffer
+ */
+ ZENCORE_API IoBuffer(const IoBuffer& OuterBuffer, size_t Offset, size_t SizeBytes = ~0ull);
+
+ /** Create a buffer which references a range of bytes which we assume will live
+ * for the entire life time.
+ */
+ inline IoBuffer(EWrapTag, const void* DataPtr, size_t SizeBytes) : m_Core(new IoBufferCore(DataPtr, SizeBytes)) {}
+
+ inline IoBuffer(ECloneTag, const void* DataPtr, size_t SizeBytes) : m_Core(new IoBufferCore(SizeBytes))
+ {
+ memcpy(const_cast<void*>(m_Core->DataPointer()), DataPtr, SizeBytes);
+ }
+
+ ZENCORE_API IoBuffer(EFileTag, void* FileHandle, uint64_t ChunkFileOffset, uint64_t ChunkSize);
+ ZENCORE_API IoBuffer(EBorrowedFileTag, void* FileHandle, uint64_t ChunkFileOffset, uint64_t ChunkSize);
+
+ inline explicit operator bool() const { return !m_Core->IsNull(); }
+ inline operator MemoryView() const& { return MemoryView(m_Core->DataPointer(), m_Core->DataBytes()); }
+ inline void MakeOwned() { return m_Core->MakeOwned(); }
+ [[nodiscard]] inline bool IsOwned() const { return m_Core->IsOwned(); }
+ [[nodiscard]] inline bool IsWholeFile() const { return m_Core->IsWholeFile(); }
+ [[nodiscard]] void* MutableData() const { return m_Core->MutableDataPointer(); }
+ void MakeImmutable() { m_Core->SetIsImmutable(true); }
+ [[nodiscard]] const void* Data() const { return m_Core->DataPointer(); }
+ [[nodiscard]] const void* GetData() const { return m_Core->DataPointer(); }
+ [[nodiscard]] size_t Size() const { return m_Core->DataBytes(); }
+ [[nodiscard]] size_t GetSize() const { return m_Core->DataBytes(); }
+ inline void SetContentType(ZenContentType ContentType) { m_Core->SetContentType(ContentType); }
+ [[nodiscard]] inline ZenContentType GetContentType() const { return m_Core->GetContentType(); }
+ [[nodiscard]] ZENCORE_API bool GetFileReference(IoBufferFileReference& OutRef) const;
+ void MarkAsDeleteOnClose();
+
+ inline MemoryView GetView() const { return MemoryView(m_Core->DataPointer(), m_Core->DataBytes()); }
+ inline MutableMemoryView GetMutableView() { return MutableMemoryView(m_Core->MutableDataPointer(), m_Core->DataBytes()); }
+
+ template<typename T>
+ [[nodiscard]] const T* Data() const
+ {
+ return reinterpret_cast<const T*>(m_Core->DataPointer());
+ }
+
+ template<typename T>
+ [[nodiscard]] T* MutableData() const
+ {
+ return reinterpret_cast<T*>(m_Core->MutableDataPointer());
+ }
+
+private:
+ RefPtr<IoBufferCore> m_Core = new IoBufferCore;
+
+ IoBuffer(IoBufferCore* Core) : m_Core(Core) {}
+
+ friend class SharedBuffer;
+ friend class IoBufferBuilder;
+};
+
+class IoBufferBuilder
+{
+public:
+ ZENCORE_API static IoBuffer MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset = 0, uint64_t Size = ~0ull);
+ ZENCORE_API static IoBuffer MakeFromTemporaryFile(const std::filesystem::path& FileName);
+ ZENCORE_API static IoBuffer MakeFromFileHandle(void* FileHandle, uint64_t Offset = 0, uint64_t Size = ~0ull);
+ ZENCORE_API static IoBuffer ReadFromFileMaybe(IoBuffer& InBuffer);
+ inline static IoBuffer MakeCloneFromMemory(const void* Ptr, size_t Sz) { return IoBuffer(IoBuffer::Clone, Ptr, Sz); }
+ inline static IoBuffer MakeCloneFromMemory(MemoryView Memory) { return IoBuffer(IoBuffer::Clone, Memory.GetData(), Memory.GetSize()); }
+};
+
+IoHash HashBuffer(IoBuffer& Buffer);
+
+void iobuffer_forcelink();
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/iohash.h b/src/zencore/include/zencore/iohash.h
new file mode 100644
index 000000000..fd0f4b2a7
--- /dev/null
+++ b/src/zencore/include/zencore/iohash.h
@@ -0,0 +1,115 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#include <zencore/blake3.h>
+#include <zencore/memory.h>
+
+#include <compare>
+#include <string_view>
+
+namespace zen {
+
+class StringBuilderBase;
+class CompositeBuffer;
+
+/**
+ * Hash used for content addressable storage
+ *
+ * This is basically a BLAKE3-160 hash (note: this is probably not an officially
+ * recognized identifier). It is generated by computing a 32-byte BLAKE3 hash and
+ * picking the first 20 bytes of the resulting hash.
+ *
+ */
+struct IoHash
+{
+ alignas(uint32_t) uint8_t Hash[20] = {};
+
+ static IoHash MakeFrom(const void* data /* 20 bytes */)
+ {
+ IoHash Io;
+ memcpy(Io.Hash, data, sizeof Io);
+ return Io;
+ }
+
+ static IoHash FromBLAKE3(const BLAKE3& Blake3)
+ {
+ IoHash Io;
+ memcpy(Io.Hash, Blake3.Hash, sizeof Io.Hash);
+ return Io;
+ }
+
+ static IoHash HashBuffer(const void* data, size_t byteCount);
+ static IoHash HashBuffer(MemoryView Data) { return HashBuffer(Data.GetData(), Data.GetSize()); }
+ static IoHash HashBuffer(const CompositeBuffer& Buffer);
+ static IoHash FromHexString(const char* string);
+ static IoHash FromHexString(const std::string_view string);
+ const char* ToHexString(char* outString /* 40 characters + NUL terminator */) const;
+ StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const;
+ std::string ToHexString() const;
+
+ static const int StringLength = 40;
+ typedef char String_t[StringLength + 1];
+
+ static const IoHash Zero; // Initialized to all zeros
+
+ inline auto operator<=>(const IoHash& rhs) const = default;
+
+ struct Hasher
+ {
+ size_t operator()(const IoHash& v) const
+ {
+ size_t h;
+ memcpy(&h, v.Hash, sizeof h);
+ return h;
+ }
+ };
+};
+
+struct IoHashStream
+{
+ /// Begin streaming hash compute (not needed on freshly constructed instance)
+ void Reset() { m_Blake3Stream.Reset(); }
+
+ /// Append another chunk
+ IoHashStream& Append(const void* data, size_t byteCount)
+ {
+ m_Blake3Stream.Append(data, byteCount);
+ return *this;
+ }
+
+ /// Append another chunk
+ IoHashStream& Append(MemoryView Data)
+ {
+ m_Blake3Stream.Append(Data.GetData(), Data.GetSize());
+ return *this;
+ }
+
+ /// Obtain final hash. If you wish to reuse the instance call reset()
+ IoHash GetHash()
+ {
+ BLAKE3 b3 = m_Blake3Stream.GetHash();
+
+ IoHash Io;
+ memcpy(Io.Hash, b3.Hash, sizeof Io.Hash);
+
+ return Io;
+ }
+
+private:
+ BLAKE3Stream m_Blake3Stream;
+};
+
+} // namespace zen
+
+namespace std {
+
+template<>
+struct hash<zen::IoHash> : public zen::IoHash::Hasher
+{
+};
+
+} // namespace std
diff --git a/src/zencore/include/zencore/logging.h b/src/zencore/include/zencore/logging.h
new file mode 100644
index 000000000..5cbe034cf
--- /dev/null
+++ b/src/zencore/include/zencore/logging.h
@@ -0,0 +1,136 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <spdlog/spdlog.h>
+#undef GetObject
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <string_view>
+
+namespace zen::logging {
+
+spdlog::logger& Default();
+void SetDefault(std::shared_ptr<spdlog::logger> NewDefaultLogger);
+spdlog::logger& ConsoleLog();
+spdlog::logger& Get(std::string_view Name);
+spdlog::logger* ErrorLog();
+void SetErrorLog(std::shared_ptr<spdlog::logger>&& NewErrorLogger);
+
+void InitializeLogging();
+void ShutdownLogging();
+
+} // namespace zen::logging
+
+namespace zen {
+extern spdlog::logger* TheDefaultLogger;
+
+inline spdlog::logger&
+Log()
+{
+ return *TheDefaultLogger;
+}
+
+using logging::ConsoleLog;
+using logging::ErrorLog;
+} // namespace zen
+
+using zen::ConsoleLog;
+using zen::ErrorLog;
+using zen::Log;
+
+struct LogCategory
+{
+ LogCategory(std::string_view InCategory) : Category(InCategory) {}
+
+ spdlog::logger& Logger()
+ {
+ static spdlog::logger& Inst = zen::logging::Get(Category);
+ return Inst;
+ }
+
+ std::string Category;
+};
+
+inline consteval bool
+LogIsErrorLevel(int level)
+{
+ return (level == spdlog::level::err || level == spdlog::level::critical);
+};
+
+#define ZEN_LOG_WITH_LOCATION(logger, loc, level, fmtstr, ...) \
+ do \
+ { \
+ using namespace std::literals; \
+ if (logger.should_log(level)) \
+ { \
+ logger.log(loc, level, fmtstr, ##__VA_ARGS__); \
+ if (LogIsErrorLevel(level)) \
+ { \
+ if (auto ErrLogger = zen::logging::ErrorLog(); ErrLogger != nullptr) \
+ { \
+ ErrLogger->log(loc, level, fmtstr, ##__VA_ARGS__); \
+ } \
+ } \
+ } \
+ } while (false);
+
+#define ZEN_LOG(logger, level, fmtstr, ...) ZEN_LOG_WITH_LOCATION(logger, spdlog::source_loc{}, level, fmtstr, ##__VA_ARGS__)
+
+#define ZEN_DEFINE_LOG_CATEGORY_STATIC(Category, Name) \
+ static struct LogCategory##Category : public LogCategory \
+ { \
+ LogCategory##Category() : LogCategory(Name) {} \
+ } Category;
+
+#define ZEN_LOG_TRACE(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), spdlog::level::trace, fmtstr##sv, ##__VA_ARGS__)
+
+#define ZEN_LOG_DEBUG(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), spdlog::level::debug, fmtstr##sv, ##__VA_ARGS__)
+
+#define ZEN_LOG_INFO(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), spdlog::level::info, fmtstr##sv, ##__VA_ARGS__)
+
+#define ZEN_LOG_WARN(Category, fmtstr, ...) ZEN_LOG(Category.Logger(), spdlog::level::warn, fmtstr##sv, ##__VA_ARGS__)
+
+#define ZEN_LOG_ERROR(Category, fmtstr, ...) \
+ ZEN_LOG_WITH_LOCATION(Category.Logger(), \
+ spdlog::source_loc(__FILE__, __LINE__, SPDLOG_FUNCTION), \
+ spdlog::level::err, \
+ fmtstr##sv, \
+ ##__VA_ARGS__)
+
+#define ZEN_LOG_CRITICAL(Category, fmtstr, ...) \
+ ZEN_LOG_WITH_LOCATION(Category.Logger(), \
+ spdlog::source_loc(__FILE__, __LINE__, SPDLOG_FUNCTION), \
+ spdlog::level::critical, \
+ fmtstr##sv, \
+ ##__VA_ARGS__)
+
+ // Helper macros for logging
+
+#define ZEN_TRACE(fmtstr, ...) ZEN_LOG(Log(), spdlog::level::trace, fmtstr##sv, ##__VA_ARGS__)
+
+#define ZEN_DEBUG(fmtstr, ...) ZEN_LOG(Log(), spdlog::level::debug, fmtstr##sv, ##__VA_ARGS__)
+
+#define ZEN_INFO(fmtstr, ...) ZEN_LOG(Log(), spdlog::level::info, fmtstr##sv, ##__VA_ARGS__)
+
+#define ZEN_WARN(fmtstr, ...) ZEN_LOG(Log(), spdlog::level::warn, fmtstr##sv, ##__VA_ARGS__)
+
+#define ZEN_ERROR(fmtstr, ...) \
+ ZEN_LOG_WITH_LOCATION(Log(), spdlog::source_loc(__FILE__, __LINE__, SPDLOG_FUNCTION), spdlog::level::err, fmtstr##sv, ##__VA_ARGS__)
+
+#define ZEN_CRITICAL(fmtstr, ...) \
+ ZEN_LOG_WITH_LOCATION(Log(), \
+ spdlog::source_loc(__FILE__, __LINE__, SPDLOG_FUNCTION), \
+ spdlog::level::critical, \
+ fmtstr##sv, \
+ ##__VA_ARGS__)
+
+#define ZEN_CONSOLE(fmtstr, ...) \
+ do \
+ { \
+ using namespace std::literals; \
+ ConsoleLog().info(fmtstr##sv, ##__VA_ARGS__); \
+ } while (false)
diff --git a/src/zencore/include/zencore/md5.h b/src/zencore/include/zencore/md5.h
new file mode 100644
index 000000000..d934dd86b
--- /dev/null
+++ b/src/zencore/include/zencore/md5.h
@@ -0,0 +1,50 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <stdint.h>
+#include <compare>
+#include "zencore.h"
+
+namespace zen {
+
+class StringBuilderBase;
+
+struct MD5
+{
+ uint8_t Hash[16];
+
+ inline auto operator<=>(const MD5& rhs) const = default;
+
+ static const int StringLength = 32;
+ typedef char String_t[StringLength + 1];
+
+ static MD5 HashMemory(const void* data, size_t byteCount);
+ static MD5 FromHexString(const char* string);
+ const char* ToHexString(char* outString /* 32 characters + NUL terminator */) const;
+ StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const;
+
+ static MD5 Zero; // Initialized to all zeroes
+};
+
+/**
+ * Utility class for computing MD5 hashes
+ */
+class MD5Stream
+{
+public:
+ MD5Stream();
+
+ /// Begin streaming MD5 compute (not needed on freshly constructed MD5Stream instance)
+ void Reset();
+ /// Append another chunk
+ MD5Stream& Append(const void* data, size_t byteCount);
+ /// Obtain final MD5 hash. If you wish to reuse the MD5Stream instance call reset()
+ MD5 GetHash();
+
+private:
+};
+
+void md5_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/memory.h b/src/zencore/include/zencore/memory.h
new file mode 100644
index 000000000..560fa9ffc
--- /dev/null
+++ b/src/zencore/include/zencore/memory.h
@@ -0,0 +1,401 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#include <zencore/intmath.h>
+#include <zencore/thread.h>
+
+#include <cstddef>
+#include <cstring>
+#include <span>
+#include <vector>
+
+namespace zen {
+
+#if defined(__cpp_lib_ranges) && __cpp_lib_ranges >= 201911L
+template<typename T>
+concept ContiguousRange = std::ranges::contiguous_range<T>;
+#else
+template<typename T>
+concept ContiguousRange = true;
+#endif
+
+struct MemoryView;
+
+class MemoryArena
+{
+public:
+ ZENCORE_API MemoryArena();
+ ZENCORE_API ~MemoryArena();
+
+ ZENCORE_API void* Alloc(size_t Size, size_t Alignment);
+ ZENCORE_API void Free(void* Ptr);
+
+private:
+};
+
+class Memory
+{
+public:
+ ZENCORE_API static void* Alloc(size_t Size, size_t Alignment = sizeof(void*));
+ ZENCORE_API static void Free(void* Ptr);
+};
+
+/** Allocator which claims fixed-size blocks from the underlying allocator.
+
+ There is no way to free individual memory blocks.
+
+ \note This is not thread-safe, you will need to provide synchronization yourself
+*/
+
+class ChunkingLinearAllocator
+{
+public:
+ ChunkingLinearAllocator(uint64_t ChunkSize, uint64_t ChunkAlignment = sizeof(std::max_align_t));
+ ~ChunkingLinearAllocator();
+
+ ZENCORE_API void Reset();
+
+ ZENCORE_API void* Alloc(size_t Size, size_t Alignment = sizeof(void*));
+ inline void Free(void* Ptr) { ZEN_UNUSED(Ptr); /* no-op */ }
+
+ ChunkingLinearAllocator(const ChunkingLinearAllocator&) = delete;
+ ChunkingLinearAllocator& operator=(const ChunkingLinearAllocator&) = delete;
+
+private:
+ uint8_t* m_ChunkCursor = nullptr;
+ uint64_t m_ChunkBytesRemain = 0;
+ const uint64_t m_ChunkSize = 0;
+ const uint64_t m_ChunkAlignment = 0;
+ std::vector<void*> m_ChunkList;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct MutableMemoryView
+{
+ MutableMemoryView() = default;
+
+ MutableMemoryView(void* DataPtr, size_t DataSize)
+ : m_Data(reinterpret_cast<uint8_t*>(DataPtr))
+ , m_DataEnd(reinterpret_cast<uint8_t*>(DataPtr) + DataSize)
+ {
+ }
+
+ MutableMemoryView(void* DataPtr, void* DataEndPtr)
+ : m_Data(reinterpret_cast<uint8_t*>(DataPtr))
+ , m_DataEnd(reinterpret_cast<uint8_t*>(DataEndPtr))
+ {
+ }
+
+ inline bool IsEmpty() const { return m_Data == m_DataEnd; }
+ void* GetData() const { return m_Data; }
+ void* GetDataEnd() const { return m_DataEnd; }
+ size_t GetSize() const { return reinterpret_cast<uint8_t*>(m_DataEnd) - reinterpret_cast<uint8_t*>(m_Data); }
+
+ inline bool EqualBytes(const MutableMemoryView& InView) const
+ {
+ const size_t Size = GetSize();
+
+ return Size == InView.GetSize() && (memcmp(m_Data, InView.m_Data, Size) == 0);
+ }
+
+ /** Modifies the view to be the given number of bytes from the right. */
+ inline void RightInline(uint64_t InSize)
+ {
+ const uint64_t OldSize = GetSize();
+ const uint64_t NewSize = zen::Min(OldSize, InSize);
+ m_Data = GetDataAtOffsetNoCheck(OldSize - NewSize);
+ m_DataEnd = m_Data + NewSize;
+ }
+
+ /** Returns the right-most part of the view by taking the given number of bytes from the right. */
+ [[nodiscard]] inline MutableMemoryView Right(uint64_t InSize) const
+ {
+ MutableMemoryView View(*this);
+ View.RightChopInline(InSize);
+ return View;
+ }
+
+ /** Modifies the view by chopping the given number of bytes from the left. */
+ inline void RightChopInline(uint64_t InSize)
+ {
+ const uint64_t Offset = zen::Min(GetSize(), InSize);
+ m_Data = GetDataAtOffsetNoCheck(Offset);
+ }
+
+ /** Returns the left-most part of the view by taking the given number of bytes from the left. */
+ constexpr inline MutableMemoryView Left(uint64_t InSize) const
+ {
+ MutableMemoryView View(*this);
+ View.LeftInline(InSize);
+ return View;
+ }
+
+ /** Modifies the view to be the given number of bytes from the left. */
+ constexpr inline void LeftInline(uint64_t InSize) { m_DataEnd = zen::Min(m_DataEnd, m_Data + InSize); }
+
+ /** Modifies the view to be the middle part by taking up to the given number of bytes from the given offset. */
+ inline void MidInline(uint64_t InOffset, uint64_t InSize = ~uint64_t(0))
+ {
+ RightChopInline(InOffset);
+ LeftInline(InSize);
+ }
+
+ /** Returns the middle part of the view by taking up to the given number of bytes from the given position. */
+ [[nodiscard]] inline MutableMemoryView Mid(uint64_t InOffset, uint64_t InSize = ~uint64_t(0)) const
+ {
+ MutableMemoryView View(*this);
+ View.MidInline(InOffset, InSize);
+ return View;
+ }
+
+ /** Returns the right-most part of the view by chopping the given number of bytes from the left. */
+ [[nodiscard]] inline MutableMemoryView RightChop(uint64_t InSize) const
+ {
+ MutableMemoryView View(*this);
+ View.RightChopInline(InSize);
+ return View;
+ }
+
+ inline MutableMemoryView& operator+=(size_t InSize)
+ {
+ RightChopInline(InSize);
+ return *this;
+ }
+
+ /** Copies bytes from the input view into this view, and returns the remainder of this view. */
+ inline MutableMemoryView CopyFrom(MemoryView InView) const;
+
+private:
+ uint8_t* m_Data = nullptr;
+ uint8_t* m_DataEnd = nullptr;
+
+ /** Returns the data pointer advanced by an offset in bytes. */
+ inline constexpr uint8_t* GetDataAtOffsetNoCheck(uint64_t InOffset) const { return m_Data + InOffset; }
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct MemoryView
+{
+ MemoryView() = default;
+
+ MemoryView(const MutableMemoryView& MutableView)
+ : m_Data(reinterpret_cast<const uint8_t*>(MutableView.GetData()))
+ , m_DataEnd(m_Data + MutableView.GetSize())
+ {
+ }
+
+ MemoryView(const void* DataPtr, size_t DataSize)
+ : m_Data(reinterpret_cast<const uint8_t*>(DataPtr))
+ , m_DataEnd(reinterpret_cast<const uint8_t*>(DataPtr) + DataSize)
+ {
+ }
+
+ MemoryView(const void* DataPtr, const void* DataEndPtr)
+ : m_Data(reinterpret_cast<const uint8_t*>(DataPtr))
+ , m_DataEnd(reinterpret_cast<const uint8_t*>(DataEndPtr))
+ {
+ }
+
+ inline bool Contains(const MemoryView& Other) const { return (m_Data <= Other.m_Data) && (m_DataEnd >= Other.m_DataEnd); }
+ inline bool IsEmpty() const { return m_Data == m_DataEnd; }
+ const void* GetData() const { return m_Data; }
+ const void* GetDataEnd() const { return m_DataEnd; }
+ size_t GetSize() const { return reinterpret_cast<const uint8_t*>(m_DataEnd) - reinterpret_cast<const uint8_t*>(m_Data); }
+ inline bool operator==(const MemoryView& Rhs) const { return m_Data == Rhs.m_Data && m_DataEnd == Rhs.m_DataEnd; }
+
+ inline bool EqualBytes(const MemoryView& InView) const
+ {
+ const size_t Size = GetSize();
+
+ return Size == InView.GetSize() && (memcmp(m_Data, InView.GetData(), Size) == 0);
+ }
+
+ inline MemoryView& operator+=(size_t InSize)
+ {
+ RightChopInline(InSize);
+ return *this;
+ }
+
+ /** Modifies the view by chopping the given number of bytes from the left. */
+ inline void RightChopInline(uint64_t InSize)
+ {
+ const uint64_t Offset = zen::Min(GetSize(), InSize);
+ m_Data = GetDataAtOffsetNoCheck(Offset);
+ }
+
+ inline MemoryView RightChop(uint64_t InSize)
+ {
+ MemoryView View(*this);
+ View.RightChopInline(InSize);
+ return View;
+ }
+
+ /** Returns the right-most part of the view by taking the given number of bytes from the right. */
+ [[nodiscard]] inline MemoryView Right(uint64_t InSize) const
+ {
+ MemoryView View(*this);
+ View.RightInline(InSize);
+ return View;
+ }
+
+ /** Modifies the view to be the given number of bytes from the right. */
+ inline void RightInline(uint64_t InSize)
+ {
+ const uint64_t OldSize = GetSize();
+ const uint64_t NewSize = zen::Min(OldSize, InSize);
+ m_Data = GetDataAtOffsetNoCheck(OldSize - NewSize);
+ m_DataEnd = m_Data + NewSize;
+ }
+
+ /** Returns the left-most part of the view by taking the given number of bytes from the left. */
+ inline MemoryView Left(uint64_t InSize) const
+ {
+ MemoryView View(*this);
+ View.LeftInline(InSize);
+ return View;
+ }
+
+ /** Modifies the view to be the given number of bytes from the left. */
+ inline void LeftInline(uint64_t InSize)
+ {
+ InSize = zen::Min(GetSize(), InSize);
+ m_DataEnd = zen::Min(m_DataEnd, m_Data + InSize);
+ }
+
+ /** Modifies the view to be the middle part by taking up to the given number of bytes from the given offset. */
+ inline void MidInline(uint64_t InOffset, uint64_t InSize = ~uint64_t(0))
+ {
+ RightChopInline(InOffset);
+ LeftInline(InSize);
+ }
+
+ /** Returns the middle part of the view by taking up to the given number of bytes from the given position. */
+ [[nodiscard]] inline MemoryView Mid(uint64_t InOffset, uint64_t InSize = ~uint64_t(0)) const
+ {
+ MemoryView View(*this);
+ View.MidInline(InOffset, InSize);
+ return View;
+ }
+
+ constexpr void Reset()
+ {
+ m_Data = nullptr;
+ m_DataEnd = nullptr;
+ }
+
+private:
+ const uint8_t* m_Data = nullptr;
+ const uint8_t* m_DataEnd = nullptr;
+
+ /** Returns the data pointer advanced by an offset in bytes. */
+ inline constexpr const uint8_t* GetDataAtOffsetNoCheck(uint64_t InOffset) const { return m_Data + InOffset; }
+};
+
+inline MutableMemoryView
+MutableMemoryView::CopyFrom(MemoryView InView) const
+{
+ ZEN_ASSERT(InView.GetSize() <= GetSize());
+ memcpy(m_Data, InView.GetData(), InView.GetSize());
+ return RightChop(InView.GetSize());
+}
+
+/** Advances the start of the view by an offset, which is clamped to stay within the view. */
+inline MemoryView
+operator+(const MemoryView& View, uint64_t Offset)
+{
+ return MemoryView(View) += Offset;
+}
+
+/** Advances the start of the view by an offset, which is clamped to stay within the view. */
+inline MemoryView
+operator+(uint64_t Offset, const MemoryView& View)
+{
+ return MemoryView(View) += Offset;
+}
+
+/** Advances the start of the view by an offset, which is clamped to stay within the view. */
+inline MutableMemoryView
+operator+(const MutableMemoryView& View, uint64_t Offset)
+{
+ return MutableMemoryView(View) += Offset;
+}
+
+/** Advances the start of the view by an offset, which is clamped to stay within the view. */
+inline MutableMemoryView
+operator+(uint64_t Offset, const MutableMemoryView& View)
+{
+ return MutableMemoryView(View) += Offset;
+}
+
+/**
+ * Make a non-owning view of the memory of the initializer list.
+ *
+ * This overload is only available when the element type does not need to be deduced.
+ */
+template<typename T>
+[[nodiscard]] inline MemoryView
+MakeMemoryView(std::initializer_list<typename std::type_identity<T>::type> List)
+{
+ return MemoryView(List.begin(), List.size() * sizeof(T));
+}
+
+/** Make a non-owning view of the memory of the contiguous container. */
+template<ContiguousRange R>
+[[nodiscard]] constexpr inline MemoryView
+MakeMemoryView(const R& Container)
+{
+ std::span Span = Container;
+ return MemoryView(Span.data(), Span.size() * sizeof(typename decltype(Span)::element_type));
+}
+
+/** Make a non-owning const view starting at Data and ending at DataEnd. */
+
+[[nodiscard]] inline MemoryView
+MakeMemoryView(const void* Data, const void* DataEnd)
+{
+ return MemoryView(Data, DataEnd);
+}
+
+[[nodiscard]] inline MemoryView
+MakeMemoryView(const void* Data, uint64_t Size)
+{
+ return MemoryView(Data, reinterpret_cast<const uint8_t*>(Data) + Size);
+}
+
+/**
+ * Make a non-owning mutable view of the memory of the initializer list.
+ *
+ * This overload is only available when the element type does not need to be deduced.
+ */
+template<typename T>
+[[nodiscard]] inline MutableMemoryView
+MakeMutableMemoryView(std::initializer_list<typename std::type_identity<T>::type> List)
+{
+ return MutableMemoryView(List.begin(), List.size() * sizeof(T));
+}
+
+/** Make a non-owning mutable view of the memory of the contiguous container. */
+template<ContiguousRange R>
+[[nodiscard]] constexpr inline MutableMemoryView
+MakeMutableMemoryView(R& Container)
+{
+ std::span Span = Container;
+ return MutableMemoryView(Span.data(), Span.size() * sizeof(typename decltype(Span)::element_type));
+}
+
+/** Make a non-owning mutable view starting at Data and ending at DataEnd. */
+
+[[nodiscard]] inline MutableMemoryView
+MakeMutableMemoryView(void* Data, void* DataEnd)
+{
+ return MutableMemoryView(Data, DataEnd);
+}
+
+void memory_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/meta.h b/src/zencore/include/zencore/meta.h
new file mode 100644
index 000000000..82eb5cc30
--- /dev/null
+++ b/src/zencore/include/zencore/meta.h
@@ -0,0 +1,30 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+/* This file contains utility functions for meta programming
+ *
+ * Since you're in here you're probably quite observant, and you'll
+ * note that it's quite barren here. This is because template
+ * metaprogramming is awful and I try not to engage in it. However,
+ * sometimes these things are forced upon us.
+ *
+ */
+
+namespace zen {
+
+/**
+ * Uses implicit conversion to create an instance of a specific type.
+ * Useful to make things clearer or circumvent unintended type deduction in templates.
+ * Safer than C casts and static_casts, e.g. does not allow down-casts
+ *
+ * @param Obj The object (usually pointer or reference) to convert.
+ *
+ * @return The object converted to the specified type.
+ */
+template<typename T>
+inline T
+ImplicitConv(typename std::type_identity<T>::type Obj)
+{
+ return Obj;
+}
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/mpscqueue.h b/src/zencore/include/zencore/mpscqueue.h
new file mode 100644
index 000000000..19e410d85
--- /dev/null
+++ b/src/zencore/include/zencore/mpscqueue.h
@@ -0,0 +1,110 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <atomic>
+#include <memory>
+#include <new>
+#include <optional>
+
+#ifdef __cpp_lib_hardware_interference_size
+using std::hardware_constructive_interference_size;
+using std::hardware_destructive_interference_size;
+#else
+// 64 bytes on x86-64 │ L1_CACHE_BYTES │ L1_CACHE_SHIFT │ __cacheline_aligned │ ...
+constexpr std::size_t hardware_constructive_interference_size = 64;
+constexpr std::size_t hardware_destructive_interference_size = 64;
+#endif
+
+namespace zen {
+
+/** An untyped array of data with compile-time alignment and size derived from another type. */
+template<typename ElementType>
+struct TypeCompatibleStorage
+{
+ ElementType* Data() { return (ElementType*)this; }
+ const ElementType* Data() const { return (const ElementType*)this; }
+
+ alignas(ElementType) char DataMember;
+};
+
+/** Fast multi-producer/single-consumer unbounded concurrent queue.
+
+ Based on http://www.1024cores.net/home/lock-free-algorithms/queues/non-intrusive-mpsc-node-based-queue
+ */
+
+template<typename T>
+class MpscQueue final
+{
+public:
+ using ElementType = T;
+
+ MpscQueue()
+ {
+ Node* Sentinel = new Node;
+ Head.store(Sentinel, std::memory_order_relaxed);
+ Tail = Sentinel;
+ }
+
+ ~MpscQueue()
+ {
+ Node* Next = Tail->Next.load(std::memory_order_relaxed);
+
+ // sentinel's value is already destroyed
+ delete Tail;
+
+ while (Next != nullptr)
+ {
+ Tail = Next;
+ Next = Tail->Next.load(std::memory_order_relaxed);
+
+ std::destroy_at((ElementType*)&Tail->Value);
+ delete Tail;
+ }
+ }
+
+ template<typename... ArgTypes>
+ void Enqueue(ArgTypes&&... Args)
+ {
+ Node* New = new Node;
+ new (&New->Value) ElementType(std::forward<ArgTypes>(Args)...);
+
+ Node* Prev = Head.exchange(New, std::memory_order_acq_rel);
+ Prev->Next.store(New, std::memory_order_release);
+ }
+
+ std::optional<ElementType> Dequeue()
+ {
+ Node* Next = Tail->Next.load(std::memory_order_acquire);
+
+ if (Next == nullptr)
+ {
+ return {};
+ }
+
+ ElementType* ValuePtr = (ElementType*)&Next->Value;
+ std::optional<ElementType> Res{std::move(*ValuePtr)};
+ std::destroy_at(ValuePtr);
+
+ delete Tail; // current sentinel
+
+ Tail = Next; // new sentinel
+ return Res;
+ }
+
+private:
+ struct Node
+ {
+ std::atomic<Node*> Next{nullptr};
+ TypeCompatibleStorage<ElementType> Value;
+ };
+
+private:
+ std::atomic<Node*> Head; // accessed only by producers
+ alignas(hardware_constructive_interference_size)
+ Node* Tail; // accessed only by consumer, hence should be on a different cache line than `Head`
+};
+
+void mpscqueue_forcelink();
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/refcount.h b/src/zencore/include/zencore/refcount.h
new file mode 100644
index 000000000..f0bb6b85e
--- /dev/null
+++ b/src/zencore/include/zencore/refcount.h
@@ -0,0 +1,186 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+#pragma once
+
+#include "atomic.h"
+#include "zencore.h"
+
+#include <compare>
+
+namespace zen {
+
+/**
+ * Helper base class for reference counted objects using intrusive reference counts
+ *
+ * This class is pretty straightforward but does one thing which may be unexpected:
+ *
+ * - Instances on the stack are initialized with a reference count of one to ensure
+ * nobody tries to accidentally delete it. (TODO: is this really useful?)
+ */
+class RefCounted
+{
+public:
+ RefCounted() = default;
+ virtual ~RefCounted() = default;
+
+ inline uint32_t AddRef() const { return AtomicIncrement(const_cast<RefCounted*>(this)->m_RefCount); }
+ inline uint32_t Release() const
+ {
+ uint32_t refCount = AtomicDecrement(const_cast<RefCounted*>(this)->m_RefCount);
+ if (refCount == 0)
+ {
+ delete this;
+ }
+ return refCount;
+ }
+
+ // Copying reference counted objects doesn't make a lot of sense generally, so let's prevent it
+
+ RefCounted(const RefCounted&) = delete;
+ RefCounted(RefCounted&&) = delete;
+ RefCounted& operator=(const RefCounted&) = delete;
+ RefCounted& operator=(RefCounted&&) = delete;
+
+protected:
+ inline uint32_t RefCount() const { return m_RefCount; }
+
+private:
+ uint32_t m_RefCount = 0;
+};
+
+/**
+ * Smart pointer for classes derived from RefCounted
+ */
+
+template<class T>
+class RefPtr
+{
+public:
+ inline RefPtr() = default;
+ inline RefPtr(const RefPtr& Rhs) : m_Ref(Rhs.m_Ref) { m_Ref && m_Ref->AddRef(); }
+ inline RefPtr(T* Ptr) : m_Ref(Ptr) { m_Ref && m_Ref->AddRef(); }
+ inline ~RefPtr() { m_Ref && m_Ref->Release(); }
+
+ [[nodiscard]] inline bool IsNull() const { return m_Ref == nullptr; }
+ inline explicit operator bool() const { return m_Ref != nullptr; }
+ inline operator T*() const { return m_Ref; }
+ inline T* operator->() const { return m_Ref; }
+
+ inline std::strong_ordering operator<=>(const RefPtr& Rhs) const = default;
+
+ inline RefPtr& operator=(T* Rhs)
+ {
+ Rhs && Rhs->AddRef();
+ m_Ref && m_Ref->Release();
+ m_Ref = Rhs;
+ return *this;
+ }
+ inline RefPtr& operator=(const RefPtr& Rhs)
+ {
+ if (&Rhs != this)
+ {
+ Rhs && Rhs->AddRef();
+ m_Ref && m_Ref->Release();
+ m_Ref = Rhs.m_Ref;
+ }
+ return *this;
+ }
+ inline RefPtr& operator=(RefPtr&& Rhs) noexcept
+ {
+ if (&Rhs != this)
+ {
+ m_Ref && m_Ref->Release();
+ m_Ref = Rhs.m_Ref;
+ Rhs.m_Ref = nullptr;
+ }
+ return *this;
+ }
+ template<typename OtherType>
+ inline RefPtr& operator=(RefPtr<OtherType>&& Rhs) noexcept
+ {
+ if ((RefPtr*)&Rhs != this)
+ {
+ m_Ref && m_Ref->Release();
+ m_Ref = Rhs.m_Ref;
+ Rhs.m_Ref = nullptr;
+ }
+ return *this;
+ }
+ inline RefPtr(RefPtr&& Rhs) noexcept : m_Ref(Rhs.m_Ref) { Rhs.m_Ref = nullptr; }
+ template<typename OtherType>
+ explicit inline RefPtr(RefPtr<OtherType>&& Rhs) noexcept : m_Ref(Rhs.m_Ref)
+ {
+ Rhs.m_Ref = nullptr;
+ }
+
+private:
+ T* m_Ref = nullptr;
+ template<typename U>
+ friend class RefPtr;
+};
+
+/**
+ * Smart pointer for classes derived from RefCounted
+ *
+ * This variant does not decay to a raw pointer
+ *
+ */
+
+template<class T>
+class Ref
+{
+public:
+ inline Ref() = default;
+ inline Ref(const Ref& Rhs) : m_Ref(Rhs.m_Ref) { m_Ref && m_Ref->AddRef(); }
+ inline explicit Ref(T* Ptr) : m_Ref(Ptr) { m_Ref && m_Ref->AddRef(); }
+ inline ~Ref() { m_Ref && m_Ref->Release(); }
+
+ template<typename DerivedType>
+ requires DerivedFrom<DerivedType, T>
+ inline Ref(const Ref<DerivedType>& Rhs) : Ref(Rhs.m_Ref) {}
+
+ [[nodiscard]] inline bool IsNull() const { return m_Ref == nullptr; }
+ inline explicit operator bool() const { return m_Ref != nullptr; }
+ inline T* operator->() const { return m_Ref; }
+ inline T* Get() const { return m_Ref; }
+
+ inline std::strong_ordering operator<=>(const Ref& Rhs) const = default;
+
+ inline Ref& operator=(T* Rhs)
+ {
+ Rhs && Rhs->AddRef();
+ m_Ref && m_Ref->Release();
+ m_Ref = Rhs;
+ return *this;
+ }
+ inline Ref& operator=(const Ref& Rhs)
+ {
+ if (&Rhs != this)
+ {
+ Rhs && Rhs->AddRef();
+ m_Ref && m_Ref->Release();
+ m_Ref = Rhs.m_Ref;
+ }
+ return *this;
+ }
+ inline Ref& operator=(Ref&& Rhs) noexcept
+ {
+ if (&Rhs != this)
+ {
+ m_Ref && m_Ref->Release();
+ m_Ref = Rhs.m_Ref;
+ Rhs.m_Ref = nullptr;
+ }
+ return *this;
+ }
+ inline Ref(Ref&& Rhs) noexcept : m_Ref(Rhs.m_Ref) { Rhs.m_Ref = nullptr; }
+
+private:
+ T* m_Ref = nullptr;
+
+ template<class U>
+ friend class Ref;
+};
+
+void refcount_forcelink();
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/scopeguard.h b/src/zencore/include/zencore/scopeguard.h
new file mode 100644
index 000000000..d04c8ed9c
--- /dev/null
+++ b/src/zencore/include/zencore/scopeguard.h
@@ -0,0 +1,45 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <type_traits>
+#include "logging.h"
+#include "zencore.h"
+
+namespace zen {
+
+template<typename T>
+class [[nodiscard]] ScopeGuardImpl
+{
+public:
+ inline ScopeGuardImpl(T&& func) : m_guardFunc(func) {}
+ ~ScopeGuardImpl()
+ {
+ if (!m_dismissed)
+ {
+ try
+ {
+ m_guardFunc();
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_ERROR("scope guard threw exception: '{}'", Ex.what());
+ }
+ }
+ }
+
+ void Dismiss() { m_dismissed = true; }
+
+private:
+ bool m_dismissed = false;
+ T m_guardFunc;
+};
+
+template<typename T>
+ScopeGuardImpl<T>
+MakeGuard(T&& fn)
+{
+ return ScopeGuardImpl<T>(std::move(fn));
+}
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/session.h b/src/zencore/include/zencore/session.h
new file mode 100644
index 000000000..dd90197bf
--- /dev/null
+++ b/src/zencore/include/zencore/session.h
@@ -0,0 +1,14 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+namespace zen {
+
+struct Oid;
+
+ZENCORE_API [[nodiscard]] Oid GetSessionId();
+ZENCORE_API [[nodiscard]] std::string_view GetSessionIdString();
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/sha1.h b/src/zencore/include/zencore/sha1.h
new file mode 100644
index 000000000..fc26f442b
--- /dev/null
+++ b/src/zencore/include/zencore/sha1.h
@@ -0,0 +1,76 @@
+// //////////////////////////////////////////////////////////
+// sha1.h
+// Copyright (c) 2014,2015 Stephan Brumme. All rights reserved.
+// see http://create.stephan-brumme.com/disclaimer.html
+//
+
+#pragma once
+
+#include <stdint.h>
+#include <compare>
+#include "zencore.h"
+
+namespace zen {
+
+class StringBuilderBase;
+
+struct SHA1
+{
+ uint8_t Hash[20];
+
+ inline auto operator<=>(const SHA1& rhs) const = default;
+
+ static const int StringLength = 40;
+ typedef char String_t[StringLength + 1];
+
+ static SHA1 HashMemory(const void* data, size_t byteCount);
+ static SHA1 FromHexString(const char* string);
+ const char* ToHexString(char* outString /* 40 characters + NUL terminator */) const;
+ StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const;
+
+ static SHA1 Zero; // Initialized to all zeroes
+};
+
+/**
+ * Utility class for computing SHA1 hashes
+ */
+class SHA1Stream
+{
+public:
+ SHA1Stream();
+
+ /** compute SHA1 of a memory block
+
+ \note SHA1 class contains a slightly more convenient helper function for this use case
+ \see SHA1::fromMemory()
+ */
+ SHA1 Compute(const void* data, size_t byteCount);
+
+ /// Begin streaming SHA1 compute (not needed on freshly constructed SHA1Stream instance)
+ void Reset();
+ /// Append another chunk
+ SHA1Stream& Append(const void* data, size_t byteCount);
+ /// Obtain final SHA1 hash. If you wish to reuse the SHA1Stream instance call reset()
+ SHA1 GetHash();
+
+private:
+ void ProcessBlock(const void* data);
+ void ProcessBuffer();
+
+ enum
+ {
+ /// split into 64 byte blocks (=> 512 bits)
+ BlockSize = 512 / 8,
+ HashBytes = 20,
+ HashValues = HashBytes / 4
+ };
+
+ uint64_t m_NumBytes; // size of processed data in bytes
+ size_t m_BufferSize; // valid bytes in m_buffer
+ uint8_t m_Buffer[BlockSize]; // bytes not processed yet
+ uint32_t m_Hash[HashValues];
+};
+
+void sha1_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/sharedbuffer.h b/src/zencore/include/zencore/sharedbuffer.h
new file mode 100644
index 000000000..97c5a9d21
--- /dev/null
+++ b/src/zencore/include/zencore/sharedbuffer.h
@@ -0,0 +1,167 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#include <zencore/iobuffer.h>
+#include <zencore/memory.h>
+#include <zencore/refcount.h>
+
+#include <memory.h>
+
+namespace zen {
+
+class SharedBuffer;
+
+/**
+ * Reference to a memory buffer with a single owner
+ *
+ * Internally
+ */
+class UniqueBuffer
+{
+public:
+ UniqueBuffer() = default;
+ UniqueBuffer(UniqueBuffer&&) = default;
+ UniqueBuffer& operator=(UniqueBuffer&&) = default;
+ UniqueBuffer(const UniqueBuffer&) = delete;
+ UniqueBuffer& operator=(const UniqueBuffer&) = delete;
+
+ ZENCORE_API explicit UniqueBuffer(IoBufferCore* Owner);
+
+ [[nodiscard]] void* GetData() { return m_Buffer ? m_Buffer->MutableDataPointer() : nullptr; }
+ [[nodiscard]] const void* GetData() const { return m_Buffer ? m_Buffer->DataPointer() : nullptr; }
+ [[nodiscard]] size_t GetSize() const { return m_Buffer ? m_Buffer->DataBytes() : 0; }
+
+ operator MutableMemoryView() { return GetMutableView(); }
+ operator MemoryView() const { return GetView(); }
+
+ /**
+ * Returns true if this does not point to a buffer owner.
+ *
+ * A null buffer is always owned, materialized, and empty.
+ */
+ [[nodiscard]] inline bool IsNull() const { return m_Buffer.IsNull(); }
+
+ /** Reset this to null. */
+ ZENCORE_API void Reset();
+
+ [[nodiscard]] inline MutableMemoryView GetMutableView() { return MutableMemoryView(GetData(), GetSize()); }
+ [[nodiscard]] inline MemoryView GetView() const { return MemoryView(GetData(), GetSize()); }
+
+ /** Make an uninitialized owned buffer of the specified size. */
+ [[nodiscard]] ZENCORE_API static UniqueBuffer Alloc(uint64_t Size);
+
+ /** Make a non-owned view of the input. */
+ [[nodiscard]] ZENCORE_API static UniqueBuffer MakeMutableView(void* DataPtr, uint64_t Size);
+
+ /**
+ * Convert this to an immutable shared buffer, leaving this null.
+ *
+ * Steals the buffer owner from the unique buffer.
+ */
+ [[nodiscard]] ZENCORE_API SharedBuffer MoveToShared();
+
+private:
+ // This may be null, for a default constructed UniqueBuffer only
+ RefPtr<IoBufferCore> m_Buffer;
+
+ friend class SharedBuffer;
+};
+
+/**
+ * Reference to a memory buffer with shared ownership
+ */
+class SharedBuffer
+{
+public:
+ SharedBuffer() = default;
+ ZENCORE_API explicit SharedBuffer(UniqueBuffer&&);
+ inline explicit SharedBuffer(IoBufferCore* Owner) : m_Buffer(Owner) {}
+ ZENCORE_API explicit SharedBuffer(IoBuffer&& Buffer) : m_Buffer(std::move(Buffer.m_Core)) {}
+ ZENCORE_API explicit SharedBuffer(const IoBuffer& Buffer) : m_Buffer(Buffer.m_Core) {}
+ ZENCORE_API explicit SharedBuffer(RefPtr<IoBufferCore>&& Owner) : m_Buffer(std::move(Owner)) {}
+
+ [[nodiscard]] const void* GetData() const
+ {
+ if (m_Buffer)
+ {
+ return m_Buffer->DataPointer();
+ }
+ return nullptr;
+ }
+
+ [[nodiscard]] size_t GetSize() const
+ {
+ if (m_Buffer)
+ {
+ return m_Buffer->DataBytes();
+ }
+ return 0;
+ }
+
+ inline void MakeImmutable()
+ {
+ ZEN_ASSERT(m_Buffer);
+ m_Buffer->SetIsImmutable(true);
+ }
+
+ /** Returns a buffer that is owned, by cloning if not owned. */
+ [[nodiscard]] ZENCORE_API SharedBuffer MakeOwned() const&;
+ [[nodiscard]] ZENCORE_API SharedBuffer MakeOwned() &&;
+
+ [[nodiscard]] bool IsOwned() const { return !m_Buffer || m_Buffer->IsOwned(); }
+ [[nodiscard]] inline bool IsNull() const { return !m_Buffer; }
+ inline void Reset() { m_Buffer = nullptr; }
+
+ [[nodiscard]] MemoryView GetView() const
+ {
+ if (m_Buffer)
+ {
+ return MemoryView(m_Buffer->DataPointer(), m_Buffer->DataBytes());
+ }
+ else
+ {
+ return MemoryView();
+ }
+ }
+ operator MemoryView() const { return GetView(); }
+
+ /** Returns true if this points to a buffer owner. */
+ [[nodiscard]] inline explicit operator bool() const { return !IsNull(); }
+
+ [[nodiscard]] inline IoBuffer AsIoBuffer() const { return IoBuffer(m_Buffer); }
+
+ SharedBuffer& operator=(UniqueBuffer&& Rhs)
+ {
+ m_Buffer = std::move(Rhs.m_Buffer);
+ return *this;
+ }
+
+ std::strong_ordering operator<=>(const SharedBuffer& Rhs) const = default;
+
+ /** Make a non-owned view of the input */
+ [[nodiscard]] inline static SharedBuffer MakeView(MemoryView View) { return MakeView(View.GetData(), View.GetSize()); }
+ /** Make a non-owning view of the memory of the contiguous container. */
+ [[nodiscard]] inline static SharedBuffer MakeView(const ContiguousRange auto& Container)
+ {
+ std::span Span = Container;
+ return MakeView(Span.data(), Span.size() * sizeof(typename decltype(Span)::element_type));
+ }
+ /** Make a non-owned view of the input */
+ [[nodiscard]] ZENCORE_API static SharedBuffer MakeView(const void* Data, uint64_t Size);
+ /** Make a non-owned view of the input */
+ [[nodiscard]] ZENCORE_API static SharedBuffer MakeView(MemoryView View, SharedBuffer OuterBuffer);
+ /** Make an owned clone of the buffer */
+ [[nodiscard]] ZENCORE_API SharedBuffer Clone();
+ /** Make an owned clone of the memory in the input view */
+ [[nodiscard]] ZENCORE_API static SharedBuffer Clone(MemoryView View);
+
+private:
+ RefPtr<IoBufferCore> m_Buffer;
+};
+
+void sharedbuffer_forcelink();
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/stats.h b/src/zencore/include/zencore/stats.h
new file mode 100644
index 000000000..1a0817b99
--- /dev/null
+++ b/src/zencore/include/zencore/stats.h
@@ -0,0 +1,295 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#include <atomic>
+#include <random>
+
+namespace zen {
+class CbObjectWriter;
+}
+
+namespace zen::metrics {
+
+template<typename T>
+class Gauge
+{
+public:
+ Gauge() : m_Value{0} {}
+
+ T Value() const { return m_Value; }
+ void SetValue(T Value) { m_Value = Value; }
+
+private:
+ std::atomic<T> m_Value;
+};
+
+/** Stats counter
+ *
+ * A counter is modified by adding or subtracting a value from a current value.
+ * This would typically be used to track number of requests in flight, number
+ * of active jobs etc
+ *
+ */
+class Counter
+{
+public:
+ inline void SetValue(uint64_t Value) { m_count = Value; }
+ inline uint64_t Value() const { return m_count; }
+
+ inline void Increment(int64_t AddValue) { m_count.fetch_add(AddValue); }
+ inline void Decrement(int64_t SubValue) { m_count.fetch_sub(SubValue); }
+ inline void Clear() { m_count.store(0, std::memory_order_release); }
+
+private:
+ std::atomic<uint64_t> m_count{0};
+};
+
+/** Exponential Weighted Moving Average
+
+ This is very raw, to use as little state as possible. If we
+ want to use this more broadly in user code we should perhaps
+ add a more user-friendly wrapper
+ */
+
+class RawEWMA
+{
+public:
+ /// <summary>
+ /// Update EWMA with new measure
+ /// </summary>
+ /// <param name="Alpha">Smoothing factor (between 0 and 1)</param>
+ /// <param name="Interval">Elapsed time since last</param>
+ /// <param name="Count">Value</param>
+ /// <param name="IsInitialUpdate">Whether this is the first update or not</param>
+ void Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate);
+ double Rate() const;
+
+private:
+ std::atomic<double> m_Rate = 0;
+};
+
+/// <summary>
+/// Tracks rate of events over time (i.e requests/sec), using
+/// exponential moving averages
+/// </summary>
+class Meter
+{
+public:
+ Meter();
+ ~Meter();
+
+ inline uint64_t Count() const { return m_TotalCount; }
+ double Rate1(); // One-minute rate
+ double Rate5(); // Five-minute rate
+ double Rate15(); // Fifteen-minute rate
+ double MeanRate() const; // Mean rate since instantiation of this meter
+ void Mark(uint64_t Count = 1); // Register one or more events
+
+private:
+ std::atomic<uint64_t> m_TotalCount{0}; // Accumulator counting number of marks since beginning
+ std::atomic<uint64_t> m_PendingCount{0}; // Pending EWMA update accumulator
+ std::atomic<uint64_t> m_StartTick{0}; // Time this was instantiated (for mean)
+ std::atomic<uint64_t> m_LastTick{0}; // Timestamp of last EWMA tick
+ std::atomic<int64_t> m_Remainder{0}; // Tracks the "modulo" of tick time
+ bool m_IsFirstTick = true;
+ RawEWMA m_RateM1;
+ RawEWMA m_RateM5;
+ RawEWMA m_RateM15;
+
+ void TickIfNecessary();
+ void Tick();
+};
+
+/** Moment-in-time snapshot of a distribution
+ */
+class SampleSnapshot
+{
+public:
+ SampleSnapshot(std::vector<double>&& Values);
+ ~SampleSnapshot();
+
+ uint32_t Size() const { return (uint32_t)m_Values.size(); }
+ double GetQuantileValue(double Quantile);
+ double GetMedian() { return GetQuantileValue(0.5); }
+ double Get75Percentile() { return GetQuantileValue(0.75); }
+ double Get95Percentile() { return GetQuantileValue(0.95); }
+ double Get98Percentile() { return GetQuantileValue(0.98); }
+ double Get99Percentile() { return GetQuantileValue(0.99); }
+ double Get999Percentile() { return GetQuantileValue(0.999); }
+ const std::vector<double>& GetValues() const;
+
+private:
+ std::vector<double> m_Values;
+};
+
+/** Randomly selects samples from a stream. Uses Vitter's
+ Algorithm R to produce a statistically representative sample.
+
+ http://www.cs.umd.edu/~samir/498/vitter.pdf - Random Sampling with a Reservoir
+ */
+
+class UniformSample
+{
+public:
+ UniformSample(uint32_t ReservoirSize);
+ ~UniformSample();
+
+ void Clear();
+ uint32_t Size() const;
+ void Update(int64_t Value);
+ SampleSnapshot Snapshot() const;
+
+ template<Invocable<int64_t> T>
+ void IterateValues(T Callback) const
+ {
+ for (const auto& Value : m_Values)
+ {
+ Callback(Value);
+ }
+ }
+
+private:
+ std::atomic<uint64_t> m_SampleCounter{0};
+ std::vector<std::atomic<int64_t>> m_Values;
+};
+
+/** Track (probabilistic) sample distribution along with min/max
+ */
+class Histogram
+{
+public:
+ Histogram(int32_t SampleCount = 1028);
+ ~Histogram();
+
+ void Clear();
+ void Update(int64_t Value);
+ int64_t Max() const;
+ int64_t Min() const;
+ double Mean() const;
+ uint64_t Count() const;
+ SampleSnapshot Snapshot() const { return m_Sample.Snapshot(); }
+
+private:
+ UniformSample m_Sample;
+ std::atomic<int64_t> m_Min{0};
+ std::atomic<int64_t> m_Max{0};
+ std::atomic<int64_t> m_Sum{0};
+ std::atomic<int64_t> m_Count{0};
+};
+
+/** Track timing and frequency of some operation
+
+ Example usage would be to track frequency and duration of network
+ requests, or function calls.
+
+ */
+class OperationTiming
+{
+public:
+ OperationTiming(int32_t SampleCount = 514);
+ ~OperationTiming();
+
+ void Update(int64_t Duration);
+ int64_t Max() const;
+ int64_t Min() const;
+ double Mean() const;
+ uint64_t Count() const;
+ SampleSnapshot Snapshot() const { return m_Histogram.Snapshot(); }
+
+ double Rate1() { return m_Meter.Rate1(); }
+ double Rate5() { return m_Meter.Rate5(); }
+ double Rate15() { return m_Meter.Rate15(); }
+ double MeanRate() const { return m_Meter.MeanRate(); }
+
+ struct Scope
+ {
+ Scope(OperationTiming& Outer);
+ ~Scope();
+
+ void Stop();
+ void Cancel();
+
+ private:
+ OperationTiming& m_Outer;
+ uint64_t m_StartTick;
+ };
+
+private:
+ Meter m_Meter;
+ Histogram m_Histogram;
+};
+
+/** Metrics for network requests
+
+ Aggregates tracking of duration, payload sizes into a single
+ class
+
+ */
+class RequestStats
+{
+public:
+ RequestStats(int32_t SampleCount = 514);
+ ~RequestStats();
+
+ void Update(int64_t Duration, int64_t Bytes);
+ uint64_t Count() const;
+
+ // Timing
+
+ int64_t MaxDuration() const { return m_BytesHistogram.Max(); }
+ int64_t MinDuration() const { return m_BytesHistogram.Min(); }
+ double MeanDuration() const { return m_BytesHistogram.Mean(); }
+ SampleSnapshot DurationSnapshot() const { return m_RequestTimeHistogram.Snapshot(); }
+ double Rate1() { return m_RequestMeter.Rate1(); }
+ double Rate5() { return m_RequestMeter.Rate5(); }
+ double Rate15() { return m_RequestMeter.Rate15(); }
+ double MeanRate() const { return m_RequestMeter.MeanRate(); }
+
+ // Bytes
+
+ int64_t MaxBytes() const { return m_BytesHistogram.Max(); }
+ int64_t MinBytes() const { return m_BytesHistogram.Min(); }
+ double MeanBytes() const { return m_BytesHistogram.Mean(); }
+ SampleSnapshot BytesSnapshot() const { return m_BytesHistogram.Snapshot(); }
+ double ByteRate1() { return m_BytesMeter.Rate1(); }
+ double ByteRate5() { return m_BytesMeter.Rate5(); }
+ double ByteRate15() { return m_BytesMeter.Rate15(); }
+ double ByteMeanRate() const { return m_BytesMeter.MeanRate(); }
+
+ struct Scope
+ {
+ Scope(OperationTiming& Outer);
+ ~Scope();
+
+ void Cancel();
+
+ private:
+ OperationTiming& m_Outer;
+ uint64_t m_StartTick;
+ };
+
+ void EmitSnapshot(std::string_view Tag, CbObjectWriter& Cbo);
+
+private:
+ Meter m_RequestMeter;
+ Meter m_BytesMeter;
+ Histogram m_RequestTimeHistogram;
+ Histogram m_BytesHistogram;
+};
+
+void EmitSnapshot(std::string_view Tag, OperationTiming& Stat, CbObjectWriter& Cbo);
+void EmitSnapshot(std::string_view Tag, const Histogram& Stat, CbObjectWriter& Cbo, double ConversionFactor);
+void EmitSnapshot(std::string_view Tag, Meter& Stat, CbObjectWriter& Cbo);
+
+void EmitSnapshot(const Histogram& Stat, CbObjectWriter& Cbo, double ConversionFactor);
+
+} // namespace zen::metrics
+
+namespace zen {
+
+extern void stats_forcelink();
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/stream.h b/src/zencore/include/zencore/stream.h
new file mode 100644
index 000000000..9e4996249
--- /dev/null
+++ b/src/zencore/include/zencore/stream.h
@@ -0,0 +1,90 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#include <zencore/memory.h>
+#include <zencore/thread.h>
+
+#include <vector>
+
+namespace zen {
+
+/**
+ * Binary stream writer
+ */
+
+class BinaryWriter
+{
+public:
+ inline BinaryWriter() = default;
+ ~BinaryWriter() = default;
+
+ inline void Write(const void* DataPtr, size_t ByteCount)
+ {
+ Write(DataPtr, ByteCount, m_Offset);
+ m_Offset += ByteCount;
+ }
+
+ inline void Write(MemoryView Memory) { Write(Memory.GetData(), Memory.GetSize()); }
+ void Write(std::initializer_list<const MemoryView> Buffers);
+
+ inline uint64_t CurrentOffset() const { return m_Offset; }
+
+ inline const uint8_t* Data() const { return m_Buffer.data(); }
+ inline const uint8_t* GetData() const { return m_Buffer.data(); }
+ inline uint64_t Size() const { return m_Buffer.size(); }
+ inline uint64_t GetSize() const { return m_Buffer.size(); }
+ void Reset();
+
+ inline MemoryView GetView() const { return MemoryView(m_Buffer.data(), m_Offset); }
+ inline MutableMemoryView GetMutableView() { return MutableMemoryView(m_Buffer.data(), m_Offset); }
+
+private:
+ std::vector<uint8_t> m_Buffer;
+ uint64_t m_Offset = 0;
+
+ void Write(const void* DataPtr, size_t ByteCount, uint64_t Offset);
+};
+
+inline MemoryView
+MakeMemoryView(const BinaryWriter& Stream)
+{
+ return MemoryView(Stream.Data(), Stream.Size());
+}
+
+/**
+ * Binary stream reader
+ */
+
+class BinaryReader
+{
+public:
+ inline BinaryReader(const void* Buffer, uint64_t Size) : m_BufferBase(reinterpret_cast<const uint8_t*>(Buffer)), m_BufferSize(Size) {}
+ inline BinaryReader(MemoryView Buffer)
+ : m_BufferBase(reinterpret_cast<const uint8_t*>(Buffer.GetData()))
+ , m_BufferSize(Buffer.GetSize())
+ {
+ }
+
+ inline void Read(void* DataPtr, size_t ByteCount)
+ {
+ memcpy(DataPtr, m_BufferBase + m_Offset, ByteCount);
+ m_Offset += ByteCount;
+ }
+
+ inline uint64_t Size() const { return m_BufferSize; }
+ inline uint64_t GetSize() const { return Size(); }
+ inline uint64_t CurrentOffset() const { return m_Offset; }
+ inline void Skip(size_t ByteCount) { m_Offset += ByteCount; };
+
+private:
+ const uint8_t* m_BufferBase;
+ uint64_t m_BufferSize;
+ uint64_t m_Offset = 0;
+};
+
+void stream_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/string.h b/src/zencore/include/zencore/string.h
new file mode 100644
index 000000000..ab111ff81
--- /dev/null
+++ b/src/zencore/include/zencore/string.h
@@ -0,0 +1,1115 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "intmath.h"
+#include "zencore.h"
+
+#include <stdint.h>
+#include <string.h>
+#include <charconv>
+#include <codecvt>
+#include <compare>
+#include <concepts>
+#include <optional>
+#include <span>
+#include <string_view>
+
+#include <type_traits>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+inline bool
+StringEquals(const char8_t* s1, const char* s2)
+{
+ return strcmp(reinterpret_cast<const char*>(s1), s2) == 0;
+}
+
+inline bool
+StringEquals(const char* s1, const char* s2)
+{
+ return strcmp(s1, s2) == 0;
+}
+
+inline size_t
+StringLength(const char* str)
+{
+ return strlen(str);
+}
+
+inline bool
+StringEquals(const wchar_t* s1, const wchar_t* s2)
+{
+ return wcscmp(s1, s2) == 0;
+}
+
+inline size_t
+StringLength(const wchar_t* str)
+{
+ return wcslen(str);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// File name helpers
+//
+
+ZENCORE_API const char* FilepathFindExtension(const std::string_view& path, const char* extensionToMatch = nullptr);
+
+//////////////////////////////////////////////////////////////////////////
+// Text formatting of numbers
+//
+
+ZENCORE_API bool ToString(std::span<char> Buffer, uint64_t Num);
+ZENCORE_API bool ToString(std::span<char> Buffer, int64_t Num);
+
+struct TextNumBase
+{
+ inline const char* c_str() const { return m_Buffer; }
+ inline operator std::string_view() const { return std::string_view(m_Buffer); }
+
+protected:
+ char m_Buffer[24];
+};
+
+struct IntNum : public TextNumBase
+{
+ inline IntNum(UnsignedIntegral auto Number) { ToString(m_Buffer, uint64_t(Number)); }
+ inline IntNum(SignedIntegral auto Number) { ToString(m_Buffer, int64_t(Number)); }
+};
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Quick-and-dirty string builder. Good enough for me, but contains traps
+// and not-quite-ideal behaviour especially when mixing character types etc
+//
+
+template<typename C>
+class StringBuilderImpl
+{
+public:
+ StringBuilderImpl() = default;
+ ZENCORE_API ~StringBuilderImpl();
+
+ StringBuilderImpl(const StringBuilderImpl&) = delete;
+ StringBuilderImpl(const StringBuilderImpl&&) = delete;
+ const StringBuilderImpl& operator=(const StringBuilderImpl&) = delete;
+ const StringBuilderImpl& operator=(const StringBuilderImpl&&) = delete;
+
+ inline size_t AddUninitialized(size_t Count)
+ {
+ EnsureCapacity(Count);
+ const size_t OldCount = Size();
+ m_CurPos += Count;
+ return OldCount;
+ }
+
+ StringBuilderImpl& Append(C OneChar)
+ {
+ EnsureCapacity(1);
+
+ *m_CurPos++ = OneChar;
+
+ return *this;
+ }
+
+ inline StringBuilderImpl& AppendAscii(const std::string_view& String)
+ {
+ const size_t len = String.size();
+
+ EnsureCapacity(len);
+
+ for (size_t i = 0; i < len; ++i)
+ m_CurPos[i] = String[i];
+
+ m_CurPos += len;
+
+ return *this;
+ }
+
+ inline StringBuilderImpl& AppendAscii(const std::u8string_view& String)
+ {
+ const size_t len = String.size();
+
+ EnsureCapacity(len);
+
+ for (size_t i = 0; i < len; ++i)
+ m_CurPos[i] = String[i];
+
+ m_CurPos += len;
+
+ return *this;
+ }
+
+ inline StringBuilderImpl& AppendAscii(const char* NulTerminatedString)
+ {
+ size_t StringLen = StringLength(NulTerminatedString);
+
+ return AppendAscii({NulTerminatedString, StringLen});
+ }
+
+ inline StringBuilderImpl& Append(const char8_t* NulTerminatedString)
+ {
+ // This is super hacky and not fully functional - needs better
+ // solution
+ if constexpr (sizeof(C) == 1)
+ {
+ size_t len = StringLength((const char*)NulTerminatedString);
+
+ EnsureCapacity(len);
+
+ for (size_t i = 0; i < len; ++i)
+ m_CurPos[i] = C(NulTerminatedString[i]);
+
+ m_CurPos += len;
+ }
+ else
+ {
+ ZEN_NOT_IMPLEMENTED();
+ }
+
+ return *this;
+ }
+
+ inline StringBuilderImpl& AppendAsciiRange(const char* BeginString, const char* EndString)
+ {
+ EnsureCapacity(EndString - BeginString);
+
+ while (BeginString != EndString)
+ *m_CurPos++ = *BeginString++;
+
+ return *this;
+ }
+
+ inline StringBuilderImpl& Append(const C* NulTerminatedString)
+ {
+ size_t Len = StringLength(NulTerminatedString);
+
+ EnsureCapacity(Len);
+ memcpy(m_CurPos, NulTerminatedString, Len * sizeof(C));
+ m_CurPos += Len;
+
+ return *this;
+ }
+
+ inline StringBuilderImpl& Append(const C* NulTerminatedString, size_t MaxChars)
+ {
+ size_t len = Min(MaxChars, StringLength(NulTerminatedString));
+
+ EnsureCapacity(len);
+ memcpy(m_CurPos, NulTerminatedString, len * sizeof(C));
+ m_CurPos += len;
+
+ return *this;
+ }
+
+ inline StringBuilderImpl& AppendRange(const C* BeginString, const C* EndString)
+ {
+ size_t Len = EndString - BeginString;
+
+ EnsureCapacity(Len);
+ memcpy(m_CurPos, BeginString, Len * sizeof(C));
+ m_CurPos += Len;
+
+ return *this;
+ }
+
+ inline StringBuilderImpl& Append(const std::basic_string_view<C>& String)
+ {
+ return AppendRange(String.data(), String.data() + String.size());
+ }
+
+ inline StringBuilderImpl& AppendBool(bool v)
+ {
+ // This is a method instead of a << operator overload as the latter can
+ // easily get called with non-bool types like pointers. It is a very
+ // subtle behaviour that can cause bugs.
+ using namespace std::literals;
+ if (v)
+ {
+ return AppendAscii("true"sv);
+ }
+ return AppendAscii("false"sv);
+ }
+
+ inline void RemoveSuffix(uint32_t Count)
+ {
+ ZEN_ASSERT(Count <= Size());
+ m_CurPos -= Count;
+ }
+
+ inline const C* c_str() const
+ {
+ EnsureNulTerminated();
+ return m_Base;
+ }
+
+ inline C* Data()
+ {
+ EnsureNulTerminated();
+ return m_Base;
+ }
+
+ inline const C* Data() const
+ {
+ EnsureNulTerminated();
+ return m_Base;
+ }
+
+ inline size_t Size() const { return m_CurPos - m_Base; }
+ inline bool IsDynamic() const { return m_IsDynamic; }
+ inline void Reset() { m_CurPos = m_Base; }
+
+ inline StringBuilderImpl& operator<<(uint64_t n)
+ {
+ IntNum Str(n);
+ return AppendAscii(Str);
+ }
+ inline StringBuilderImpl& operator<<(int64_t n)
+ {
+ IntNum Str(n);
+ return AppendAscii(Str);
+ }
+ inline StringBuilderImpl& operator<<(uint32_t n)
+ {
+ IntNum Str(n);
+ return AppendAscii(Str);
+ }
+ inline StringBuilderImpl& operator<<(int32_t n)
+ {
+ IntNum Str(n);
+ return AppendAscii(Str);
+ }
+ inline StringBuilderImpl& operator<<(uint16_t n)
+ {
+ IntNum Str(n);
+ return AppendAscii(Str);
+ }
+ inline StringBuilderImpl& operator<<(int16_t n)
+ {
+ IntNum Str(n);
+ return AppendAscii(Str);
+ }
+ inline StringBuilderImpl& operator<<(uint8_t n)
+ {
+ IntNum Str(n);
+ return AppendAscii(Str);
+ }
+ inline StringBuilderImpl& operator<<(int8_t n)
+ {
+ IntNum Str(n);
+ return AppendAscii(Str);
+ }
+
+ inline StringBuilderImpl& operator<<(const char* str) { return AppendAscii(str); }
+ inline StringBuilderImpl& operator<<(const std::string_view str) { return AppendAscii(str); }
+ inline StringBuilderImpl& operator<<(const std::u8string_view str) { return AppendAscii(str); }
+
+protected:
+ inline void Init(C* Base, size_t Capacity)
+ {
+ m_Base = m_CurPos = Base;
+ m_End = Base + Capacity;
+ }
+
+ inline void EnsureNulTerminated() const { *m_CurPos = '\0'; }
+
+ inline void EnsureCapacity(size_t ExtraRequired)
+ {
+ // precondition: we know the current buffer has enough capacity
+ // for the existing string including NUL terminator
+
+ if ((m_CurPos + ExtraRequired) < m_End)
+ return;
+
+ Extend(ExtraRequired);
+ }
+
+ ZENCORE_API void Extend(size_t ExtraCapacity);
+ ZENCORE_API void* AllocBuffer(size_t ByteCount);
+ ZENCORE_API void FreeBuffer(void* Buffer, size_t ByteCount);
+
+ ZENCORE_API [[noreturn]] void Fail(const char* FailReason); // note: throws exception
+
+ C* m_Base;
+ C* m_CurPos;
+ C* m_End;
+ bool m_IsDynamic = false;
+ bool m_IsExtendable = false;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+extern template class StringBuilderImpl<char>;
+
+inline StringBuilderImpl<char>&
+operator<<(StringBuilderImpl<char>& Builder, char Char)
+{
+ return Builder.Append(Char);
+}
+
+class StringBuilderBase : public StringBuilderImpl<char>
+{
+public:
+ inline StringBuilderBase(char* bufferPointer, size_t bufferCapacity) { Init(bufferPointer, bufferCapacity); }
+ inline ~StringBuilderBase() = default;
+
+ // Note that we don't need a terminator for the string_view so we avoid calling data() here
+ inline operator std::string_view() const { return std::string_view(m_Base, m_CurPos - m_Base); }
+ inline std::string_view ToView() const { return std::string_view(m_Base, m_CurPos - m_Base); }
+ inline std::string ToString() const { return std::string{Data(), Size()}; }
+
+ inline void AppendCodepoint(uint32_t cp)
+ {
+ if (cp < 0x80) // one octet
+ {
+ Append(static_cast<char8_t>(cp));
+ }
+ else if (cp < 0x800)
+ {
+ EnsureCapacity(2); // two octets
+ m_CurPos[0] = static_cast<char8_t>((cp >> 6) | 0xc0);
+ m_CurPos[1] = static_cast<char8_t>((cp & 0x3f) | 0x80);
+ m_CurPos += 2;
+ }
+ else if (cp < 0x10000)
+ {
+ EnsureCapacity(3); // three octets
+ m_CurPos[0] = static_cast<char8_t>((cp >> 12) | 0xe0);
+ m_CurPos[1] = static_cast<char8_t>(((cp >> 6) & 0x3f) | 0x80);
+ m_CurPos[2] = static_cast<char8_t>((cp & 0x3f) | 0x80);
+ m_CurPos += 3;
+ }
+ else
+ {
+ EnsureCapacity(4); // four octets
+ m_CurPos[0] = static_cast<char8_t>((cp >> 18) | 0xf0);
+ m_CurPos[1] = static_cast<char8_t>(((cp >> 12) & 0x3f) | 0x80);
+ m_CurPos[2] = static_cast<char8_t>(((cp >> 6) & 0x3f) | 0x80);
+ m_CurPos[3] = static_cast<char8_t>((cp & 0x3f) | 0x80);
+ m_CurPos += 4;
+ }
+ }
+};
+
+template<size_t N>
+class StringBuilder : public StringBuilderBase
+{
+public:
+ inline StringBuilder() : StringBuilderBase(m_StringBuffer, sizeof m_StringBuffer) {}
+ inline ~StringBuilder() = default;
+
+private:
+ char m_StringBuffer[N];
+};
+
+template<size_t N>
+class ExtendableStringBuilder : public StringBuilderBase
+{
+public:
+ inline ExtendableStringBuilder() : StringBuilderBase(m_StringBuffer, sizeof m_StringBuffer) { m_IsExtendable = true; }
+ inline ~ExtendableStringBuilder() = default;
+
+private:
+ char m_StringBuffer[N];
+};
+
+template<size_t N>
+class WriteToString : public ExtendableStringBuilder<N>
+{
+public:
+ template<typename... ArgTypes>
+ explicit WriteToString(ArgTypes&&... Args)
+ {
+ (*this << ... << std::forward<ArgTypes>(Args));
+ }
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+extern template class StringBuilderImpl<wchar_t>;
+
+class WideStringBuilderBase : public StringBuilderImpl<wchar_t>
+{
+public:
+ inline WideStringBuilderBase(wchar_t* BufferPointer, size_t BufferCapacity) { Init(BufferPointer, BufferCapacity); }
+ inline ~WideStringBuilderBase() = default;
+
+ inline operator std::wstring_view() const { return std::wstring_view{Data(), Size()}; }
+ inline std::wstring_view ToView() const { return std::wstring_view{Data(), Size()}; }
+ inline std::wstring ToString() const { return std::wstring{Data(), Size()}; }
+
+ inline StringBuilderImpl& operator<<(const std::wstring_view str) { return Append((const wchar_t*)str.data(), str.size()); }
+ inline StringBuilderImpl& operator<<(const wchar_t* str) { return Append(str); }
+ using StringBuilderImpl:: operator<<;
+};
+
+template<size_t N>
+class WideStringBuilder : public WideStringBuilderBase
+{
+public:
+ inline WideStringBuilder() : WideStringBuilderBase(m_Buffer, N) {}
+ ~WideStringBuilder() = default;
+
+private:
+ wchar_t m_Buffer[N];
+};
+
+template<size_t N>
+class ExtendableWideStringBuilder : public WideStringBuilderBase
+{
+public:
+ inline ExtendableWideStringBuilder() : WideStringBuilderBase(m_Buffer, N) { m_IsExtendable = true; }
+ ~ExtendableWideStringBuilder() = default;
+
+private:
+ wchar_t m_Buffer[N];
+};
+
+template<size_t N>
+class WriteToWideString : public ExtendableWideStringBuilder<N>
+{
+public:
+ template<typename... ArgTypes>
+ explicit WriteToWideString(ArgTypes&&... Args)
+ {
+ (*this << ... << Forward<ArgTypes>(Args));
+ }
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+void Utf8ToWide(const char8_t* str, WideStringBuilderBase& out);
+void Utf8ToWide(const std::u8string_view& wstr, WideStringBuilderBase& out);
+void Utf8ToWide(const std::string_view& wstr, WideStringBuilderBase& out);
+std::wstring Utf8ToWide(const std::string_view& wstr);
+
+void WideToUtf8(const wchar_t* wstr, StringBuilderBase& out);
+std::string WideToUtf8(const wchar_t* wstr);
+void WideToUtf8(const std::wstring_view& wstr, StringBuilderBase& out);
+std::string WideToUtf8(const std::wstring_view Wstr);
+
+inline uint8_t
+Char2Nibble(char c)
+{
+ if (c >= '0' && c <= '9')
+ {
+ return uint8_t(c - '0');
+ }
+ if (c >= 'a' && c <= 'f')
+ {
+ return uint8_t(c - 'a' + 10);
+ }
+ if (c >= 'A' && c <= 'F')
+ {
+ return uint8_t(c - 'A' + 10);
+ }
+ return uint8_t(0xff);
+};
+
+static constexpr const char HexChars[] = "0123456789abcdef";
+
+/// <summary>
+/// Parse hex string into a byte buffer
+/// </summary>
+/// <param name="string">Input string</param>
+/// <param name="characterCount">Number of characters in string</param>
+/// <param name="outPtr">Pointer to output buffer</param>
+/// <returns>true if the input consisted of all valid hexadecimal characters</returns>
+
+inline bool
+ParseHexBytes(const char* InputString, size_t CharacterCount, uint8_t* OutPtr)
+{
+ ZEN_ASSERT((CharacterCount & 1) == 0);
+
+ uint8_t allBits = 0;
+
+ while (CharacterCount)
+ {
+ uint8_t n0 = Char2Nibble(InputString[0]);
+ uint8_t n1 = Char2Nibble(InputString[1]);
+
+ allBits |= n0 | n1;
+
+ *OutPtr = (n0 << 4) | n1;
+
+ OutPtr += 1;
+ InputString += 2;
+ CharacterCount -= 2;
+ }
+
+ return (allBits & 0x80) == 0;
+}
+
+inline void
+ToHexBytes(const uint8_t* InputData, size_t ByteCount, char* OutString)
+{
+ while (ByteCount--)
+ {
+ uint8_t byte = *InputData++;
+
+ *OutString++ = HexChars[byte >> 4];
+ *OutString++ = HexChars[byte & 15];
+ }
+}
+
+inline bool
+ParseHexNumber(const char* InputString, size_t CharacterCount, uint8_t* OutPtr)
+{
+ ZEN_ASSERT((CharacterCount & 1) == 0);
+
+ uint8_t allBits = 0;
+
+ InputString += CharacterCount;
+ while (CharacterCount)
+ {
+ InputString -= 2;
+ uint8_t n0 = Char2Nibble(InputString[0]);
+ uint8_t n1 = Char2Nibble(InputString[1]);
+
+ allBits |= n0 | n1;
+
+ *OutPtr = (n0 << 4) | n1;
+
+ OutPtr += 1;
+ CharacterCount -= 2;
+ }
+
+ return (allBits & 0x80) == 0;
+}
+
+inline void
+ToHexNumber(const uint8_t* InputData, size_t ByteCount, char* OutString)
+{
+ InputData += ByteCount;
+ while (ByteCount--)
+ {
+ uint8_t byte = *(--InputData);
+
+ *OutString++ = HexChars[byte >> 4];
+ *OutString++ = HexChars[byte & 15];
+ }
+}
+
+/// <summary>
+/// Generates a hex number from a pointer to an integer type, this formats the number in the correct order for a hexadecimal number
+/// </summary>
+/// <param name="Value">Integer value type</param>
+/// <param name="outString">Output buffer where resulting string is written</param>
+void
+ToHexNumber(UnsignedIntegral auto Value, char* OutString)
+{
+ ToHexNumber((const uint8_t*)&Value, sizeof(Value), OutString);
+ OutString[sizeof(Value) * 2] = 0;
+}
+
+/// <summary>
+/// Parse hex number string into a value, this formats the number in the correct order for a hexadecimal number
+/// </summary>
+/// <param name="string">Input string</param>
+/// <param name="characterCount">Number of characters in string</param>
+/// <param name="OutValue">Pointer to output value</param>
+/// <returns>true if the input consisted of all valid hexadecimal characters</returns>
+bool
+ParseHexNumber(const std::string HexString, UnsignedIntegral auto& OutValue)
+{
+ return ParseHexNumber(HexString.c_str(), sizeof(OutValue) * 2, (uint8_t*)&OutValue);
+}
+
+//////////////////////////////////////////////////////////////////////////
+// Format numbers for humans
+//
+
+ZENCORE_API size_t NiceNumToBuffer(uint64_t Num, std::span<char> Buffer);
+ZENCORE_API size_t NiceBytesToBuffer(uint64_t Num, std::span<char> Buffer);
+ZENCORE_API size_t NiceByteRateToBuffer(uint64_t Num, uint64_t ms, std::span<char> Buffer);
+ZENCORE_API size_t NiceLatencyNsToBuffer(uint64_t NanoSeconds, std::span<char> Buffer);
+ZENCORE_API size_t NiceTimeSpanMsToBuffer(uint64_t Milliseconds, std::span<char> Buffer);
+
+struct NiceBase
+{
+ inline const char* c_str() const { return m_Buffer; }
+ inline operator std::string_view() const { return std::string_view(m_Buffer); }
+
+protected:
+ char m_Buffer[16];
+};
+
+struct NiceNum : public NiceBase
+{
+ inline NiceNum(uint64_t Num) { NiceNumToBuffer(Num, m_Buffer); }
+};
+
+struct NiceBytes : public NiceBase
+{
+ inline NiceBytes(uint64_t Num) { NiceBytesToBuffer(Num, m_Buffer); }
+};
+
+struct NiceByteRate : public NiceBase
+{
+ inline NiceByteRate(uint64_t Bytes, uint64_t TimeMilliseconds) { NiceByteRateToBuffer(Bytes, TimeMilliseconds, m_Buffer); }
+};
+
+struct NiceLatencyNs : public NiceBase
+{
+ inline NiceLatencyNs(uint64_t Milliseconds) { NiceLatencyNsToBuffer(Milliseconds, m_Buffer); }
+};
+
+struct NiceTimeSpanMs : public NiceBase
+{
+ inline NiceTimeSpanMs(uint64_t Milliseconds) { NiceTimeSpanMsToBuffer(Milliseconds, m_Buffer); }
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+inline std::string
+NiceRate(uint64_t Num, uint32_t DurationMilliseconds, const char* Unit = "B")
+{
+ char Buffer[32];
+
+ if (DurationMilliseconds)
+ {
+ // Leave a little of 'Buffer' for the "Unit/s" suffix
+ std::span<char> BufferSpan(Buffer, sizeof(Buffer) - 8);
+ NiceNumToBuffer(Num * 1000 / DurationMilliseconds, BufferSpan);
+ }
+ else
+ {
+ strcpy(Buffer, "0");
+ }
+
+ strncat(Buffer, Unit, 4);
+ strcat(Buffer, "/s");
+
+ return Buffer;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+template<Integral T>
+std::optional<T>
+ParseInt(const std::string_view& Input)
+{
+ T Out = 0;
+ const std::from_chars_result Result = std::from_chars(Input.data(), Input.data() + Input.size(), Out);
+ if (Result.ec == std::errc::invalid_argument || Result.ec == std::errc::result_out_of_range)
+ {
+ return std::nullopt;
+ }
+ return Out;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+constexpr uint32_t
+HashStringDjb2(const std::string_view& InString)
+{
+ uint32_t HashValue = 5381;
+
+ for (int CurChar : InString)
+ {
+ HashValue = HashValue * 33 + CurChar;
+ }
+
+ return HashValue;
+}
+
+constexpr uint32_t
+HashStringAsLowerDjb2(const std::string_view& InString)
+{
+ uint32_t HashValue = 5381;
+
+ for (uint8_t CurChar : InString)
+ {
+ CurChar -= ((CurChar - 'A') <= ('Z' - 'A')) * ('A' - 'a'); // this should be compiled into branchless logic
+ HashValue = HashValue * 33 + CurChar;
+ }
+
+ return HashValue;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+inline std::string
+ToLower(const std::string_view& InString)
+{
+ std::string Out(InString);
+
+ for (char& CurChar : Out)
+ {
+ CurChar -= (uint8_t(CurChar - 'A') <= ('Z' - 'A')) * ('A' - 'a'); // this should be compiled into branchless logic
+ }
+
+ return Out;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+template<typename Fn>
+uint32_t
+ForEachStrTok(const std::string_view& Str, char Delim, Fn&& Func)
+{
+ const char* It = Str.data();
+ const char* End = It + Str.length();
+ uint32_t Count = 0;
+
+ while (It != End)
+ {
+ if (*It == Delim)
+ {
+ It++;
+ continue;
+ }
+
+ std::string_view Remaining{It, size_t(ptrdiff_t(End - It))};
+ size_t Idx = Remaining.find(Delim, 0);
+
+ if (Idx == std::string_view::npos)
+ {
+ Idx = Remaining.size();
+ }
+
+ Count++;
+ std::string_view Token{It, Idx};
+ if (!Func(Token))
+ {
+ break;
+ }
+
+ It = It + Idx;
+ }
+
+ return Count;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+inline int32_t
+StrCaseCompare(const char* Lhs, const char* Rhs, int64_t Length = -1)
+{
+ // A helper for cross-platform case-insensitive string comparison.
+#if ZEN_PLATFORM_WINDOWS
+ return (Length < 0) ? _stricmp(Lhs, Rhs) : _strnicmp(Lhs, Rhs, size_t(Length));
+#else
+ return (Length < 0) ? strcasecmp(Lhs, Rhs) : strncasecmp(Lhs, Rhs, size_t(Length));
+#endif
+}
+
+/**
+ * @brief
+ * Helper function to implement case sensitive spaceship operator for strings.
+ * MacOS clang version we use does not implement <=> for std::string
+ * @param Lhs string
+ * @param Rhs string
+ * @return std::strong_ordering indicating relationship between Lhs and Rhs
+ */
+inline auto
+caseSensitiveCompareStrings(const std::string& Lhs, const std::string& Rhs)
+{
+ int r = Lhs.compare(Rhs);
+ return r == 0 ? std::strong_ordering::equal : r < 0 ? std::strong_ordering::less : std::strong_ordering::greater;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+/**
+ * ASCII character bitset useful for fast and readable parsing
+ *
+ * Entirely constexpr. Works with both wide and narrow strings.
+ *
+ * Example use cases:
+ *
+ * constexpr AsciiSet WhitespaceCharacters(" \v\f\t\r\n");
+ * bool bIsWhitespace = WhitespaceCharacters.Test(MyChar);
+ * const char* HelloWorld = AsciiSet::Skip(" \t\tHello world!", WhitespaceCharacters);
+ *
+ * constexpr AsciiSet XmlEscapeChars("&<>\"'");
+ * check(AsciiSet::HasNone(EscapedXmlString, XmlEscapeChars));
+ *
+ * constexpr AsciiSet Delimiters(".:;");
+ * const TCHAR* DelimiterOrEnd = AsciiSet::FindFirstOrEnd(PrefixedName, Delimiters);
+ * FString Prefix(PrefixedName, DelimiterOrEnd - PrefixedName);
+ *
+ * constexpr AsciiSet Slashes("/\\");
+ * const TCHAR* SlashOrEnd = AsciiSet::FindLastOrEnd(PathName, Slashes);
+ * const TCHAR* FileName = *SlashOrEnd ? SlashOrEnd + 1 : PathName;
+ */
+class AsciiSet
+{
+public:
+ template<typename CharType, int N>
+ constexpr AsciiSet(const CharType (&Chars)[N]) : AsciiSet(StringToBitset(Chars))
+ {
+ }
+
+ /** Returns true if a character is part of the set */
+ template<typename CharType>
+ constexpr inline bool Contains(CharType Char) const
+ {
+ using UnsignedCharType = typename std::make_unsigned<CharType>::type;
+
+ return !!TestImpl((UnsignedCharType)Char);
+ }
+
+ /** Returns non-zero if a character is part of the set. Prefer Contains() to avoid VS2019 conversion warnings. */
+ template<typename CharType>
+ constexpr inline uint64_t Test(CharType Char) const
+ {
+ using UnsignedCharType = typename std::make_unsigned<CharType>::type;
+
+ return TestImpl((UnsignedCharType)Char);
+ }
+
+ /** Create new set with specified character in it */
+ constexpr inline AsciiSet operator+(char Char) const
+ {
+ using UnsignedCharType = typename std::make_unsigned<char>::type;
+
+ InitData Bitset = {LoMask, HiMask};
+ SetImpl(Bitset, (UnsignedCharType)Char);
+ return AsciiSet(Bitset);
+ }
+
+ /** Create new set containing inverse set of characters - likely including null-terminator */
+ constexpr inline AsciiSet operator~() const { return AsciiSet(~LoMask, ~HiMask); }
+
+ ////////// Algorithms for C strings //////////
+
+ /** Find first character of string inside set or end pointer. Never returns null. */
+ template<class CharType>
+ static constexpr const CharType* FindFirstOrEnd(const CharType* Str, AsciiSet Set)
+ {
+ for (AsciiSet SetOrNil(Set.LoMask | NilMask, Set.HiMask); !SetOrNil.Test(*Str); ++Str)
+ ;
+
+ return Str;
+ }
+
+ /** Find last character of string inside set or end pointer. Never returns null. */
+ template<class CharType>
+ static constexpr const CharType* FindLastOrEnd(const CharType* Str, AsciiSet Set)
+ {
+ const CharType* Last = FindFirstOrEnd(Str, Set);
+
+ for (const CharType* It = Last; *It; It = FindFirstOrEnd(It + 1, Set))
+ {
+ Last = It;
+ }
+
+ return Last;
+ }
+
+ /** Find first character of string outside of set. Never returns null. */
+ template<typename CharType>
+ static constexpr const CharType* Skip(const CharType* Str, AsciiSet Set)
+ {
+ while (Set.Contains(*Str))
+ {
+ ++Str;
+ }
+
+ return Str;
+ }
+
+ /** Test if string contains any character in set */
+ template<typename CharType>
+ static constexpr bool HasAny(const CharType* Str, AsciiSet Set)
+ {
+ return *FindFirstOrEnd(Str, Set) != '\0';
+ }
+
+ /** Test if string contains no character in set */
+ template<typename CharType>
+ static constexpr bool HasNone(const CharType* Str, AsciiSet Set)
+ {
+ return *FindFirstOrEnd(Str, Set) == '\0';
+ }
+
+ /** Test if string contains any character outside of set */
+ template<typename CharType>
+ static constexpr bool HasOnly(const CharType* Str, AsciiSet Set)
+ {
+ return *Skip(Str, Set) == '\0';
+ }
+
+ ////////// Algorithms for string types like std::string_view and std::string //////////
+
+ /** Get initial substring with all characters in set */
+ template<class StringType>
+ static constexpr StringType FindPrefixWith(const StringType& Str, AsciiSet Set)
+ {
+ return Scan<EDir::Forward, EInclude::Members, EKeep::Head>(Str, Set);
+ }
+
+ /** Get initial substring with no characters in set */
+ template<class StringType>
+ static constexpr StringType FindPrefixWithout(const StringType& Str, AsciiSet Set)
+ {
+ return Scan<EDir::Forward, EInclude::NonMembers, EKeep::Head>(Str, Set);
+ }
+
+ /** Trim initial characters in set */
+ template<class StringType>
+ static constexpr StringType TrimPrefixWith(const StringType& Str, AsciiSet Set)
+ {
+ return Scan<EDir::Forward, EInclude::Members, EKeep::Tail>(Str, Set);
+ }
+
+ /** Trim initial characters not in set */
+ template<class StringType>
+ static constexpr StringType TrimPrefixWithout(const StringType& Str, AsciiSet Set)
+ {
+ return Scan<EDir::Forward, EInclude::NonMembers, EKeep::Tail>(Str, Set);
+ }
+
+ /** Get trailing substring with all characters in set */
+ template<class StringType>
+ static constexpr StringType FindSuffixWith(const StringType& Str, AsciiSet Set)
+ {
+ return Scan<EDir::Reverse, EInclude::Members, EKeep::Tail>(Str, Set);
+ }
+
+ /** Get trailing substring with no characters in set */
+ template<class StringType>
+ static constexpr StringType FindSuffixWithout(const StringType& Str, AsciiSet Set)
+ {
+ return Scan<EDir::Reverse, EInclude::NonMembers, EKeep::Tail>(Str, Set);
+ }
+
+ /** Trim trailing characters in set */
+ template<class StringType>
+ static constexpr StringType TrimSuffixWith(const StringType& Str, AsciiSet Set)
+ {
+ return Scan<EDir::Reverse, EInclude::Members, EKeep::Head>(Str, Set);
+ }
+
+ /** Trim trailing characters not in set */
+ template<class StringType>
+ static constexpr StringType TrimSuffixWithout(const StringType& Str, AsciiSet Set)
+ {
+ return Scan<EDir::Reverse, EInclude::NonMembers, EKeep::Head>(Str, Set);
+ }
+
+ /** Test if string contains any character in set */
+ template<class StringType>
+ static constexpr bool HasAny(const StringType& Str, AsciiSet Set)
+ {
+ return !HasNone(Str, Set);
+ }
+
+ /** Test if string contains no character in set */
+ template<class StringType>
+ static constexpr bool HasNone(const StringType& Str, AsciiSet Set)
+ {
+ uint64_t Match = 0;
+ for (auto Char : Str)
+ {
+ Match |= Set.Test(Char);
+ }
+ return Match == 0;
+ }
+
+ /** Test if string contains any character outside of set */
+ template<class StringType>
+ static constexpr bool HasOnly(const StringType& Str, AsciiSet Set)
+ {
+ auto End = Str.data() + Str.size();
+ return FindFirst<EInclude::Members>(Set, Str.data(), End) == End;
+ }
+
+private:
+ enum class EDir
+ {
+ Forward,
+ Reverse
+ };
+ enum class EInclude
+ {
+ Members,
+ NonMembers
+ };
+ enum class EKeep
+ {
+ Head,
+ Tail
+ };
+
+ template<EInclude Include, typename CharType>
+ static constexpr const CharType* FindFirst(AsciiSet Set, const CharType* It, const CharType* End)
+ {
+ for (; It != End && (Include == EInclude::Members) == !!Set.Test(*It); ++It)
+ ;
+ return It;
+ }
+
+ template<EInclude Include, typename CharType>
+ static constexpr const CharType* FindLast(AsciiSet Set, const CharType* It, const CharType* End)
+ {
+ for (; It != End && (Include == EInclude::Members) == !!Set.Test(*It); --It)
+ ;
+ return It;
+ }
+
+ template<EDir Dir, EInclude Include, EKeep Keep, class StringType>
+ static constexpr StringType Scan(const StringType& Str, AsciiSet Set)
+ {
+ auto Begin = Str.data();
+ auto End = Begin + Str.size();
+ auto It = Dir == EDir::Forward ? FindFirst<Include>(Set, Begin, End) : FindLast<Include>(Set, End - 1, Begin - 1) + 1;
+
+ return Keep == EKeep::Head ? StringType(Begin, static_cast<int32_t>(It - Begin)) : StringType(It, static_cast<int32_t>(End - It));
+ }
+
+ // Work-around for constexpr limitations
+ struct InitData
+ {
+ uint64_t Lo, Hi;
+ };
+ static constexpr uint64_t NilMask = uint64_t(1) << '\0';
+
+ static constexpr inline void SetImpl(InitData& Bitset, uint32_t Char)
+ {
+ uint64_t IsLo = uint64_t(0) - (Char >> 6 == 0);
+ uint64_t IsHi = uint64_t(0) - (Char >> 6 == 1);
+ uint64_t Bit = uint64_t(1) << uint8_t(Char & 0x3f);
+
+ Bitset.Lo |= Bit & IsLo;
+ Bitset.Hi |= Bit & IsHi;
+ }
+
+ constexpr inline uint64_t TestImpl(uint32_t Char) const
+ {
+ uint64_t IsLo = uint64_t(0) - (Char >> 6 == 0);
+ uint64_t IsHi = uint64_t(0) - (Char >> 6 == 1);
+ uint64_t Bit = uint64_t(1) << (Char & 0x3f);
+
+ return (Bit & IsLo & LoMask) | (Bit & IsHi & HiMask);
+ }
+
+ template<typename CharType, int N>
+ static constexpr InitData StringToBitset(const CharType (&Chars)[N])
+ {
+ using UnsignedCharType = typename std::make_unsigned<CharType>::type;
+
+ InitData Bitset = {0, 0};
+ for (int I = 0; I < N - 1; ++I)
+ {
+ SetImpl(Bitset, UnsignedCharType(Chars[I]));
+ }
+
+ return Bitset;
+ }
+
+ constexpr AsciiSet(InitData Bitset) : LoMask(Bitset.Lo), HiMask(Bitset.Hi) {}
+
+ constexpr AsciiSet(uint64_t Lo, uint64_t Hi) : LoMask(Lo), HiMask(Hi) {}
+
+ uint64_t LoMask, HiMask;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+void string_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/testing.h b/src/zencore/include/zencore/testing.h
new file mode 100644
index 000000000..a00ee3166
--- /dev/null
+++ b/src/zencore/include/zencore/testing.h
@@ -0,0 +1,67 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#include <memory>
+
+#ifndef ZEN_TEST_WITH_RUNNER
+# define ZEN_TEST_WITH_RUNNER 0
+#endif
+
+#if ZEN_TEST_WITH_RUNNER
+# define DOCTEST_CONFIG_IMPLEMENT
+#endif
+
+#if ZEN_WITH_TESTS
+# include <doctest/doctest.h>
+inline auto
+Approx(auto Value)
+{
+ return doctest::Approx(Value);
+}
+#endif
+
+/**
+ * Test runner helper
+ *
+ * This acts as a thin layer between the test app and the test
+ * framework, which is used to customize configuration logic
+ * and to set up logging.
+ *
+ * If you don't want to implement custom setup then the
+ * ZEN_RUN_TESTS macro can be used instead.
+ */
+
+#if ZEN_WITH_TESTS
+namespace zen::testing {
+
+class TestRunner
+{
+public:
+ TestRunner();
+ ~TestRunner();
+
+ int ApplyCommandLine(int argc, char const* const* argv);
+ int Run();
+
+private:
+ struct Impl;
+
+ std::unique_ptr<Impl> m_Impl;
+};
+
+# define ZEN_RUN_TESTS(argC, argV) \
+ [&] { \
+ zen::testing::TestRunner Runner; \
+ Runner.ApplyCommandLine(argC, argV); \
+ return Runner.Run(); \
+ }()
+
+} // namespace zen::testing
+#endif
+
+#if ZEN_TEST_WITH_RUNNER
+# undef DOCTEST_CONFIG_IMPLEMENT
+#endif
diff --git a/src/zencore/include/zencore/testutils.h b/src/zencore/include/zencore/testutils.h
new file mode 100644
index 000000000..04648c6de
--- /dev/null
+++ b/src/zencore/include/zencore/testutils.h
@@ -0,0 +1,32 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <filesystem>
+
+namespace zen {
+
+std::filesystem::path CreateTemporaryDirectory();
+
+class ScopedTemporaryDirectory
+{
+public:
+ explicit ScopedTemporaryDirectory(std::filesystem::path Directory);
+ ScopedTemporaryDirectory();
+ ~ScopedTemporaryDirectory();
+
+ std::filesystem::path& Path() { return m_RootPath; }
+
+private:
+ std::filesystem::path m_RootPath;
+};
+
+struct ScopedCurrentDirectoryChange
+{
+ std::filesystem::path OldPath{std::filesystem::current_path()};
+
+ ScopedCurrentDirectoryChange() { std::filesystem::current_path(CreateTemporaryDirectory()); }
+ ~ScopedCurrentDirectoryChange() { std::filesystem::current_path(OldPath); }
+};
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/thread.h b/src/zencore/include/zencore/thread.h
new file mode 100644
index 000000000..a9c96d422
--- /dev/null
+++ b/src/zencore/include/zencore/thread.h
@@ -0,0 +1,273 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#include <atomic>
+#include <filesystem>
+#include <shared_mutex>
+#include <string_view>
+#include <vector>
+
+namespace zen {
+
+void SetCurrentThreadName(std::string_view ThreadName);
+
+/**
+ * Reader-writer lock
+ *
+ * - A single thread may hold an exclusive lock at any given moment
+ *
+ * - Multiple threads may hold shared locks, but only if no thread has
+ * acquired an exclusive lock
+ */
+class RwLock
+{
+public:
+ ZENCORE_API void AcquireShared();
+ ZENCORE_API void ReleaseShared();
+
+ ZENCORE_API void AcquireExclusive();
+ ZENCORE_API void ReleaseExclusive();
+
+ struct SharedLockScope
+ {
+ SharedLockScope(RwLock& Lock) : m_Lock(&Lock) { Lock.AcquireShared(); }
+ ~SharedLockScope() { ReleaseNow(); }
+
+ void ReleaseNow()
+ {
+ if (m_Lock)
+ {
+ m_Lock->ReleaseShared();
+ m_Lock = nullptr;
+ }
+ }
+
+ private:
+ RwLock* m_Lock;
+ };
+
+ struct ExclusiveLockScope
+ {
+ ExclusiveLockScope(RwLock& Lock) : m_Lock(&Lock) { Lock.AcquireExclusive(); }
+ ~ExclusiveLockScope() { ReleaseNow(); }
+
+ void ReleaseNow()
+ {
+ if (m_Lock)
+ {
+ m_Lock->ReleaseExclusive();
+ m_Lock = nullptr;
+ }
+ }
+
+ private:
+ RwLock* m_Lock;
+ };
+
+private:
+ std::shared_mutex m_Mutex;
+};
+
+/** Basic abstraction of a simple event synchronization mechanism (aka 'binary semaphore')
+ */
+class Event
+{
+public:
+ ZENCORE_API Event();
+ ZENCORE_API ~Event();
+
+ Event(Event&& Rhs) noexcept : m_EventHandle(Rhs.m_EventHandle) { Rhs.m_EventHandle = nullptr; }
+
+ Event(const Event& Rhs) = delete;
+ Event& operator=(const Event& Rhs) = delete;
+
+ inline Event& operator=(Event&& Rhs) noexcept
+ {
+ std::swap(m_EventHandle, Rhs.m_EventHandle);
+ return *this;
+ }
+
+ ZENCORE_API void Set();
+ ZENCORE_API void Reset();
+ ZENCORE_API bool Wait(int TimeoutMs = -1);
+ ZENCORE_API void Close();
+
+protected:
+ explicit Event(void* EventHandle) : m_EventHandle(EventHandle) {}
+
+ void* m_EventHandle = nullptr;
+};
+
+/** Basic abstraction of an IPC mechanism (aka 'binary semaphore')
+ */
+class NamedEvent
+{
+public:
+ NamedEvent() = default;
+ ZENCORE_API explicit NamedEvent(std::string_view EventName);
+ ZENCORE_API ~NamedEvent();
+ ZENCORE_API void Close();
+ ZENCORE_API void Set();
+ ZENCORE_API bool Wait(int TimeoutMs = -1);
+
+ NamedEvent(NamedEvent&& Rhs) noexcept : m_EventHandle(Rhs.m_EventHandle) { Rhs.m_EventHandle = nullptr; }
+
+ inline NamedEvent& operator=(NamedEvent&& Rhs) noexcept
+ {
+ std::swap(m_EventHandle, Rhs.m_EventHandle);
+ return *this;
+ }
+
+protected:
+ void* m_EventHandle = nullptr;
+
+private:
+ NamedEvent(const NamedEvent& Rhs) = delete;
+ NamedEvent& operator=(const NamedEvent& Rhs) = delete;
+};
+
+/** Basic abstraction of a named (system wide) mutex primitive
+ */
+class NamedMutex
+{
+public:
+ ~NamedMutex();
+
+ ZENCORE_API [[nodiscard]] bool Create(std::string_view MutexName);
+
+ ZENCORE_API static bool Exists(std::string_view MutexName);
+
+private:
+ void* m_MutexHandle = nullptr;
+};
+
+/**
+ * Downward counter of type std::ptrdiff_t which can be used to synchronize threads
+ */
+class Latch
+{
+public:
+ Latch(std::ptrdiff_t Count) : Counter(Count) {}
+
+ void CountDown()
+ {
+ std::ptrdiff_t Old = Counter.fetch_sub(1);
+ if (Old == 1)
+ {
+ Complete.Set();
+ }
+ }
+
+ std::ptrdiff_t Remaining() const { return Counter.load(); }
+
+ // If you want to add dynamic count, make sure to set the initial counter to 1
+ // and then do a CountDown() just before wait to not trigger the event causing
+ // false positive completion results.
+ void AddCount(std::ptrdiff_t Count)
+ {
+ std::atomic_ptrdiff_t Old = Counter.fetch_add(Count);
+ ZEN_ASSERT_SLOW(Old > 0);
+ }
+
+ bool Wait(int TimeoutMs = -1)
+ {
+ std::ptrdiff_t Old = Counter.load();
+ if (Old == 0)
+ {
+ return true;
+ }
+ return Complete.Wait(TimeoutMs);
+ }
+
+private:
+ std::atomic_ptrdiff_t Counter;
+ Event Complete;
+};
+
+/** Basic process abstraction
+ */
+class ProcessHandle
+{
+public:
+ ZENCORE_API ProcessHandle();
+
+ ProcessHandle(const ProcessHandle&) = delete;
+ ProcessHandle& operator=(const ProcessHandle&) = delete;
+
+ ZENCORE_API ~ProcessHandle();
+
+ ZENCORE_API void Initialize(int Pid);
+ ZENCORE_API void Initialize(void* ProcessHandle); /// Initialize with an existing handle - takes ownership of the handle
+ ZENCORE_API [[nodiscard]] bool IsRunning() const;
+ ZENCORE_API [[nodiscard]] bool IsValid() const;
+ ZENCORE_API bool Wait(int TimeoutMs = -1);
+ ZENCORE_API void Terminate(int ExitCode);
+ ZENCORE_API void Reset();
+ [[nodiscard]] inline int Pid() const { return m_Pid; }
+
+private:
+ void* m_ProcessHandle = nullptr;
+ int m_Pid = 0;
+};
+
+/** Basic process creation
+ */
+struct CreateProcOptions
+{
+ enum
+ {
+ Flag_NewConsole = 1 << 0,
+ Flag_Elevated = 1 << 1,
+ Flag_Unelevated = 1 << 2,
+ };
+
+ const std::filesystem::path* WorkingDirectory = nullptr;
+ uint32_t Flags = 0;
+};
+
+#if ZEN_PLATFORM_WINDOWS
+using CreateProcResult = void*; // handle to the process
+#else
+using CreateProcResult = int32_t; // pid
+#endif
+
+ZENCORE_API CreateProcResult CreateProc(const std::filesystem::path& Executable,
+ std::string_view CommandLine, // should also include arg[0] (executable name)
+ const CreateProcOptions& Options = {});
+
+/** Process monitor - monitors a list of running processes via polling
+
+ Intended to be used to monitor a set of "sponsor" processes, where
+ we need to determine when none of them remain alive
+
+ */
+
+class ProcessMonitor
+{
+public:
+ ProcessMonitor();
+ ~ProcessMonitor();
+
+ ZENCORE_API bool IsRunning();
+ ZENCORE_API void AddPid(int Pid);
+ ZENCORE_API bool IsActive() const;
+
+private:
+ using HandleType = void*;
+
+ mutable RwLock m_Lock;
+ std::vector<HandleType> m_ProcessHandles;
+};
+
+ZENCORE_API bool IsProcessRunning(int pid);
+ZENCORE_API int GetCurrentProcessId();
+ZENCORE_API int GetCurrentThreadId();
+
+ZENCORE_API void Sleep(int ms);
+
+void thread_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/timer.h b/src/zencore/include/zencore/timer.h
new file mode 100644
index 000000000..e4ddc3505
--- /dev/null
+++ b/src/zencore/include/zencore/timer.h
@@ -0,0 +1,58 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#if ZEN_COMPILER_MSC
+# include <intrin.h>
+#elif ZEN_ARCH_X64
+# include <x86intrin.h>
+#endif
+
+#include <stdint.h>
+
+namespace zen {
+
+// High frequency timers
+
+ZENCORE_API uint64_t GetHifreqTimerValue();
+ZENCORE_API uint64_t GetHifreqTimerFrequency();
+ZENCORE_API double GetHifreqTimerToSeconds();
+ZENCORE_API uint64_t GetHifreqTimerFrequencySafe(); // May be used during static init
+
+class Stopwatch
+{
+public:
+ inline Stopwatch() : m_StartValue(GetHifreqTimerValue()) {}
+
+ inline uint64_t GetElapsedTimeMs() const { return (GetHifreqTimerValue() - m_StartValue) * 1'000 / GetHifreqTimerFrequency(); }
+ inline uint64_t GetElapsedTimeUs() const { return (GetHifreqTimerValue() - m_StartValue) * 1'000'000 / GetHifreqTimerFrequency(); }
+ inline uint64_t GetElapsedTicks() const { return GetHifreqTimerValue() - m_StartValue; }
+ inline void Reset() { m_StartValue = GetHifreqTimerValue(); }
+
+ static inline uint64_t GetElapsedTimeMs(uint64_t Ticks) { return Ticks * 1'000 / GetHifreqTimerFrequency(); }
+ static inline uint64_t GetElapsedTimeUs(uint64_t Ticks) { return Ticks * 1'000'000 / GetHifreqTimerFrequency(); }
+
+private:
+ uint64_t m_StartValue;
+};
+
+// Low frequency timers
+
+namespace detail {
+ extern ZENCORE_API uint64_t g_LofreqTimerValue;
+} // namespace detail
+
+inline uint64_t
+GetLofreqTimerValue()
+{
+ return detail::g_LofreqTimerValue;
+}
+
+ZENCORE_API void UpdateLofreqTimerValue();
+ZENCORE_API uint64_t GetLofreqTimerFrequency();
+
+void timer_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/trace.h b/src/zencore/include/zencore/trace.h
new file mode 100644
index 000000000..0af490f23
--- /dev/null
+++ b/src/zencore/include/zencore/trace.h
@@ -0,0 +1,36 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+/* clang-format off */
+
+#include <zencore/zencore.h>
+
+#if ZEN_WITH_TRACE
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#if !defined(TRACE_IMPLEMENT)
+# define TRACE_IMPLEMENT 0
+#endif
+#include <trace.h>
+#undef TRACE_IMPLEMENT
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#define ZEN_TRACE_CPU(x) TRACE_CPU_SCOPE(x)
+
+enum class TraceType
+{
+ File,
+ Network,
+ None
+};
+
+void TraceInit(const char* HostOrPath, TraceType Type);
+
+#else
+
+#define ZEN_TRACE_CPU(x)
+
+#endif // ZEN_WITH_TRACE
+
+/* clang-format on */
diff --git a/src/zencore/include/zencore/uid.h b/src/zencore/include/zencore/uid.h
new file mode 100644
index 000000000..9659f5893
--- /dev/null
+++ b/src/zencore/include/zencore/uid.h
@@ -0,0 +1,87 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+#include <compare>
+
+namespace zen {
+
+class StringBuilderBase;
+
+/** Object identifier
+
+ Can be used as a GUID essentially, but is more compact (12 bytes) and as such
+ is more susceptible to collisions than a 16-byte GUID but also I don't expect
+ the population to be large so in practice the risk should be minimal due to
+ how the identifiers work.
+
+ Similar in spirit to MongoDB ObjectId
+
+ When serialized, object identifiers generated in a given session in sequence
+ will sort in chronological order since the timestamp is in the MSB in big
+ endian format. This makes it suitable as a database key since most indexing
+ structures work better when keys are inserted in lexicographically
+ increasing order.
+
+ The current layout is basically:
+
+ |----------------|----------------|----------------|
+ | timestamp | serial # | run id |
+ |----------------|----------------|----------------|
+ MSB LSB
+
+ - Timestamp is a unsigned 32-bit value (seconds since 00:00:00 Jan 1 2021)
+ - Serial # is another unsigned 32-bit value which is assigned a (strong)
+ random number at initialization time which is incremented when a new Oid
+ is generated
+ - The run id is generated from a strong random number generator
+ at initialization time and stays fixed for the duration of the program
+
+ Timestamp and serial are stored in memory in such a way that they can be
+ ordered lexicographically. I.e they are in big-endian byte order.
+
+ NOTE: The information above is only meant to explain the properties of
+ the identifiers. Client code should simply treat the identifier as an
+ opaque value and may not make any assumptions on the structure, as there
+ may be other ways of generating the identifiers in the future if an
+ application benefits.
+
+ */
+
+struct Oid
+{
+ static const int StringLength = 24;
+ typedef char String_t[StringLength + 1];
+
+ static void Initialize();
+ [[nodiscard]] static Oid NewOid();
+
+ const Oid& Generate();
+ [[nodiscard]] static Oid FromHexString(const std::string_view String);
+ StringBuilderBase& ToString(StringBuilderBase& OutString) const;
+ void ToString(char OutString[StringLength]);
+ [[nodiscard]] static Oid FromMemory(const void* Ptr);
+
+ auto operator<=>(const Oid& rhs) const = default;
+ [[nodiscard]] inline operator bool() const { return *this != Zero; }
+
+ static const Oid Zero; // Min (can be used to signify a "null" value, or for open range queries)
+ static const Oid Max; // Max (can be used for open range queries)
+
+ struct Hasher
+ {
+ size_t operator()(const Oid& id) const
+ {
+ const size_t seed = id.OidBits[0];
+ return ((seed << 6) + (seed >> 2) + 0x9e3779b9 + uint64_t(id.OidBits[1])) | (uint64_t(id.OidBits[2]) << 32);
+ }
+ };
+
+ // You should not assume anything about these words
+ uint32_t OidBits[3];
+};
+
+extern void uid_forcelink();
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/varint.h b/src/zencore/include/zencore/varint.h
new file mode 100644
index 000000000..e57e1d497
--- /dev/null
+++ b/src/zencore/include/zencore/varint.h
@@ -0,0 +1,277 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "intmath.h"
+
+#include <algorithm>
+
+namespace zen {
+
+// Variable-Length Integer Encoding
+//
+// ZigZag encoding is used to convert signed integers into unsigned integers in a way that allows
+// integers with a small magnitude to have a smaller encoded representation.
+//
+// An unsigned integer is encoded into 1-9 bytes based on its magnitude. The first byte indicates
+// how many additional bytes are used by the number of leading 1-bits that it has. The additional
+// bytes are stored in big endian order, and the most significant bits of the value are stored in
+// the remaining bits in the first byte. The encoding of the first byte allows the reader to skip
+// over the encoded integer without consuming its bytes individually.
+//
+// Encoded unsigned integers sort the same in a byte-wise comparison as when their decoded values
+// are compared. The same property does not hold for signed integers due to ZigZag encoding.
+//
+// 32-bit inputs encode to 1-5 bytes.
+// 64-bit inputs encode to 1-9 bytes.
+//
+// 0x0000'0000'0000'0000 - 0x0000'0000'0000'007f : 0b0_______ 1 byte
+// 0x0000'0000'0000'0080 - 0x0000'0000'0000'3fff : 0b10______ 2 bytes
+// 0x0000'0000'0000'4000 - 0x0000'0000'001f'ffff : 0b110_____ 3 bytes
+// 0x0000'0000'0020'0000 - 0x0000'0000'0fff'ffff : 0b1110____ 4 bytes
+// 0x0000'0000'1000'0000 - 0x0000'0007'ffff'ffff : 0b11110___ 5 bytes
+// 0x0000'0008'0000'0000 - 0x0000'03ff'ffff'ffff : 0b111110__ 6 bytes
+// 0x0000'0400'0000'0000 - 0x0001'ffff'ffff'ffff : 0b1111110_ 7 bytes
+// 0x0002'0000'0000'0000 - 0x00ff'ffff'ffff'ffff : 0b11111110 8 bytes
+// 0x0100'0000'0000'0000 - 0xffff'ffff'ffff'ffff : 0b11111111 9 bytes
+//
+// Encoding Examples
+// -42 => ZigZag => 0x53 => 0x53
+// 42 => ZigZag => 0x54 => 0x54
+// 0x1 => 0x01
+// 0x12 => 0x12
+// 0x123 => 0x81 0x23
+// 0x1234 => 0x92 0x34
+// 0x12345 => 0xc1 0x23 0x45
+// 0x123456 => 0xd2 0x34 0x56
+// 0x1234567 => 0xe1 0x23 0x45 0x67
+// 0x12345678 => 0xf0 0x12 0x34 0x56 0x78
+// 0x123456789 => 0xf1 0x23 0x45 0x67 0x89
+// 0x123456789a => 0xf8 0x12 0x34 0x56 0x78 0x9a
+// 0x123456789ab => 0xfb 0x23 0x45 0x67 0x89 0xab
+// 0x123456789abc => 0xfc 0x12 0x34 0x56 0x78 0x9a 0xbc
+// 0x123456789abcd => 0xfd 0x23 0x45 0x67 0x89 0xab 0xcd
+// 0x123456789abcde => 0xfe 0x12 0x34 0x56 0x78 0x9a 0xbc 0xde
+// 0x123456789abcdef => 0xff 0x01 0x23 0x45 0x67 0x89 0xab 0xcd 0xef
+// 0x123456789abcdef0 => 0xff 0x12 0x34 0x56 0x78 0x9a 0xbc 0xde 0xf0
+
+/**
+ * Measure the length in bytes (1-9) of an encoded variable-length integer.
+ *
+ * @param InData A variable-length encoding of an (signed or unsigned) integer.
+ * @return The number of bytes used to encode the integer, in the range 1-9.
+ */
+inline uint32_t
+MeasureVarUInt(const void* InData)
+{
+ return CountLeadingZeros(uint8_t(~*static_cast<const uint8_t*>(InData))) - 23;
+}
+
+/** Measure the length in bytes (1-9) of an encoded variable-length integer. \see \ref MeasureVarUInt */
+inline uint32_t
+MeasureVarInt(const void* InData)
+{
+ return MeasureVarUInt(InData);
+}
+
+/** Measure the number of bytes (1-5) required to encode the 32-bit input. */
+inline uint32_t
+MeasureVarUInt(uint32_t InValue)
+{
+ return uint32_t(int32_t(FloorLog2(InValue)) / 7 + 1);
+}
+
+/** Measure the number of bytes (1-9) required to encode the 64-bit input. */
+inline uint32_t
+MeasureVarUInt(uint64_t InValue)
+{
+ return uint32_t(std::min(int32_t(FloorLog2_64(InValue)) / 7 + 1, 9));
+}
+
+/** Measure the number of bytes (1-5) required to encode the 32-bit input. \see \ref MeasureVarUInt */
+inline uint32_t
+MeasureVarInt(int32_t InValue)
+{
+ return MeasureVarUInt(uint32_t((InValue >> 31) ^ (InValue << 1)));
+}
+
+/** Measure the number of bytes (1-9) required to encode the 64-bit input. \see \ref MeasureVarUInt */
+inline uint32_t
+MeasureVarInt(int64_t InValue)
+{
+ return MeasureVarUInt(uint64_t((InValue >> 63) ^ (InValue << 1)));
+}
+
+/**
+ * Read a variable-length unsigned integer.
+ *
+ * @param InData A variable-length encoding of an unsigned integer.
+ * @param OutByteCount The number of bytes consumed from the input.
+ * @return An unsigned integer.
+ */
+inline uint64_t
+ReadVarUInt(const void* InData, uint32_t& OutByteCount)
+{
+ const uint32_t ByteCount = MeasureVarUInt(InData);
+ OutByteCount = ByteCount;
+
+ const uint8_t* InBytes = static_cast<const uint8_t*>(InData);
+ uint64_t Value = *InBytes++ & uint8_t(0xff >> ByteCount);
+ switch (ByteCount - 1)
+ {
+ case 8:
+ Value <<= 8;
+ Value |= *InBytes++;
+ [[fallthrough]];
+ case 7:
+ Value <<= 8;
+ Value |= *InBytes++;
+ [[fallthrough]];
+ case 6:
+ Value <<= 8;
+ Value |= *InBytes++;
+ [[fallthrough]];
+ case 5:
+ Value <<= 8;
+ Value |= *InBytes++;
+ [[fallthrough]];
+ case 4:
+ Value <<= 8;
+ Value |= *InBytes++;
+ [[fallthrough]];
+ case 3:
+ Value <<= 8;
+ Value |= *InBytes++;
+ [[fallthrough]];
+ case 2:
+ Value <<= 8;
+ Value |= *InBytes++;
+ [[fallthrough]];
+ case 1:
+ Value <<= 8;
+ Value |= *InBytes++;
+ [[fallthrough]];
+ default:
+ return Value;
+ }
+}
+
+/**
+ * Read a variable-length signed integer.
+ *
+ * @param InData A variable-length encoding of a signed integer.
+ * @param OutByteCount The number of bytes consumed from the input.
+ * @return A signed integer.
+ */
+inline int64_t
+ReadVarInt(const void* InData, uint32_t& OutByteCount)
+{
+ const uint64_t Value = ReadVarUInt(InData, OutByteCount);
+ return -int64_t(Value & 1) ^ int64_t(Value >> 1);
+}
+
+/**
+ * Write a variable-length unsigned integer.
+ *
+ * @param InValue An unsigned integer to encode.
+ * @param OutData A buffer of at least 5 bytes to write the output to.
+ * @return The number of bytes used in the output.
+ */
+inline uint32_t
+WriteVarUInt(uint32_t InValue, void* OutData)
+{
+ const uint32_t ByteCount = MeasureVarUInt(InValue);
+ uint8_t* OutBytes = static_cast<uint8_t*>(OutData) + ByteCount - 1;
+ switch (ByteCount - 1)
+ {
+ case 4:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ case 3:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ case 2:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ case 1:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ default:
+ break;
+ }
+ *OutBytes = uint8_t(0xff << (9 - ByteCount)) | uint8_t(InValue);
+ return ByteCount;
+}
+
+/**
+ * Write a variable-length unsigned integer.
+ *
+ * @param InValue An unsigned integer to encode.
+ * @param OutData A buffer of at least 9 bytes to write the output to.
+ * @return The number of bytes used in the output.
+ */
+inline uint32_t
+WriteVarUInt(uint64_t InValue, void* OutData)
+{
+ const uint32_t ByteCount = MeasureVarUInt(InValue);
+ uint8_t* OutBytes = static_cast<uint8_t*>(OutData) + ByteCount - 1;
+ switch (ByteCount - 1)
+ {
+ case 8:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ case 7:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ case 6:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ case 5:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ case 4:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ case 3:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ case 2:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ case 1:
+ *OutBytes-- = uint8_t(InValue);
+ InValue >>= 8;
+ [[fallthrough]];
+ default:
+ break;
+ }
+ *OutBytes = uint8_t(0xff << (9 - ByteCount)) | uint8_t(InValue);
+ return ByteCount;
+}
+
+/** Write a variable-length signed integer. \see \ref WriteVarUInt */
+inline uint32_t
+WriteVarInt(int32_t InValue, void* OutData)
+{
+ const uint32_t Value = uint32_t((InValue >> 31) ^ (InValue << 1));
+ return WriteVarUInt(Value, OutData);
+}
+
+/** Write a variable-length signed integer. \see \ref WriteVarUInt */
+inline uint32_t
+WriteVarInt(int64_t InValue, void* OutData)
+{
+ const uint64_t Value = uint64_t((InValue >> 63) ^ (InValue << 1));
+ return WriteVarUInt(Value, OutData);
+}
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/windows.h b/src/zencore/include/zencore/windows.h
new file mode 100644
index 000000000..91828f0ec
--- /dev/null
+++ b/src/zencore/include/zencore/windows.h
@@ -0,0 +1,25 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+
+struct IUnknown; // Workaround for "combaseapi.h(229): error C2187: syntax error: 'identifier' was unexpected here" when using /permissive-
+#ifndef NOMINMAX
+# define NOMINMAX // We don't want your min/max macros
+#endif
+#ifndef NOGDI
+# define NOGDI // We don't want your GetObject define
+#endif
+#ifndef WIN32_LEAN_AND_MEAN
+# define WIN32_LEAN_AND_MEAN
+#endif
+#ifndef _WIN32_WINNT
+# define _WIN32_WINNT 0x0A00
+#endif
+#include <windows.h>
+#undef GetObject
+
+ZEN_THIRD_PARTY_INCLUDES_END
diff --git a/src/zencore/include/zencore/workthreadpool.h b/src/zencore/include/zencore/workthreadpool.h
new file mode 100644
index 000000000..0ddc65298
--- /dev/null
+++ b/src/zencore/include/zencore/workthreadpool.h
@@ -0,0 +1,48 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#include <zencore/blockingqueue.h>
+#include <zencore/refcount.h>
+
+#include <exception>
+#include <functional>
+#include <system_error>
+#include <thread>
+#include <vector>
+
+namespace zen {
+
+struct IWork : public RefCounted
+{
+ virtual void Execute() = 0;
+
+ inline std::exception_ptr GetException() { return m_Exception; }
+
+private:
+ std::exception_ptr m_Exception;
+
+ friend class WorkerThreadPool;
+};
+
+class WorkerThreadPool
+{
+public:
+ WorkerThreadPool(int InThreadCount);
+ ~WorkerThreadPool();
+
+ void ScheduleWork(Ref<IWork> Work);
+ void ScheduleWork(std::function<void()>&& Work);
+
+ [[nodiscard]] size_t PendingWork() const;
+
+private:
+ void WorkerThreadFunction();
+
+ std::vector<std::thread> m_WorkerThreads;
+ BlockingQueue<Ref<IWork>> m_WorkQueue;
+};
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/xxhash.h b/src/zencore/include/zencore/xxhash.h
new file mode 100644
index 000000000..04872f4c3
--- /dev/null
+++ b/src/zencore/include/zencore/xxhash.h
@@ -0,0 +1,89 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zencore.h"
+
+#include <zencore/memory.h>
+
+#include <xxh3.h>
+
+#include <compare>
+#include <string_view>
+
+namespace zen {
+
+class StringBuilderBase;
+
+/**
+ * XXH3 hash
+ */
+struct XXH3_128
+{
+ uint8_t Hash[16];
+
+ static XXH3_128 MakeFrom(const void* data /* 16 bytes */)
+ {
+ XXH3_128 Xx;
+ memcpy(Xx.Hash, data, sizeof Xx);
+ return Xx;
+ }
+
+ static inline XXH3_128 HashMemory(const void* data, size_t byteCount)
+ {
+ XXH3_128 Hash;
+ XXH128_canonicalFromHash((XXH128_canonical_t*)Hash.Hash, XXH3_128bits(data, byteCount));
+ return Hash;
+ }
+ static XXH3_128 HashMemory(MemoryView Data) { return HashMemory(Data.GetData(), Data.GetSize()); }
+ static XXH3_128 FromHexString(const char* string);
+ static XXH3_128 FromHexString(const std::string_view string);
+ const char* ToHexString(char* outString /* 32 characters + NUL terminator */) const;
+ StringBuilderBase& ToHexString(StringBuilderBase& outBuilder) const;
+
+ static const int StringLength = 32;
+ typedef char String_t[StringLength + 1];
+
+ static XXH3_128 Zero; // Initialized to all zeros
+
+ inline auto operator<=>(const XXH3_128& rhs) const = default;
+
+ struct Hasher
+ {
+ size_t operator()(const XXH3_128& v) const
+ {
+ size_t h;
+ memcpy(&h, v.Hash, sizeof h);
+ return h;
+ }
+ };
+};
+
+struct XXH3_128Stream
+{
+ /// Begin streaming hash compute (not needed on freshly constructed instance)
+ void Reset() { memset(&m_State, 0, sizeof m_State); }
+
+ /// Append another chunk
+ XXH3_128Stream& Append(const void* Data, size_t ByteCount)
+ {
+ XXH3_128bits_update(&m_State, Data, ByteCount);
+ return *this;
+ }
+
+ /// Append another chunk
+ XXH3_128Stream& Append(MemoryView Data) { return Append(Data.GetData(), Data.GetSize()); }
+
+ /// Obtain final hash. If you wish to reuse the instance call reset()
+ XXH3_128 GetHash()
+ {
+ XXH3_128 Hash;
+ XXH128_canonicalFromHash((XXH128_canonical_t*)Hash.Hash, XXH3_128bits_digest(&m_State));
+ return Hash;
+ }
+
+private:
+ XXH3_state_s m_State{};
+};
+
+} // namespace zen
diff --git a/src/zencore/include/zencore/zencore.h b/src/zencore/include/zencore/zencore.h
new file mode 100644
index 000000000..5bcd77239
--- /dev/null
+++ b/src/zencore/include/zencore/zencore.h
@@ -0,0 +1,383 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <cinttypes>
+#include <stdexcept>
+#include <string>
+#include <version>
+
+#ifndef ZEN_WITH_TESTS
+# define ZEN_WITH_TESTS 1
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// Platform
+//
+
+#define ZEN_PLATFORM_WINDOWS 0
+#define ZEN_PLATFORM_LINUX 0
+#define ZEN_PLATFORM_MAC 0
+
+#ifdef _WIN32
+# undef ZEN_PLATFORM_WINDOWS
+# define ZEN_PLATFORM_WINDOWS 1
+#elif defined(__linux__)
+# undef ZEN_PLATFORM_LINUX
+# define ZEN_PLATFORM_LINUX 1
+#elif defined(__APPLE__)
+# undef ZEN_PLATFORM_MAC
+# define ZEN_PLATFORM_MAC 1
+#endif
+
+#if ZEN_PLATFORM_WINDOWS
+# if !defined(NOMINMAX)
+# define NOMINMAX // stops Windows.h from defining 'min/max' macros
+# endif
+# if !defined(NOGDI)
+# define NOGDI
+# endif
+# if !defined(WIN32_LEAN_AND_MEAN)
+# define WIN32_LEAN_AND_MEAN // cut-down what Windows.h defines
+# endif
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// Compiler
+//
+
+#define ZEN_COMPILER_CLANG 0
+#define ZEN_COMPILER_MSC 0
+#define ZEN_COMPILER_GCC 0
+
+// Clang can define __GNUC__ and/or _MSC_VER so we check for Clang first
+#ifdef __clang__
+# undef ZEN_COMPILER_CLANG
+# define ZEN_COMPILER_CLANG 1
+#elif defined(_MSC_VER)
+# undef ZEN_COMPILER_MSC
+# define ZEN_COMPILER_MSC 1
+#elif defined(__GNUC__)
+# undef ZEN_COMPILER_GCC
+# define ZEN_COMPILER_GCC 1
+#else
+# error Unknown compiler
+#endif
+
+#if ZEN_COMPILER_MSC
+# pragma warning(disable : 4324) // warning C4324: '<type>': structure was padded due to alignment specifier
+# pragma warning(default : 4668) // warning C4668: 'symbol' is not defined as a preprocessor macro, replacing with '0' for 'directives'
+# pragma warning(default : 4100) // warning C4100: 'identifier' : unreferenced formal parameter
+#endif
+
+#ifndef ZEN_THIRD_PARTY_INCLUDES_START
+# if ZEN_COMPILER_MSC
+# define ZEN_THIRD_PARTY_INCLUDES_START \
+ __pragma(warning(push)) __pragma(warning(disable : 4668)) /* C4668: use of undefined preprocessor macro */ \
+ __pragma(warning(disable : 4305)) /* C4305: 'if': truncation from 'uint32' to 'bool' */ \
+ __pragma(warning(disable : 4267)) /* C4267: '=': conversion from 'size_t' to 'US' */ \
+ __pragma(warning(disable : 4127)) /* C4127: conditional expression is constant */ \
+ __pragma(warning(disable : 4189)) /* C4189: local variable is initialized but not referenced */
+# elif ZEN_COMPILER_CLANG
+# define ZEN_THIRD_PARTY_INCLUDES_START \
+ _Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wundef\"") \
+ _Pragma("clang diagnostic ignored \"-Wunused-parameter\"") _Pragma("clang diagnostic ignored \"-Wunused-variable\"")
+# elif ZEN_COMPILER_GCC
+# define ZEN_THIRD_PARTY_INCLUDES_START \
+ _Pragma("GCC diagnostic push") /* NB. ignoring -Wundef doesn't work with GCC */ \
+ _Pragma("GCC diagnostic ignored \"-Wunused-parameter\"") _Pragma("GCC diagnostic ignored \"-Wunused-variable\"")
+# endif
+#endif
+
+#ifndef ZEN_THIRD_PARTY_INCLUDES_END
+# if ZEN_COMPILER_MSC
+# define ZEN_THIRD_PARTY_INCLUDES_END __pragma(warning(pop))
+# elif ZEN_COMPILER_CLANG
+# define ZEN_THIRD_PARTY_INCLUDES_END _Pragma("clang diagnostic pop")
+# elif ZEN_COMPILER_GCC
+# define ZEN_THIRD_PARTY_INCLUDES_END _Pragma("GCC diagnostic pop")
+# endif
+#endif
+
+#if ZEN_COMPILER_MSC
+# define ZEN_DEBUG_BREAK() \
+ do \
+ { \
+ __debugbreak(); \
+ } while (0)
+#else
+# define ZEN_DEBUG_BREAK() \
+ do \
+ { \
+ __builtin_trap(); \
+ } while (0)
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// C++20 support
+//
+
+// Clang
+#if ZEN_COMPILER_CLANG && __clang_major__ < 12
+# error clang-12 onwards is required for C++20 support
+#endif
+
+// GCC
+#if ZEN_COMPILER_GCC && __GNUC__ < 11
+# error GCC-11 onwards is required for C++20 support
+#endif
+
+// GNU libstdc++
+#if defined(_GLIBCXX_RELEASE) && _GLIBCXX_RELEASE < 11
+# error GNU libstdc++-11 onwards is required for C++20 support
+#endif
+
+// LLVM libc++
+#if defined(_LIBCPP_VERSION) && _LIBCPP_VERSION < 12000
+# error LLVM libc++-12 onwards is required for C++20 support
+#endif
+
+// At the time of writing only ver >= 13 of LLVM's libc++ has an implementation
+// of std::integral. Some platforms like Ubuntu and Mac OS are still on 12.
+#if defined(__cpp_lib_concepts)
+# include <concepts>
+template<class T>
+concept Integral = std::integral<T>;
+template<class T>
+concept SignedIntegral = std::signed_integral<T>;
+template<class T>
+concept UnsignedIntegral = std::unsigned_integral<T>;
+template<class F, class... A>
+concept Invocable = std::invocable<F, A...>;
+template<class D, class B>
+concept DerivedFrom = std::derived_from<D, B>;
+#else
+template<class T>
+concept Integral = std::is_integral_v<T>;
+template<class T>
+concept SignedIntegral = Integral<T> && std::is_signed_v<T>;
+template<class T>
+concept UnsignedIntegral = Integral<T> && !std::is_signed_v<T>;
+template<class F, class... A>
+concept Invocable = requires(F&& f, A&&... a)
+{
+ std::invoke(std::forward<F>(f), std::forward<A>(a)...);
+};
+template<class D, class B>
+concept DerivedFrom = std::is_base_of_v<B, D> && std::is_convertible_v<const volatile D*, const volatile B*>;
+#endif
+
+#if defined(__cpp_lib_ranges)
+template<typename T>
+concept ContiguousRange = std::ranges::contiguous_range<T>;
+#else
+template<typename T>
+concept ContiguousRange = true;
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// Architecture
+//
+
+#if defined(__amd64__) || defined(_M_X64)
+# define ZEN_ARCH_X64 1
+# define ZEN_ARCH_ARM64 0
+#elif defined(__arm64__) || defined(_M_ARM64)
+# define ZEN_ARCH_X64 0
+# define ZEN_ARCH_ARM64 1
+#else
+# error Unknown architecture
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// Build flavor
+//
+
+#ifdef NDEBUG
+# define ZEN_BUILD_DEBUG 0
+# define ZEN_BUILD_RELEASE 1
+#else
+# define ZEN_BUILD_DEBUG 1
+# define ZEN_BUILD_RELEASE 0
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+#define ZEN_PLATFORM_SUPPORTS_UNALIGNED_LOADS 1
+
+#if defined(__SIZEOF_WCHAR_T__) && __SIZEOF_WCHAR_T__ == 4
+# define ZEN_SIZEOF_WCHAR_T 4
+#else
+static_assert(sizeof(wchar_t) == 2, "wchar_t is expected to be two bytes in size");
+# define ZEN_SIZEOF_WCHAR_T 2
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// Assert
+//
+
+#if ZEN_PLATFORM_WINDOWS
+// Tells the compiler to put the decorated function in a certain section (aka. segment) of the executable.
+# define ZEN_CODE_SECTION(Name) __declspec(code_seg(Name))
+# define ZEN_FORCENOINLINE __declspec(noinline) /* Force code to NOT be inline */
+# define LINE_TERMINATOR_ANSI "\r\n"
+#else
+# define ZEN_CODE_SECTION(Name)
+# define ZEN_FORCENOINLINE
+# define LINE_TERMINATOR_ANSI "\n"
+#endif
+
+#if ZEN_ARCH_ARM64
+// On ARM we can't do this because the executable will require jumps larger
+// than the branch instruction can handle. Clang will only generate
+// the trampolines in the .text segment of the binary. If the zcold segment
+// is present it will generate code that it cannot link.
+# define ZEN_DEBUG_SECTION
+#else
+// We'll put all assert implementation code into a separate section in the linked
+// executable. This code should never execute so using a separate section keeps
+// it well off the hot path and hopefully out of the instruction cache. It also
+// facilitates reasoning about the makeup of a compiled/linked binary.
+# define ZEN_DEBUG_SECTION ZEN_CODE_SECTION(".zcold")
+#endif
+
+namespace zen
+{
+ class AssertException : public std::logic_error
+ {
+ public:
+ AssertException(const char* Msg) : std::logic_error(Msg) {}
+ };
+
+ struct AssertImpl
+ {
+ static void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION ExecAssert
+ [[noreturn]] (const char* Filename, int LineNumber, const char* FunctionName, const char* Msg)
+ {
+ CurrentAssertImpl->OnAssert(Filename, LineNumber, FunctionName, Msg);
+ throw AssertException{Msg};
+ }
+
+ protected:
+ virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION OnAssert(const char* Filename,
+ int LineNumber,
+ const char* FunctionName,
+ const char* Msg)
+ {
+ (void(Filename));
+ (void(LineNumber));
+ (void(FunctionName));
+ (void(Msg));
+ }
+ static AssertImpl DefaultAssertImpl;
+ static AssertImpl* CurrentAssertImpl;
+ };
+
+} // namespace zen
+
+#define ZEN_ASSERT(x, ...) \
+ do \
+ { \
+ if (x) [[unlikely]] \
+ break; \
+ zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, #x); \
+ } while (false)
+
+#ifndef NDEBUG
+# define ZEN_ASSERT_SLOW(x, ...) \
+ do \
+ { \
+ if (x) [[unlikely]] \
+ break; \
+ zen::AssertImpl::ExecAssert(__FILE__, __LINE__, __FUNCTION__, #x); \
+ } while (false)
+#else
+# define ZEN_ASSERT_SLOW(x, ...)
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+#ifdef __clang__
+template<typename T>
+auto ZenArrayCountHelper(T& t) -> typename std::enable_if<__is_array(T), char (&)[sizeof(t) / sizeof(t[0]) + 1]>::type;
+#else
+template<typename T, uint32_t N>
+char (&ZenArrayCountHelper(const T (&)[N]))[N + 1];
+#endif
+
+#define ZEN_ARRAY_COUNT(array) (sizeof(ZenArrayCountHelper(array)) - 1)
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_COMPILER_MSC
+# define ZEN_NOINLINE __declspec(noinline)
+#else
+# define ZEN_NOINLINE __attribute__((noinline))
+#endif
+
+#if ZEN_PLATFORM_WINDOWS
+# define ZEN_EXE_SUFFIX_LITERAL ".exe"
+#else
+# define ZEN_EXE_SUFFIX_LITERAL
+#endif
+
+#define ZEN_UNUSED(...) ((void)__VA_ARGS__)
+#define ZEN_NOT_IMPLEMENTED(...) ZEN_ASSERT(false, __VA_ARGS__)
+#define ZENCORE_API // Placeholder to allow DLL configs in the future (maybe)
+
+namespace zen {
+
+ZENCORE_API bool IsApplicationExitRequested();
+ZENCORE_API void RequestApplicationExit(int ExitCode);
+ZENCORE_API bool IsDebuggerPresent();
+ZENCORE_API void SetIsInteractiveSession(bool Value);
+ZENCORE_API bool IsInteractiveSession();
+
+ZENCORE_API void zencore_forcelinktests();
+
+} // namespace zen
+
+//////////////////////////////////////////////////////////////////////////
+
+#ifndef ZEN_USE_MIMALLOC
+# if ZEN_ARCH_ARM64
+ // The vcpkg mimalloc port doesn't support Arm targets
+# define ZEN_USE_MIMALLOC 0
+# else
+# define ZEN_USE_MIMALLOC 1
+# endif
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_COMPILER_MSC
+# define ZEN_DISABLE_OPTIMIZATION_ACTUAL __pragma(optimize("", off))
+# define ZEN_ENABLE_OPTIMIZATION_ACTUAL __pragma(optimize("", on))
+#elif ZEN_COMPILER_GCC
+# define ZEN_DISABLE_OPTIMIZATION_ACTUAL _Pragma("GCC push_options") _Pragma("GCC optimize (\"O0\")")
+# define ZEN_ENABLE_OPTIMIZATION_ACTUAL _Pragma("GCC pop_options")
+#elif ZEN_COMPILER_CLANG
+# define ZEN_DISABLE_OPTIMIZATION_ACTUAL _Pragma("clang optimize off")
+# define ZEN_ENABLE_OPTIMIZATION_ACTUAL _Pragma("clang optimize on")
+#endif
+
+// Set up optimization control macros, now that we have both the build settings and the platform macros
+#define ZEN_DISABLE_OPTIMIZATION ZEN_DISABLE_OPTIMIZATION_ACTUAL
+
+#if ZEN_BUILD_DEBUG
+# define ZEN_ENABLE_OPTIMIZATION ZEN_DISABLE_OPTIMIZATION_ACTUAL
+#else
+# define ZEN_ENABLE_OPTIMIZATION ZEN_ENABLE_OPTIMIZATION_ACTUAL
+#endif
+
+#define ZEN_ENABLE_OPTIMIZATION_ALWAYS ZEN_ENABLE_OPTIMIZATION_ACTUAL
+
+//////////////////////////////////////////////////////////////////////////
+
+#ifndef ZEN_WITH_TRACE
+# define ZEN_WITH_TRACE 0
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+using ThreadId_t = uint32_t;
diff --git a/src/zencore/intmath.cpp b/src/zencore/intmath.cpp
new file mode 100644
index 000000000..5a686dc8e
--- /dev/null
+++ b/src/zencore/intmath.cpp
@@ -0,0 +1,65 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/endian.h>
+#include <zencore/intmath.h>
+
+#include <zencore/testing.h>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+void
+intmath_forcelink()
+{
+}
+
+TEST_CASE("intmath")
+{
+ CHECK(FloorLog2(0x00) == 0);
+ CHECK(FloorLog2(0x01) == 0);
+ CHECK(FloorLog2(0x0f) == 3);
+ CHECK(FloorLog2(0x10) == 4);
+ CHECK(FloorLog2(0x11) == 4);
+ CHECK(FloorLog2(0x12) == 4);
+ CHECK(FloorLog2(0x22) == 5);
+ CHECK(FloorLog2(0x0001'0000) == 16);
+ CHECK(FloorLog2(0x0001'000f) == 16);
+ CHECK(FloorLog2(0x8000'0000) == 31);
+
+ CHECK(FloorLog2_64(0x00ull) == 0);
+ CHECK(FloorLog2_64(0x01ull) == 0);
+ CHECK(FloorLog2_64(0x0full) == 3);
+ CHECK(FloorLog2_64(0x10ull) == 4);
+ CHECK(FloorLog2_64(0x11ull) == 4);
+ CHECK(FloorLog2_64(0x0001'0000ull) == 16);
+ CHECK(FloorLog2_64(0x0001'000full) == 16);
+ CHECK(FloorLog2_64(0x8000'0000ull) == 31);
+ CHECK(FloorLog2_64(0x0000'0001'0000'0000ull) == 32);
+ CHECK(FloorLog2_64(0x8000'0000'0000'0000ull) == 63);
+
+ CHECK(CountLeadingZeros64(0x8000'0000'0000'0000ull) == 0);
+ CHECK(CountLeadingZeros64(0x0000'0000'0000'0000ull) == 64);
+ CHECK(CountLeadingZeros64(0x0000'0000'0000'0001ull) == 63);
+ CHECK(CountLeadingZeros64(0x0000'0000'8000'0000ull) == 32);
+ CHECK(CountLeadingZeros64(0x0000'0001'0000'0000ull) == 31);
+
+ CHECK(CountTrailingZeros64(0x8000'0000'0000'0000ull) == 63);
+ CHECK(CountTrailingZeros64(0x0000'0000'0000'0000ull) == 64);
+ CHECK(CountTrailingZeros64(0x0000'0000'0000'0001ull) == 0);
+ CHECK(CountTrailingZeros64(0x0000'0000'8000'0000ull) == 31);
+ CHECK(CountTrailingZeros64(0x0000'0001'0000'0000ull) == 32);
+
+ CHECK(ByteSwap(uint16_t(0x6d72)) == 0x726d);
+ CHECK(ByteSwap(uint32_t(0x2741'3965)) == 0x6539'4127);
+ CHECK(ByteSwap(uint64_t(0x214d'6172'7469'6e21ull)) == 0x216e'6974'7261'4d21ull);
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/iobuffer.cpp b/src/zencore/iobuffer.cpp
new file mode 100644
index 000000000..1d7d47695
--- /dev/null
+++ b/src/zencore/iobuffer.cpp
@@ -0,0 +1,653 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/iobuffer.h>
+
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/testing.h>
+#include <zencore/thread.h>
+
+#include <memory.h>
+#include <system_error>
+
+#if ZEN_USE_MIMALLOC
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <mimalloc.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+#endif
+
+#if ZEN_PLATFORM_WINDOWS
+# include <atlfile.h>
+#else
+# include <sys/stat.h>
+# include <sys/mman.h>
+#endif
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+IoBufferCore::AllocateBuffer(size_t InSize, size_t Alignment) const
+{
+#if ZEN_PLATFORM_WINDOWS
+ if (((InSize & 0xffFF) == 0) && (Alignment == 0x10000))
+ {
+ m_Flags.fetch_or(kLowLevelAlloc, std::memory_order_relaxed);
+ m_DataPtr = VirtualAlloc(nullptr, InSize, MEM_COMMIT, PAGE_READWRITE);
+
+ return;
+ }
+#endif // ZEN_PLATFORM_WINDOWS
+
+#if ZEN_USE_MIMALLOC
+ void* Ptr = mi_aligned_alloc(Alignment, RoundUp(InSize, Alignment));
+ m_Flags.fetch_or(kIoBufferAlloc, std::memory_order_relaxed);
+#else
+ void* Ptr = Memory::Alloc(InSize, Alignment);
+#endif
+
+ ZEN_ASSERT(Ptr);
+
+ m_DataPtr = Ptr;
+}
+
+void
+IoBufferCore::FreeBuffer()
+{
+ if (!m_DataPtr)
+ {
+ return;
+ }
+
+ const uint32_t LocalFlags = m_Flags.load(std::memory_order_relaxed);
+#if ZEN_PLATFORM_WINDOWS
+ if (LocalFlags & kLowLevelAlloc)
+ {
+ VirtualFree(const_cast<void*>(m_DataPtr), 0, MEM_DECOMMIT);
+
+ return;
+ }
+#endif // ZEN_PLATFORM_WINDOWS
+
+#if ZEN_USE_MIMALLOC
+ if (LocalFlags & kIoBufferAlloc)
+ {
+ return mi_free(const_cast<void*>(m_DataPtr));
+ }
+#endif
+
+ ZEN_UNUSED(LocalFlags);
+ return Memory::Free(const_cast<void*>(m_DataPtr));
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+static_assert(sizeof(IoBufferCore) == 32);
+
+IoBufferCore::IoBufferCore(size_t InSize)
+{
+ ZEN_ASSERT(InSize);
+
+ AllocateBuffer(InSize, sizeof(void*));
+ m_DataBytes = InSize;
+
+ SetIsOwnedByThis(true);
+}
+
+IoBufferCore::IoBufferCore(size_t InSize, size_t Alignment)
+{
+ ZEN_ASSERT(InSize);
+
+ AllocateBuffer(InSize, Alignment);
+ m_DataBytes = InSize;
+
+ SetIsOwnedByThis(true);
+}
+
+IoBufferCore::~IoBufferCore()
+{
+ if (IsOwnedByThis() && m_DataPtr)
+ {
+ FreeBuffer();
+ m_DataPtr = nullptr;
+ }
+}
+
+void
+IoBufferCore::DeleteThis() const
+{
+ // We do this just to avoid paying for the cost of a vtable
+ if (const IoBufferExtendedCore* _ = ExtendedCore())
+ {
+ delete _;
+ }
+ else
+ {
+ delete this;
+ }
+}
+
+void
+IoBufferCore::Materialize() const
+{
+ if (const IoBufferExtendedCore* _ = ExtendedCore())
+ {
+ _->Materialize();
+ }
+}
+
+void
+IoBufferCore::MakeOwned(bool Immutable)
+{
+ if (!IsOwned())
+ {
+ const void* OldDataPtr = m_DataPtr;
+ AllocateBuffer(m_DataBytes, sizeof(void*));
+ memcpy(const_cast<void*>(m_DataPtr), OldDataPtr, m_DataBytes);
+ SetIsOwnedByThis(true);
+ }
+
+ SetIsImmutable(Immutable);
+}
+
+void*
+IoBufferCore::MutableDataPointer() const
+{
+ EnsureDataValid();
+ ZEN_ASSERT(!IsImmutable());
+ return const_cast<void*>(m_DataPtr);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+IoBufferExtendedCore::IoBufferExtendedCore(void* FileHandle, uint64_t Offset, uint64_t Size, bool TransferHandleOwnership)
+: IoBufferCore(nullptr, Size)
+, m_FileHandle(FileHandle)
+, m_FileOffset(Offset)
+{
+ uint32_t NewFlags = kIsOwnedByThis | kIsExtended;
+
+ if (TransferHandleOwnership)
+ {
+ NewFlags |= kOwnsFile;
+ }
+ m_Flags.fetch_or(NewFlags, std::memory_order_relaxed);
+}
+
+IoBufferExtendedCore::IoBufferExtendedCore(const IoBufferExtendedCore* Outer, uint64_t Offset, uint64_t Size)
+: IoBufferCore(Outer, nullptr, Size)
+, m_FileHandle(Outer->m_FileHandle)
+, m_FileOffset(Outer->m_FileOffset + Offset)
+{
+ m_Flags.fetch_or(kIsExtended, std::memory_order_relaxed);
+}
+
+IoBufferExtendedCore::~IoBufferExtendedCore()
+{
+ if (m_MappedPointer)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ UnmapViewOfFile(m_MappedPointer);
+#else
+ uint64_t MapSize = ~uint64_t(uintptr_t(m_MmapHandle));
+ munmap(m_MappedPointer, MapSize);
+#endif
+ }
+
+ const uint32_t LocalFlags = m_Flags.load(std::memory_order_relaxed);
+#if ZEN_PLATFORM_WINDOWS
+ if (LocalFlags & kOwnsMmap)
+ {
+ CloseHandle(m_MmapHandle);
+ }
+#endif
+
+ if (LocalFlags & kOwnsFile)
+ {
+ if (m_DeleteOnClose)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ // Mark file for deletion when final handle is closed
+ FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE};
+
+ SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi);
+#else
+ std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle);
+ unlink(FilePath.c_str());
+#endif
+ }
+#if ZEN_PLATFORM_WINDOWS
+ BOOL Success = CloseHandle(m_FileHandle);
+#else
+ int Fd = int(uintptr_t(m_FileHandle));
+ bool Success = (close(Fd) == 0);
+#endif
+ if (!Success)
+ {
+ ZEN_WARN("Error reported on file handle close, reason '{}'", GetLastErrorAsString());
+ }
+ }
+
+ m_DataPtr = nullptr;
+}
+
+static constexpr size_t MappingLockCount = 128;
+static_assert(IsPow2(MappingLockCount), "MappingLockCount must be power of two");
+
+static RwLock g_MappingLocks[MappingLockCount];
+
+static RwLock&
+MappingLockForInstance(const IoBufferExtendedCore* instance)
+{
+ intptr_t base = (intptr_t)instance;
+ size_t lock_index = ((base >> 5) ^ (base >> 13)) & (MappingLockCount - 1u);
+ return g_MappingLocks[lock_index];
+}
+
+void
+IoBufferExtendedCore::Materialize() const
+{
+ // The synchronization scheme here is very primitive, if we end up with
+ // a lot of contention we can make it more fine-grained
+
+ if (m_Flags.load(std::memory_order_acquire) & kIsMaterialized)
+ return;
+
+ RwLock::ExclusiveLockScope _(MappingLockForInstance(this));
+
+ // Someone could have gotten here first
+ // We can use memory_order_relaxed on this load because the mutex has already provided the fence
+ if (m_Flags.load(std::memory_order_relaxed) & kIsMaterialized)
+ return;
+
+ uint32_t NewFlags = kIsMaterialized;
+
+ if (m_DataBytes == 0)
+ {
+ // Fake a "valid" pointer, nobody should read this as size is zero
+ m_DataPtr = reinterpret_cast<uint8_t*>(&m_MmapHandle);
+ m_Flags.fetch_or(NewFlags, std::memory_order_release);
+ return;
+ }
+
+ const size_t DisableMMapSizeLimit = 0x1000ull;
+
+ if (m_DataBytes < DisableMMapSizeLimit)
+ {
+ AllocateBuffer(m_DataBytes, sizeof(void*));
+ NewFlags |= kIsOwnedByThis;
+
+#if ZEN_PLATFORM_WINDOWS
+ OVERLAPPED Ovl{};
+
+ Ovl.Offset = DWORD(m_FileOffset & 0xffff'ffffu);
+ Ovl.OffsetHigh = DWORD(m_FileOffset >> 32);
+
+ DWORD dwNumberOfBytesRead = 0;
+ BOOL Success = ::ReadFile(m_FileHandle, (void*)m_DataPtr, DWORD(m_DataBytes), &dwNumberOfBytesRead, &Ovl);
+
+ ZEN_ASSERT(Success);
+ ZEN_ASSERT(dwNumberOfBytesRead == m_DataBytes);
+#else
+ static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files");
+ int Fd = int(uintptr_t(m_FileHandle));
+ int BytesRead = pread(Fd, (void*)m_DataPtr, m_DataBytes, m_FileOffset);
+ bool Success = (BytesRead > 0);
+#endif // ZEN_PLATFORM_WINDOWS
+
+ m_Flags.fetch_or(NewFlags, std::memory_order_release);
+ return;
+ }
+
+ void* NewMmapHandle;
+
+ const uint64_t MapOffset = m_FileOffset & ~0xffffull;
+ const uint64_t MappedOffsetDisplacement = m_FileOffset - MapOffset;
+ const uint64_t MapSize = m_DataBytes + MappedOffsetDisplacement;
+
+ ZEN_ASSERT(MapSize > 0);
+
+#if ZEN_PLATFORM_WINDOWS
+ NewMmapHandle = CreateFileMapping(m_FileHandle,
+ /* lpFileMappingAttributes */ nullptr,
+ /* flProtect */ PAGE_READONLY,
+ /* dwMaximumSizeLow */ 0,
+ /* dwMaximumSizeHigh */ 0,
+ /* lpName */ nullptr);
+
+ if (NewMmapHandle == nullptr)
+ {
+ int32_t Error = zen::GetLastError();
+ ZEN_ERROR("CreateFileMapping failed on file '{}', {}", zen::PathFromHandle(m_FileHandle), GetSystemErrorAsString(Error));
+ throw std::system_error(std::error_code(Error, std::system_category()),
+ fmt::format("CreateFileMapping failed on file '{}'", zen::PathFromHandle(m_FileHandle)));
+ }
+
+ NewFlags |= kOwnsMmap;
+
+ void* MappedBase = MapViewOfFile(NewMmapHandle,
+ /* dwDesiredAccess */ FILE_MAP_READ,
+ /* FileOffsetHigh */ uint32_t(MapOffset >> 32),
+ /* FileOffsetLow */ uint32_t(MapOffset & 0xffFFffFFu),
+ /* dwNumberOfBytesToMap */ MapSize);
+#else
+ NewMmapHandle = (void*)uintptr_t(~MapSize); // ~ so it's never null (assuming MapSize >= 0)
+ NewFlags |= kOwnsMmap;
+
+ void* MappedBase = mmap(
+ /* addr */ nullptr,
+ /* length */ MapSize,
+ /* prot */ PROT_READ,
+ /* flags */ MAP_SHARED | MAP_NORESERVE,
+ /* fd */ int(uintptr_t(m_FileHandle)),
+ /* offset */ MapOffset);
+#endif // ZEN_PLATFORM_WINDOWS
+
+ if (MappedBase == nullptr)
+ {
+ int32_t Error = zen::GetLastError();
+#if ZEN_PLATFORM_WINDOWS
+ CloseHandle(NewMmapHandle);
+#endif // ZEN_PLATFORM_WINDOWS
+ ZEN_ERROR("MapViewOfFile failed (offset {:#x}, size {:#x}) file: '{}', {}",
+ MapOffset,
+ MapSize,
+ zen::PathFromHandle(m_FileHandle),
+ GetSystemErrorAsString(Error));
+ throw std::system_error(std::error_code(Error, std::system_category()),
+ fmt::format("MapViewOfFile failed (offset {:#x}, size {:#x}) file: '{}'",
+ MapOffset,
+ MapSize,
+ zen::PathFromHandle(m_FileHandle)));
+ }
+
+ m_MappedPointer = MappedBase;
+ m_DataPtr = reinterpret_cast<uint8_t*>(MappedBase) + MappedOffsetDisplacement;
+ m_MmapHandle = NewMmapHandle;
+
+ m_Flags.fetch_or(NewFlags, std::memory_order_release);
+}
+
+bool
+IoBufferExtendedCore::GetFileReference(IoBufferFileReference& OutRef) const
+{
+ if (m_FileHandle == nullptr)
+ {
+ return false;
+ }
+
+ OutRef.FileHandle = m_FileHandle;
+ OutRef.FileChunkOffset = m_FileOffset;
+ OutRef.FileChunkSize = m_DataBytes;
+
+ return true;
+}
+
+void
+IoBufferExtendedCore::MarkAsDeleteOnClose()
+{
+ m_DeleteOnClose = true;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+IoBuffer::IoBuffer(size_t InSize) : m_Core(new IoBufferCore(InSize))
+{
+ m_Core->SetIsImmutable(false);
+}
+
+IoBuffer::IoBuffer(size_t InSize, uint64_t InAlignment) : m_Core(new IoBufferCore(InSize, InAlignment))
+{
+ m_Core->SetIsImmutable(false);
+}
+
+IoBuffer::IoBuffer(const IoBuffer& OuterBuffer, size_t Offset, size_t Size)
+{
+ if (Size == ~(0ull))
+ {
+ Size = std::clamp<size_t>(Size, 0, OuterBuffer.Size() - Offset);
+ }
+
+ ZEN_ASSERT(Offset <= OuterBuffer.Size());
+ ZEN_ASSERT((Offset + Size) <= OuterBuffer.Size());
+
+ if (IoBufferExtendedCore* Extended = OuterBuffer.m_Core->ExtendedCore())
+ {
+ m_Core = new IoBufferExtendedCore(Extended, Offset, Size);
+ }
+ else
+ {
+ m_Core = new IoBufferCore(OuterBuffer.m_Core, reinterpret_cast<const uint8_t*>(OuterBuffer.Data()) + Offset, Size);
+ }
+}
+
+IoBuffer::IoBuffer(EFileTag, void* FileHandle, uint64_t ChunkFileOffset, uint64_t ChunkSize)
+: m_Core(new IoBufferExtendedCore(FileHandle, ChunkFileOffset, ChunkSize, /* owned */ true))
+{
+}
+
+IoBuffer::IoBuffer(EBorrowedFileTag, void* FileHandle, uint64_t ChunkFileOffset, uint64_t ChunkSize)
+: m_Core(new IoBufferExtendedCore(FileHandle, ChunkFileOffset, ChunkSize, /* owned */ false))
+{
+}
+
+bool
+IoBuffer::GetFileReference(IoBufferFileReference& OutRef) const
+{
+ if (IoBufferExtendedCore* ExtCore = m_Core->ExtendedCore())
+ {
+ if (ExtCore->GetFileReference(OutRef))
+ {
+ return true;
+ }
+ }
+
+ // Not a file reference
+
+ OutRef.FileHandle = 0;
+ OutRef.FileChunkOffset = ~0ull;
+ OutRef.FileChunkSize = 0;
+
+ return false;
+}
+
+void
+IoBuffer::MarkAsDeleteOnClose()
+{
+ if (IoBufferExtendedCore* ExtCore = m_Core->ExtendedCore())
+ {
+ ExtCore->MarkAsDeleteOnClose();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+IoBuffer
+IoBufferBuilder::ReadFromFileMaybe(IoBuffer& InBuffer)
+{
+ IoBufferFileReference FileRef;
+ if (InBuffer.GetFileReference(/* out */ FileRef))
+ {
+ IoBuffer OutBuffer(FileRef.FileChunkSize);
+
+#if ZEN_PLATFORM_WINDOWS
+ OVERLAPPED Ovl{};
+
+ const uint64_t NumberOfBytesToRead = FileRef.FileChunkSize;
+ const uint64_t& FileOffset = FileRef.FileChunkOffset;
+
+ Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu);
+ Ovl.OffsetHigh = DWORD(FileOffset >> 32);
+
+ DWORD dwNumberOfBytesRead = 0;
+ BOOL Success = ::ReadFile(FileRef.FileHandle, OutBuffer.MutableData(), DWORD(NumberOfBytesToRead), &dwNumberOfBytesRead, &Ovl);
+#else
+ int Fd = int(intptr_t(FileRef.FileHandle));
+ int Result = pread(Fd, OutBuffer.MutableData(), size_t(FileRef.FileChunkSize), off_t(FileRef.FileChunkOffset));
+ bool Success = (Result < 0);
+
+ uint32_t dwNumberOfBytesRead = uint32_t(Result);
+#endif
+
+ if (!Success)
+ {
+ ThrowLastError("ReadFile failed in IoBufferBuilder::ReadFromFileMaybe");
+ }
+
+ ZEN_ASSERT(dwNumberOfBytesRead == FileRef.FileChunkSize);
+
+ return OutBuffer;
+ }
+ else
+ {
+ return InBuffer;
+ }
+}
+
+IoBuffer
+IoBufferBuilder::MakeFromFileHandle(void* FileHandle, uint64_t Offset, uint64_t Size)
+{
+ return IoBuffer(IoBuffer::BorrowedFile, FileHandle, Offset, Size);
+}
+
+IoBuffer
+IoBufferBuilder::MakeFromFile(const std::filesystem::path& FileName, uint64_t Offset, uint64_t Size)
+{
+ uint64_t FileSize;
+
+#if ZEN_PLATFORM_WINDOWS
+ CAtlFile DataFile;
+
+ DWORD ShareOptions = FILE_SHARE_DELETE | FILE_SHARE_WRITE | FILE_SHARE_DELETE | FILE_SHARE_READ;
+ HRESULT hRes = DataFile.Create(FileName.c_str(), GENERIC_READ, ShareOptions, OPEN_EXISTING);
+
+ if (FAILED(hRes))
+ {
+ return {};
+ }
+
+ DataFile.GetSize((ULONGLONG&)FileSize);
+#else
+ int Flags = O_RDONLY | O_CLOEXEC;
+ int Fd = open(FileName.c_str(), Flags);
+ if (Fd < 0)
+ {
+ return {};
+ }
+
+ static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files");
+ struct stat Stat;
+ fstat(Fd, &Stat);
+ FileSize = Stat.st_size;
+#endif // ZEN_PLATFORM_WINDOWS
+
+ // TODO: should validate that offset is in range
+
+ if (Size == ~0ull)
+ {
+ Size = FileSize - Offset;
+ }
+ else
+ {
+ // Clamp size
+ if ((Offset + Size) > FileSize)
+ {
+ Size = FileSize - Offset;
+ }
+ }
+
+ if (Size)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ void* Fd = DataFile.Detach();
+#endif
+ IoBuffer Iob(IoBuffer::File, (void*)uintptr_t(Fd), Offset, Size);
+ Iob.m_Core->SetIsWholeFile(Offset == 0 && Size == FileSize);
+ return Iob;
+ }
+
+#if !ZEN_PLATFORM_WINDOWS
+ close(Fd);
+#endif
+
+ // For an empty file, we may as well just return an empty memory IoBuffer
+ return IoBuffer(IoBuffer::Wrap, "", 0);
+}
+
+IoBuffer
+IoBufferBuilder::MakeFromTemporaryFile(const std::filesystem::path& FileName)
+{
+ uint64_t FileSize;
+ void* Handle;
+
+#if ZEN_PLATFORM_WINDOWS
+ CAtlFile DataFile;
+
+ // We need to open with DELETE since this is used for the case
+ // when a file has been written to a staging directory, and is going
+ // to be moved in place
+
+ HRESULT hRes = DataFile.Create(FileName.native().c_str(), GENERIC_READ | DELETE, FILE_SHARE_READ | FILE_SHARE_DELETE, OPEN_EXISTING);
+
+ if (FAILED(hRes))
+ {
+ return {};
+ }
+
+ DataFile.GetSize((ULONGLONG&)FileSize);
+
+ Handle = DataFile.Detach();
+#else
+ int Fd = open(FileName.native().c_str(), O_RDONLY);
+ if (Fd < 0)
+ {
+ return {};
+ }
+
+ static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files");
+ struct stat Stat;
+ fstat(Fd, &Stat);
+ FileSize = Stat.st_size;
+
+ Handle = (void*)uintptr_t(Fd);
+#endif // ZEN_PLATFORM_WINDOWS
+
+ IoBuffer Iob(IoBuffer::File, Handle, 0, FileSize);
+ Iob.m_Core->SetIsWholeFile(true);
+
+ return Iob;
+}
+
+IoHash
+HashBuffer(IoBuffer& Buffer)
+{
+ // TODO: handle disk buffers with special path
+ return IoHash::HashBuffer(Buffer.Data(), Buffer.Size());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+void
+iobuffer_forcelink()
+{
+}
+
+TEST_CASE("IoBuffer")
+{
+ zen::IoBuffer buffer1;
+ zen::IoBuffer buffer2(16384);
+ zen::IoBuffer buffer3(buffer2, 0, buffer2.Size());
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/iohash.cpp b/src/zencore/iohash.cpp
new file mode 100644
index 000000000..77076c133
--- /dev/null
+++ b/src/zencore/iohash.cpp
@@ -0,0 +1,87 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/iohash.h>
+
+#include <zencore/blake3.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+const IoHash IoHash::Zero{}; // Initialized to all zeros
+
+IoHash
+IoHash::HashBuffer(const void* data, size_t byteCount)
+{
+ BLAKE3 b3 = BLAKE3::HashMemory(data, byteCount);
+
+ IoHash io;
+ memcpy(io.Hash, b3.Hash, sizeof io.Hash);
+
+ return io;
+}
+
+IoHash
+IoHash::HashBuffer(const CompositeBuffer& Buffer)
+{
+ IoHashStream Hasher;
+
+ for (const SharedBuffer& Segment : Buffer.GetSegments())
+ {
+ Hasher.Append(Segment.GetData(), Segment.GetSize());
+ }
+
+ return Hasher.GetHash();
+}
+
+IoHash
+IoHash::FromHexString(const char* string)
+{
+ return FromHexString({string, sizeof(IoHash::Hash) * 2});
+}
+
+IoHash
+IoHash::FromHexString(std::string_view string)
+{
+ ZEN_ASSERT(string.size() == 2 * sizeof(IoHash::Hash));
+
+ IoHash io;
+
+ ParseHexBytes(string.data(), string.size(), io.Hash);
+
+ return io;
+}
+
+const char*
+IoHash::ToHexString(char* outString /* 40 characters + NUL terminator */) const
+{
+ ToHexBytes(Hash, sizeof(IoHash), outString);
+ outString[2 * sizeof(IoHash)] = '\0';
+
+ return outString;
+}
+
+StringBuilderBase&
+IoHash::ToHexString(StringBuilderBase& outBuilder) const
+{
+ String_t Str;
+ ToHexString(Str);
+
+ outBuilder.AppendRange(Str, &Str[StringLength]);
+
+ return outBuilder;
+}
+
+std::string
+IoHash::ToHexString() const
+{
+ String_t Str;
+ ToHexString(Str);
+
+ return Str;
+}
+
+} // namespace zen
diff --git a/src/zencore/logging.cpp b/src/zencore/logging.cpp
new file mode 100644
index 000000000..a6423e2dc
--- /dev/null
+++ b/src/zencore/logging.cpp
@@ -0,0 +1,85 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencore/logging.h"
+
+#include <spdlog/sinks/stdout_color_sinks.h>
+
+namespace zen {
+
+// We shadow the underlying spdlog default logger, in order to avoid a bunch of overhead
+spdlog::logger* TheDefaultLogger;
+
+} // namespace zen
+
+namespace zen::logging {
+
+spdlog::logger&
+Default()
+{
+ return *TheDefaultLogger;
+}
+
+void
+SetDefault(std::shared_ptr<spdlog::logger> NewDefaultLogger)
+{
+ spdlog::set_default_logger(NewDefaultLogger);
+ TheDefaultLogger = spdlog::default_logger_raw();
+}
+
+spdlog::logger&
+Get(std::string_view Name)
+{
+ std::shared_ptr<spdlog::logger> Logger = spdlog::get(std::string(Name));
+
+ if (!Logger)
+ {
+ Logger = Default().clone(std::string(Name));
+ spdlog::register_logger(Logger);
+ }
+
+ return *Logger;
+}
+
+std::once_flag ConsoleInitFlag;
+std::shared_ptr<spdlog::logger> ConLogger;
+
+spdlog::logger&
+ConsoleLog()
+{
+ std::call_once(ConsoleInitFlag, [&] {
+ ConLogger = spdlog::stdout_color_mt("console");
+
+ ConLogger->set_pattern("%v");
+ });
+
+ return *ConLogger;
+}
+
+std::shared_ptr<spdlog::logger> TheErrorLogger;
+
+spdlog::logger*
+ErrorLog()
+{
+ return TheErrorLogger.get();
+}
+
+void
+SetErrorLog(std::shared_ptr<spdlog::logger>&& NewErrorLogger)
+{
+ TheErrorLogger = std::move(NewErrorLogger);
+}
+
+void
+InitializeLogging()
+{
+ TheDefaultLogger = spdlog::default_logger_raw();
+}
+
+void
+ShutdownLogging()
+{
+ spdlog::drop_all();
+ spdlog::shutdown();
+}
+
+} // namespace zen::logging
diff --git a/src/zencore/md5.cpp b/src/zencore/md5.cpp
new file mode 100644
index 000000000..4ec145697
--- /dev/null
+++ b/src/zencore/md5.cpp
@@ -0,0 +1,463 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/md5.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/zencore.h>
+
+#include <string.h>
+#include <string_view>
+
+/*
+ **********************************************************************
+ ** md5.h -- Header file for implementation of MD5 **
+ ** RSA Data Security, Inc. MD5 Message Digest Algorithm **
+ ** Created: 2/17/90 RLR **
+ ** Revised: 12/27/90 SRD,AJ,BSK,JT Reference C version **
+ ** Revised (for MD5): RLR 4/27/91 **
+ ** -- G modified to have y&~z instead of y&z **
+ ** -- FF, GG, HH modified to add in last register done **
+ ** -- Access pattern: round 2 works mod 5, round 3 works mod 3 **
+ ** -- distinct additive constant for each step **
+ ** -- round 4 added, working mod 7 **
+ **********************************************************************
+ */
+
+/*
+ **********************************************************************
+ ** Copyright (C) 1990, RSA Data Security, Inc. All rights reserved. **
+ ** **
+ ** License to copy and use this software is granted provided that **
+ ** it is identified as the "RSA Data Security, Inc. MD5 Message **
+ ** Digest Algorithm" in all material mentioning or referencing this **
+ ** software or this function. **
+ ** **
+ ** License is also granted to make and use derivative works **
+ ** provided that such works are identified as "derived from the RSA **
+ ** Data Security, Inc. MD5 Message Digest Algorithm" in all **
+ ** material mentioning or referencing the derived work. **
+ ** **
+ ** RSA Data Security, Inc. makes no representations concerning **
+ ** either the merchantability of this software or the suitability **
+ ** of this software for any particular purpose. It is provided "as **
+ ** is" without express or implied warranty of any kind. **
+ ** **
+ ** These notices must be retained in any copies of any part of this **
+ ** documentation and/or software. **
+ **********************************************************************
+ */
+
+/* Data structure for MD5 (Message Digest) computation */
+struct MD5_CTX
+{
+ uint32_t i[2]; /* number of _bits_ handled mod 2^64 */
+ uint32_t buf[4]; /* scratch buffer */
+ unsigned char in[64]; /* input buffer */
+ unsigned char digest[16]; /* actual digest after MD5Final call */
+};
+
+void MD5Init();
+void MD5Update();
+void MD5Final();
+
+/*
+ **********************************************************************
+ ** End of md5.h **
+ ******************************* (cut) ********************************
+ */
+
+/*
+ **********************************************************************
+ ** md5.c **
+ ** RSA Data Security, Inc. MD5 Message Digest Algorithm **
+ ** Created: 2/17/90 RLR **
+ ** Revised: 1/91 SRD,AJ,BSK,JT Reference C Version **
+ **********************************************************************
+ */
+
+/*
+ **********************************************************************
+ ** Copyright (C) 1990, RSA Data Security, Inc. All rights reserved. **
+ ** **
+ ** License to copy and use this software is granted provided that **
+ ** it is identified as the "RSA Data Security, Inc. MD5 Message **
+ ** Digest Algorithm" in all material mentioning or referencing this **
+ ** software or this function. **
+ ** **
+ ** License is also granted to make and use derivative works **
+ ** provided that such works are identified as "derived from the RSA **
+ ** Data Security, Inc. MD5 Message Digest Algorithm" in all **
+ ** material mentioning or referencing the derived work. **
+ ** **
+ ** RSA Data Security, Inc. makes no representations concerning **
+ ** either the merchantability of this software or the suitability **
+ ** of this software for any particular purpose. It is provided "as **
+ ** is" without express or implied warranty of any kind. **
+ ** **
+ ** These notices must be retained in any copies of any part of this **
+ ** documentation and/or software. **
+ **********************************************************************
+ */
+
+/* -- include the following line if the md5.h header file is separate -- */
+/* #include "md5.h" */
+
+/* forward declaration */
+static void Transform(uint32_t* buf, uint32_t* in);
+
+static unsigned char PADDING[64] = {0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
+
+/* F, G and H are basic MD5 functions: selection, majority, parity */
+#define F(x, y, z) (((x) & (y)) | ((~x) & (z)))
+#define G(x, y, z) (((x) & (z)) | ((y) & (~z)))
+#define H(x, y, z) ((x) ^ (y) ^ (z))
+#define I(x, y, z) ((y) ^ ((x) | (~z)))
+
+/* ROTATE_LEFT rotates x left n bits */
+#define ROTATE_LEFT(x, n) (((x) << (n)) | ((x) >> (32 - (n))))
+
+/* FF, GG, HH, and II transformations for rounds 1, 2, 3, and 4 */
+/* Rotation is separate from addition to prevent recomputation */
+#define FF(a, b, c, d, x, s, ac) \
+ { \
+ (a) += F((b), (c), (d)) + (x) + (uint32_t)(ac); \
+ (a) = ROTATE_LEFT((a), (s)); \
+ (a) += (b); \
+ }
+#define GG(a, b, c, d, x, s, ac) \
+ { \
+ (a) += G((b), (c), (d)) + (x) + (uint32_t)(ac); \
+ (a) = ROTATE_LEFT((a), (s)); \
+ (a) += (b); \
+ }
+#define HH(a, b, c, d, x, s, ac) \
+ { \
+ (a) += H((b), (c), (d)) + (x) + (uint32_t)(ac); \
+ (a) = ROTATE_LEFT((a), (s)); \
+ (a) += (b); \
+ }
+#define II(a, b, c, d, x, s, ac) \
+ { \
+ (a) += I((b), (c), (d)) + (x) + (uint32_t)(ac); \
+ (a) = ROTATE_LEFT((a), (s)); \
+ (a) += (b); \
+ }
+
+void
+MD5Init(MD5_CTX* mdContext)
+{
+ mdContext->i[0] = mdContext->i[1] = (uint32_t)0;
+
+ /* Load magic initialization constants.
+ */
+ mdContext->buf[0] = (uint32_t)0x67452301;
+ mdContext->buf[1] = (uint32_t)0xefcdab89;
+ mdContext->buf[2] = (uint32_t)0x98badcfe;
+ mdContext->buf[3] = (uint32_t)0x10325476;
+}
+
+void
+MD5Update(MD5_CTX* mdContext, unsigned char* inBuf, unsigned int inLen)
+{
+ uint32_t in[16];
+ int mdi;
+ unsigned int i, ii;
+
+ /* compute number of bytes mod 64 */
+ mdi = (int)((mdContext->i[0] >> 3) & 0x3F);
+
+ /* update number of bits */
+ if ((mdContext->i[0] + ((uint32_t)inLen << 3)) < mdContext->i[0])
+ mdContext->i[1]++;
+ mdContext->i[0] += ((uint32_t)inLen << 3);
+ mdContext->i[1] += ((uint32_t)inLen >> 29);
+
+ while (inLen--)
+ {
+ /* add new character to buffer, increment mdi */
+ mdContext->in[mdi++] = *inBuf++;
+
+ /* transform if necessary */
+ if (mdi == 0x40)
+ {
+ for (i = 0, ii = 0; i < 16; i++, ii += 4)
+ in[i] = (((uint32_t)mdContext->in[ii + 3]) << 24) | (((uint32_t)mdContext->in[ii + 2]) << 16) |
+ (((uint32_t)mdContext->in[ii + 1]) << 8) | ((uint32_t)mdContext->in[ii]);
+ Transform(mdContext->buf, in);
+ mdi = 0;
+ }
+ }
+}
+
+void
+MD5Final(MD5_CTX* mdContext)
+{
+ uint32_t in[16];
+ int mdi;
+ unsigned int i, ii;
+ unsigned int padLen;
+
+ /* save number of bits */
+ in[14] = mdContext->i[0];
+ in[15] = mdContext->i[1];
+
+ /* compute number of bytes mod 64 */
+ mdi = (int)((mdContext->i[0] >> 3) & 0x3F);
+
+ /* pad out to 56 mod 64 */
+ padLen = (mdi < 56) ? (56 - mdi) : (120 - mdi);
+ MD5Update(mdContext, PADDING, padLen);
+
+ /* append length in bits and transform */
+ for (i = 0, ii = 0; i < 14; i++, ii += 4)
+ in[i] = (((uint32_t)mdContext->in[ii + 3]) << 24) | (((uint32_t)mdContext->in[ii + 2]) << 16) |
+ (((uint32_t)mdContext->in[ii + 1]) << 8) | ((uint32_t)mdContext->in[ii]);
+ Transform(mdContext->buf, in);
+
+ /* store buffer in digest */
+ for (i = 0, ii = 0; i < 4; i++, ii += 4)
+ {
+ mdContext->digest[ii] = (unsigned char)(mdContext->buf[i] & 0xFF);
+ mdContext->digest[ii + 1] = (unsigned char)((mdContext->buf[i] >> 8) & 0xFF);
+ mdContext->digest[ii + 2] = (unsigned char)((mdContext->buf[i] >> 16) & 0xFF);
+ mdContext->digest[ii + 3] = (unsigned char)((mdContext->buf[i] >> 24) & 0xFF);
+ }
+}
+
+/* Basic MD5 step. Transform buf based on in.
+ */
+static void
+Transform(uint32_t* buf, uint32_t* in)
+{
+ uint32_t a = buf[0], b = buf[1], c = buf[2], d = buf[3];
+
+ /* Round 1 */
+#define S11 7
+#define S12 12
+#define S13 17
+#define S14 22
+ FF(a, b, c, d, in[0], S11, 3614090360); /* 1 */
+ FF(d, a, b, c, in[1], S12, 3905402710); /* 2 */
+ FF(c, d, a, b, in[2], S13, 606105819); /* 3 */
+ FF(b, c, d, a, in[3], S14, 3250441966); /* 4 */
+ FF(a, b, c, d, in[4], S11, 4118548399); /* 5 */
+ FF(d, a, b, c, in[5], S12, 1200080426); /* 6 */
+ FF(c, d, a, b, in[6], S13, 2821735955); /* 7 */
+ FF(b, c, d, a, in[7], S14, 4249261313); /* 8 */
+ FF(a, b, c, d, in[8], S11, 1770035416); /* 9 */
+ FF(d, a, b, c, in[9], S12, 2336552879); /* 10 */
+ FF(c, d, a, b, in[10], S13, 4294925233); /* 11 */
+ FF(b, c, d, a, in[11], S14, 2304563134); /* 12 */
+ FF(a, b, c, d, in[12], S11, 1804603682); /* 13 */
+ FF(d, a, b, c, in[13], S12, 4254626195); /* 14 */
+ FF(c, d, a, b, in[14], S13, 2792965006); /* 15 */
+ FF(b, c, d, a, in[15], S14, 1236535329); /* 16 */
+
+ /* Round 2 */
+#define S21 5
+#define S22 9
+#define S23 14
+#define S24 20
+ GG(a, b, c, d, in[1], S21, 4129170786); /* 17 */
+ GG(d, a, b, c, in[6], S22, 3225465664); /* 18 */
+ GG(c, d, a, b, in[11], S23, 643717713); /* 19 */
+ GG(b, c, d, a, in[0], S24, 3921069994); /* 20 */
+ GG(a, b, c, d, in[5], S21, 3593408605); /* 21 */
+ GG(d, a, b, c, in[10], S22, 38016083); /* 22 */
+ GG(c, d, a, b, in[15], S23, 3634488961); /* 23 */
+ GG(b, c, d, a, in[4], S24, 3889429448); /* 24 */
+ GG(a, b, c, d, in[9], S21, 568446438); /* 25 */
+ GG(d, a, b, c, in[14], S22, 3275163606); /* 26 */
+ GG(c, d, a, b, in[3], S23, 4107603335); /* 27 */
+ GG(b, c, d, a, in[8], S24, 1163531501); /* 28 */
+ GG(a, b, c, d, in[13], S21, 2850285829); /* 29 */
+ GG(d, a, b, c, in[2], S22, 4243563512); /* 30 */
+ GG(c, d, a, b, in[7], S23, 1735328473); /* 31 */
+ GG(b, c, d, a, in[12], S24, 2368359562); /* 32 */
+
+ /* Round 3 */
+#define S31 4
+#define S32 11
+#define S33 16
+#define S34 23
+ HH(a, b, c, d, in[5], S31, 4294588738); /* 33 */
+ HH(d, a, b, c, in[8], S32, 2272392833); /* 34 */
+ HH(c, d, a, b, in[11], S33, 1839030562); /* 35 */
+ HH(b, c, d, a, in[14], S34, 4259657740); /* 36 */
+ HH(a, b, c, d, in[1], S31, 2763975236); /* 37 */
+ HH(d, a, b, c, in[4], S32, 1272893353); /* 38 */
+ HH(c, d, a, b, in[7], S33, 4139469664); /* 39 */
+ HH(b, c, d, a, in[10], S34, 3200236656); /* 40 */
+ HH(a, b, c, d, in[13], S31, 681279174); /* 41 */
+ HH(d, a, b, c, in[0], S32, 3936430074); /* 42 */
+ HH(c, d, a, b, in[3], S33, 3572445317); /* 43 */
+ HH(b, c, d, a, in[6], S34, 76029189); /* 44 */
+ HH(a, b, c, d, in[9], S31, 3654602809); /* 45 */
+ HH(d, a, b, c, in[12], S32, 3873151461); /* 46 */
+ HH(c, d, a, b, in[15], S33, 530742520); /* 47 */
+ HH(b, c, d, a, in[2], S34, 3299628645); /* 48 */
+
+ /* Round 4 */
+#define S41 6
+#define S42 10
+#define S43 15
+#define S44 21
+ II(a, b, c, d, in[0], S41, 4096336452); /* 49 */
+ II(d, a, b, c, in[7], S42, 1126891415); /* 50 */
+ II(c, d, a, b, in[14], S43, 2878612391); /* 51 */
+ II(b, c, d, a, in[5], S44, 4237533241); /* 52 */
+ II(a, b, c, d, in[12], S41, 1700485571); /* 53 */
+ II(d, a, b, c, in[3], S42, 2399980690); /* 54 */
+ II(c, d, a, b, in[10], S43, 4293915773); /* 55 */
+ II(b, c, d, a, in[1], S44, 2240044497); /* 56 */
+ II(a, b, c, d, in[8], S41, 1873313359); /* 57 */
+ II(d, a, b, c, in[15], S42, 4264355552); /* 58 */
+ II(c, d, a, b, in[6], S43, 2734768916); /* 59 */
+ II(b, c, d, a, in[13], S44, 1309151649); /* 60 */
+ II(a, b, c, d, in[4], S41, 4149444226); /* 61 */
+ II(d, a, b, c, in[11], S42, 3174756917); /* 62 */
+ II(c, d, a, b, in[2], S43, 718787259); /* 63 */
+ II(b, c, d, a, in[9], S44, 3951481745); /* 64 */
+
+ buf[0] += a;
+ buf[1] += b;
+ buf[2] += c;
+ buf[3] += d;
+}
+
+/*
+ **********************************************************************
+ ** End of md5.c **
+ ******************************* (cut) ********************************
+ */
+
+#undef FF
+#undef GG
+#undef HH
+#undef II
+#undef F
+#undef G
+#undef H
+#undef I
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+MD5 MD5::Zero; // Initialized to all zeroes
+
+//////////////////////////////////////////////////////////////////////////
+
+MD5Stream::MD5Stream()
+{
+ Reset();
+}
+
+void
+MD5Stream::Reset()
+{
+}
+
+MD5Stream&
+MD5Stream::Append(const void* Data, size_t ByteCount)
+{
+ ZEN_UNUSED(Data);
+ ZEN_UNUSED(ByteCount);
+
+ return *this;
+}
+
+MD5
+MD5Stream::GetHash()
+{
+ MD5 md5{};
+
+ return md5;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+MD5
+MD5::HashMemory(const void* data, size_t byteCount)
+{
+ return MD5Stream().Append(data, byteCount).GetHash();
+}
+
+MD5
+MD5::FromHexString(const char* string)
+{
+ MD5 md5;
+
+ ParseHexBytes(string, 40, md5.Hash);
+
+ return md5;
+}
+
+const char*
+MD5::ToHexString(char* outString /* 32 characters + NUL terminator */) const
+{
+ ToHexBytes(Hash, sizeof(MD5), outString);
+ outString[2 * sizeof(MD5)] = '\0';
+
+ return outString;
+}
+
+StringBuilderBase&
+MD5::ToHexString(StringBuilderBase& outBuilder) const
+{
+ char str[41];
+ ToHexString(str);
+
+ outBuilder.AppendRange(str, &str[40]);
+
+ return outBuilder;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+void
+md5_forcelink()
+{
+}
+
+// doctest::String
+// toString(const MD5& value)
+// {
+// char md5text[2 * sizeof(MD5) + 1];
+// value.ToHexString(md5text);
+
+// return md5text;
+// }
+
+TEST_CASE("MD5")
+{
+ using namespace std::literals;
+
+ auto Input = "jumblesmcgee"sv;
+ auto Output = "28f2200a59c60b75947099d750c2cc50"sv;
+
+ MD5Stream Stream;
+ Stream.Append(Input.data(), Input.length());
+ MD5 Result = Stream.GetHash();
+
+ MD5::String_t Buffer;
+ Result.ToHexString(Buffer);
+
+ CHECK(Output.compare(Buffer));
+
+ MD5 Reresult = MD5::FromHexString(Buffer);
+ Reresult.ToHexString(Buffer);
+ CHECK(Output.compare(Buffer));
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/memory.cpp b/src/zencore/memory.cpp
new file mode 100644
index 000000000..1f148cede
--- /dev/null
+++ b/src/zencore/memory.cpp
@@ -0,0 +1,211 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/intmath.h>
+#include <zencore/memory.h>
+#include <zencore/testing.h>
+#include <zencore/zencore.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <malloc.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <mimalloc.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+#else
+# include <cstdlib>
+#endif
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+static void*
+AlignedAllocImpl(size_t Size, size_t Alignment)
+{
+#if ZEN_PLATFORM_WINDOWS
+# if ZEN_USE_MIMALLOC && 0 /* this path is not functional */
+ return mi_aligned_alloc(Alignment, Size);
+# else
+ return _aligned_malloc(Size, Alignment);
+# endif
+#else
+ // aligned_alloc() states that size must be a multiple of alignment. Some
+ // platforms return null if this requirement isn't met.
+ Size = (Size + Alignment - 1) & ~(Alignment - 1);
+ return std::aligned_alloc(Alignment, Size);
+#endif
+}
+
+void
+AlignedFreeImpl(void* ptr)
+{
+ if (ptr == nullptr)
+ return;
+
+#if ZEN_PLATFORM_WINDOWS
+# if ZEN_USE_MIMALLOC && 0 /* this path is not functional */
+ return mi_free(ptr);
+# else
+ _aligned_free(ptr);
+# endif
+#else
+ std::free(ptr);
+#endif
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+MemoryArena::MemoryArena()
+{
+}
+
+MemoryArena::~MemoryArena()
+{
+}
+
+void*
+MemoryArena::Alloc(size_t Size, size_t Alignment)
+{
+ return AlignedAllocImpl(Size, Alignment);
+}
+
+void
+MemoryArena::Free(void* ptr)
+{
+ AlignedFreeImpl(ptr);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void*
+Memory::Alloc(size_t Size, size_t Alignment)
+{
+ return AlignedAllocImpl(Size, Alignment);
+}
+
+void
+Memory::Free(void* ptr)
+{
+ AlignedFreeImpl(ptr);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ChunkingLinearAllocator::ChunkingLinearAllocator(uint64_t ChunkSize, uint64_t ChunkAlignment)
+: m_ChunkSize(ChunkSize)
+, m_ChunkAlignment(ChunkAlignment)
+{
+}
+
+ChunkingLinearAllocator::~ChunkingLinearAllocator()
+{
+ Reset();
+}
+
+void
+ChunkingLinearAllocator::Reset()
+{
+ for (void* ChunkEntry : m_ChunkList)
+ {
+ Memory::Free(ChunkEntry);
+ }
+ m_ChunkList.clear();
+
+ m_ChunkCursor = nullptr;
+ m_ChunkBytesRemain = 0;
+}
+
+void*
+ChunkingLinearAllocator::Alloc(size_t Size, size_t Alignment)
+{
+ ZEN_ASSERT_SLOW(zen::IsPow2(Alignment));
+
+ // This could be improved in a bunch of ways
+ //
+ // * We pessimistically allocate memory even though there may be enough memory available for a single allocation due to the way we take
+ // alignment into account below
+ // * The block allocation size could be chosen to minimize slack for the case when multiple oversize allocations are made rather than
+ // minimizing the number of chunks
+ // * ...
+
+ const uint64_t AllocationSize = zen::RoundUp(Size, Alignment);
+
+ if (m_ChunkBytesRemain < (AllocationSize + Alignment - 1))
+ {
+ const uint64_t ChunkSize = zen::RoundUp(zen::Max(m_ChunkSize, Size), m_ChunkSize);
+ void* ChunkPtr = Memory::Alloc(ChunkSize, m_ChunkAlignment);
+ m_ChunkCursor = reinterpret_cast<uint8_t*>(ChunkPtr);
+ m_ChunkBytesRemain = ChunkSize;
+ m_ChunkList.push_back(ChunkPtr);
+ }
+
+ const uint64_t AlignFixup = (Alignment - reinterpret_cast<uintptr_t>(m_ChunkCursor)) & (Alignment - 1);
+ void* ReturnPtr = m_ChunkCursor + AlignFixup;
+ const uint64_t Delta = AlignFixup + AllocationSize;
+
+ ZEN_ASSERT_SLOW(m_ChunkBytesRemain >= Delta);
+
+ m_ChunkCursor += Delta;
+ m_ChunkBytesRemain -= Delta;
+
+ ZEN_ASSERT_SLOW(IsPointerAligned(ReturnPtr, Alignment));
+
+ return ReturnPtr;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Unit tests
+//
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("ChunkingLinearAllocator")
+{
+ ChunkingLinearAllocator Allocator(4096);
+
+ void* p1 = Allocator.Alloc(1, 1);
+ void* p2 = Allocator.Alloc(1, 1);
+
+ CHECK(p1 != p2);
+
+ void* p3 = Allocator.Alloc(1, 4);
+ CHECK(IsPointerAligned(p3, 4));
+
+ void* p3_2 = Allocator.Alloc(1, 4);
+ CHECK(IsPointerAligned(p3_2, 4));
+
+ void* p4 = Allocator.Alloc(1, 8);
+ CHECK(IsPointerAligned(p4, 8));
+
+ for (int i = 0; i < 100; ++i)
+ {
+ void* p0 = Allocator.Alloc(64);
+ ZEN_UNUSED(p0);
+ }
+}
+
+TEST_CASE("MemoryView")
+{
+ {
+ uint8_t Array1[16] = {};
+ MemoryView View1 = MakeMemoryView(Array1);
+ CHECK(View1.GetSize() == 16);
+ }
+
+ {
+ uint32_t Array2[16] = {};
+ MemoryView View2 = MakeMemoryView(Array2);
+ CHECK(View2.GetSize() == 64);
+ }
+
+ CHECK(MakeMemoryView<float>({1.0f, 1.2f}).GetSize() == 8);
+}
+
+void
+memory_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/mpscqueue.cpp b/src/zencore/mpscqueue.cpp
new file mode 100644
index 000000000..29c76c3ca
--- /dev/null
+++ b/src/zencore/mpscqueue.cpp
@@ -0,0 +1,25 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/mpscqueue.h>
+
+#include <zencore/testing.h>
+#include <string>
+
+namespace zen {
+
+#if ZEN_WITH_TESTS && 0
+TEST_CASE("mpsc")
+{
+ MpscQueue<std::string> Queue;
+ Queue.Enqueue("hello");
+ std::optional<std::string> Value = Queue.Dequeue();
+ CHECK_EQ(Value, "hello");
+}
+#endif
+
+void
+mpscqueue_forcelink()
+{
+}
+
+} // namespace zen \ No newline at end of file
diff --git a/src/zencore/refcount.cpp b/src/zencore/refcount.cpp
new file mode 100644
index 000000000..c6c47b04d
--- /dev/null
+++ b/src/zencore/refcount.cpp
@@ -0,0 +1,65 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/refcount.h>
+
+#include <zencore/testing.h>
+
+#include <functional>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+struct TestRefClass : public RefCounted
+{
+ ~TestRefClass()
+ {
+ if (OnDestroy)
+ OnDestroy();
+ }
+
+ using RefCounted::RefCount;
+
+ std::function<void()> OnDestroy;
+};
+
+void
+refcount_forcelink()
+{
+}
+
+TEST_CASE("RefPtr")
+{
+ RefPtr<TestRefClass> Ref;
+ Ref = new TestRefClass;
+
+ bool IsDestroyed = false;
+ Ref->OnDestroy = [&] { IsDestroyed = true; };
+
+ CHECK(IsDestroyed == false);
+ CHECK(Ref->RefCount() == 1);
+
+ RefPtr<TestRefClass> Ref2;
+ Ref2 = Ref;
+
+ CHECK(IsDestroyed == false);
+ CHECK(Ref->RefCount() == 2);
+
+ RefPtr<TestRefClass> Ref3;
+ Ref2 = Ref3;
+
+ CHECK(IsDestroyed == false);
+ CHECK(Ref->RefCount() == 1);
+ Ref = Ref3;
+
+ CHECK(IsDestroyed == true);
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/session.cpp b/src/zencore/session.cpp
new file mode 100644
index 000000000..ce4bfae1b
--- /dev/null
+++ b/src/zencore/session.cpp
@@ -0,0 +1,35 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencore/session.h"
+
+#include <zencore/uid.h>
+
+#include <mutex>
+
+namespace zen {
+
+static Oid GlobalSessionId;
+static char GlobalSessionString[Oid::StringLength];
+static std::once_flag SessionInitFlag;
+
+Oid
+GetSessionId()
+{
+ std::call_once(SessionInitFlag, [&] {
+ GlobalSessionId.Generate();
+ GlobalSessionId.ToString(GlobalSessionString);
+ });
+
+ return GlobalSessionId;
+}
+
+std::string_view
+GetSessionIdString()
+{
+ // Ensure we actually have a generated session identifier
+ std::ignore = GetSessionId();
+
+ return std::string_view(GlobalSessionString, Oid::StringLength);
+}
+
+} // namespace zen \ No newline at end of file
diff --git a/src/zencore/sha1.cpp b/src/zencore/sha1.cpp
new file mode 100644
index 000000000..3ee74d7d8
--- /dev/null
+++ b/src/zencore/sha1.cpp
@@ -0,0 +1,443 @@
+// //////////////////////////////////////////////////////////
+// sha1.cpp
+// Copyright (c) 2014,2015 Stephan Brumme. All rights reserved.
+// see http://create.stephan-brumme.com/disclaimer.html
+//
+
+#include <zencore/sha1.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/zencore.h>
+
+#include <string.h>
+
+// big endian architectures need #define __BYTE_ORDER __BIG_ENDIAN
+#if ZEN_PLATFORM_LINUX
+# include <endian.h>
+#endif
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+SHA1 SHA1::Zero; // Initialized to all zeroes
+
+//////////////////////////////////////////////////////////////////////////
+
+SHA1Stream::SHA1Stream()
+{
+ Reset();
+}
+
+void
+SHA1Stream::Reset()
+{
+ m_NumBytes = 0;
+ m_BufferSize = 0;
+
+ // according to RFC 1321
+ m_Hash[0] = 0x67452301;
+ m_Hash[1] = 0xefcdab89;
+ m_Hash[2] = 0x98badcfe;
+ m_Hash[3] = 0x10325476;
+ m_Hash[4] = 0xc3d2e1f0;
+}
+
+namespace {
+ // mix functions for processBlock()
+ inline uint32_t f1(uint32_t b, uint32_t c, uint32_t d)
+ {
+ return d ^ (b & (c ^ d)); // original: f = (b & c) | ((~b) & d);
+ }
+
+ inline uint32_t f2(uint32_t b, uint32_t c, uint32_t d) { return b ^ c ^ d; }
+
+ inline uint32_t f3(uint32_t b, uint32_t c, uint32_t d) { return (b & c) | (b & d) | (c & d); }
+
+ inline uint32_t rotate(uint32_t a, uint32_t c) { return (a << c) | (a >> (32 - c)); }
+
+ inline uint32_t swap(uint32_t x)
+ {
+#if defined(__GNUC__) || defined(__clang__)
+ return __builtin_bswap32(x);
+#endif
+#ifdef MSC_VER
+ return _byteswap_ulong(x);
+#endif
+
+ return (x >> 24) | ((x >> 8) & 0x0000FF00) | ((x << 8) & 0x00FF0000) | (x << 24);
+ }
+} // namespace
+
+/// process 64 bytes
+void
+SHA1Stream::ProcessBlock(const void* data)
+{
+ // get last hash
+ uint32_t a = m_Hash[0];
+ uint32_t b = m_Hash[1];
+ uint32_t c = m_Hash[2];
+ uint32_t d = m_Hash[3];
+ uint32_t e = m_Hash[4];
+
+ // data represented as 16x 32-bit words
+ const uint32_t* input = (uint32_t*)data;
+ // convert to big endian
+ uint32_t words[80];
+ for (int i = 0; i < 16; i++)
+#if defined(__BYTE_ORDER) && (__BYTE_ORDER != 0) && (__BYTE_ORDER == __BIG_ENDIAN)
+ words[i] = input[i];
+#else
+ words[i] = swap(input[i]);
+#endif
+
+ // extend to 80 words
+ for (int i = 16; i < 80; i++)
+ words[i] = rotate(words[i - 3] ^ words[i - 8] ^ words[i - 14] ^ words[i - 16], 1);
+
+ // first round
+ for (int i = 0; i < 4; i++)
+ {
+ int offset = 5 * i;
+ e += rotate(a, 5) + f1(b, c, d) + words[offset] + 0x5a827999;
+ b = rotate(b, 30);
+ d += rotate(e, 5) + f1(a, b, c) + words[offset + 1] + 0x5a827999;
+ a = rotate(a, 30);
+ c += rotate(d, 5) + f1(e, a, b) + words[offset + 2] + 0x5a827999;
+ e = rotate(e, 30);
+ b += rotate(c, 5) + f1(d, e, a) + words[offset + 3] + 0x5a827999;
+ d = rotate(d, 30);
+ a += rotate(b, 5) + f1(c, d, e) + words[offset + 4] + 0x5a827999;
+ c = rotate(c, 30);
+ }
+
+ // second round
+ for (int i = 4; i < 8; i++)
+ {
+ int offset = 5 * i;
+ e += rotate(a, 5) + f2(b, c, d) + words[offset] + 0x6ed9eba1;
+ b = rotate(b, 30);
+ d += rotate(e, 5) + f2(a, b, c) + words[offset + 1] + 0x6ed9eba1;
+ a = rotate(a, 30);
+ c += rotate(d, 5) + f2(e, a, b) + words[offset + 2] + 0x6ed9eba1;
+ e = rotate(e, 30);
+ b += rotate(c, 5) + f2(d, e, a) + words[offset + 3] + 0x6ed9eba1;
+ d = rotate(d, 30);
+ a += rotate(b, 5) + f2(c, d, e) + words[offset + 4] + 0x6ed9eba1;
+ c = rotate(c, 30);
+ }
+
+ // third round
+ for (int i = 8; i < 12; i++)
+ {
+ int offset = 5 * i;
+ e += rotate(a, 5) + f3(b, c, d) + words[offset] + 0x8f1bbcdc;
+ b = rotate(b, 30);
+ d += rotate(e, 5) + f3(a, b, c) + words[offset + 1] + 0x8f1bbcdc;
+ a = rotate(a, 30);
+ c += rotate(d, 5) + f3(e, a, b) + words[offset + 2] + 0x8f1bbcdc;
+ e = rotate(e, 30);
+ b += rotate(c, 5) + f3(d, e, a) + words[offset + 3] + 0x8f1bbcdc;
+ d = rotate(d, 30);
+ a += rotate(b, 5) + f3(c, d, e) + words[offset + 4] + 0x8f1bbcdc;
+ c = rotate(c, 30);
+ }
+
+ // fourth round
+ for (int i = 12; i < 16; i++)
+ {
+ int offset = 5 * i;
+ e += rotate(a, 5) + f2(b, c, d) + words[offset] + 0xca62c1d6;
+ b = rotate(b, 30);
+ d += rotate(e, 5) + f2(a, b, c) + words[offset + 1] + 0xca62c1d6;
+ a = rotate(a, 30);
+ c += rotate(d, 5) + f2(e, a, b) + words[offset + 2] + 0xca62c1d6;
+ e = rotate(e, 30);
+ b += rotate(c, 5) + f2(d, e, a) + words[offset + 3] + 0xca62c1d6;
+ d = rotate(d, 30);
+ a += rotate(b, 5) + f2(c, d, e) + words[offset + 4] + 0xca62c1d6;
+ c = rotate(c, 30);
+ }
+
+ // update hash
+ m_Hash[0] += a;
+ m_Hash[1] += b;
+ m_Hash[2] += c;
+ m_Hash[3] += d;
+ m_Hash[4] += e;
+}
+
+/// add arbitrary number of bytes
+SHA1Stream&
+SHA1Stream::Append(const void* data, size_t byteCount)
+{
+ const uint8_t* current = (const uint8_t*)data;
+
+ if (m_BufferSize > 0)
+ {
+ while (byteCount > 0 && m_BufferSize < BlockSize)
+ {
+ m_Buffer[m_BufferSize++] = *current++;
+ byteCount--;
+ }
+ }
+
+ // full buffer
+ if (m_BufferSize == BlockSize)
+ {
+ ProcessBlock((void*)m_Buffer);
+ m_NumBytes += BlockSize;
+ m_BufferSize = 0;
+ }
+
+ // no more data ?
+ if (byteCount == 0)
+ return *this;
+
+ // process full blocks
+ while (byteCount >= BlockSize)
+ {
+ ProcessBlock(current);
+ current += BlockSize;
+ m_NumBytes += BlockSize;
+ byteCount -= BlockSize;
+ }
+
+ // keep remaining bytes in buffer
+ while (byteCount > 0)
+ {
+ m_Buffer[m_BufferSize++] = *current++;
+ byteCount--;
+ }
+
+ return *this;
+}
+
+/// process final block, less than 64 bytes
+void
+SHA1Stream::ProcessBuffer()
+{
+ // the input bytes are considered as bits strings, where the first bit is the most significant bit of the byte
+
+ // - append "1" bit to message
+ // - append "0" bits until message length in bit mod 512 is 448
+ // - append length as 64 bit integer
+
+ // number of bits
+ size_t paddedLength = m_BufferSize * 8;
+
+ // plus one bit set to 1 (always appended)
+ paddedLength++;
+
+ // number of bits must be (numBits % 512) = 448
+ size_t lower11Bits = paddedLength & 511;
+ if (lower11Bits <= 448)
+ paddedLength += 448 - lower11Bits;
+ else
+ paddedLength += 512 + 448 - lower11Bits;
+ // convert from bits to bytes
+ paddedLength /= 8;
+
+ // only needed if additional data flows over into a second block
+ unsigned char extra[BlockSize];
+
+ // append a "1" bit, 128 => binary 10000000
+ if (m_BufferSize < BlockSize)
+ m_Buffer[m_BufferSize] = 128;
+ else
+ extra[0] = 128;
+
+ size_t i;
+ for (i = m_BufferSize + 1; i < BlockSize; i++)
+ m_Buffer[i] = 0;
+ for (; i < paddedLength; i++)
+ extra[i - BlockSize] = 0;
+
+ // add message length in bits as 64 bit number
+ uint64_t msgBits = 8 * (m_NumBytes + m_BufferSize);
+ // find right position
+ unsigned char* addLength;
+ if (paddedLength < BlockSize)
+ addLength = m_Buffer + paddedLength;
+ else
+ addLength = extra + paddedLength - BlockSize;
+
+ // must be big endian
+ *addLength++ = (unsigned char)((msgBits >> 56) & 0xFF);
+ *addLength++ = (unsigned char)((msgBits >> 48) & 0xFF);
+ *addLength++ = (unsigned char)((msgBits >> 40) & 0xFF);
+ *addLength++ = (unsigned char)((msgBits >> 32) & 0xFF);
+ *addLength++ = (unsigned char)((msgBits >> 24) & 0xFF);
+ *addLength++ = (unsigned char)((msgBits >> 16) & 0xFF);
+ *addLength++ = (unsigned char)((msgBits >> 8) & 0xFF);
+ *addLength = (unsigned char)(msgBits & 0xFF);
+
+ // process blocks
+ ProcessBlock(m_Buffer);
+ // flowed over into a second block ?
+ if (paddedLength > BlockSize)
+ ProcessBlock(extra);
+}
+
+/// return latest hash as bytes
+SHA1
+SHA1Stream::GetHash()
+{
+ SHA1 sha1;
+ // save old hash if buffer is partially filled
+ uint32_t oldHash[HashValues];
+ for (int i = 0; i < HashValues; i++)
+ oldHash[i] = m_Hash[i];
+
+ // process remaining bytes
+ ProcessBuffer();
+
+ unsigned char* current = sha1.Hash;
+ for (int i = 0; i < HashValues; i++)
+ {
+ *current++ = (m_Hash[i] >> 24) & 0xFF;
+ *current++ = (m_Hash[i] >> 16) & 0xFF;
+ *current++ = (m_Hash[i] >> 8) & 0xFF;
+ *current++ = m_Hash[i] & 0xFF;
+
+ // restore old hash
+ m_Hash[i] = oldHash[i];
+ }
+
+ return sha1;
+}
+
+/// compute SHA1 of a memory block
+SHA1
+SHA1Stream::Compute(const void* data, size_t byteCount)
+{
+ Reset();
+ Append(data, byteCount);
+ return GetHash();
+}
+
+SHA1
+SHA1::HashMemory(const void* data, size_t byteCount)
+{
+ return SHA1Stream().Append(data, byteCount).GetHash();
+}
+
+SHA1
+SHA1::FromHexString(const char* string)
+{
+ SHA1 sha1;
+
+ ParseHexBytes(string, 40, sha1.Hash);
+
+ return sha1;
+}
+
+const char*
+SHA1::ToHexString(char* outString /* 40 characters + NUL terminator */) const
+{
+ ToHexBytes(Hash, sizeof(SHA1), outString);
+ outString[2 * sizeof(SHA1)] = '\0';
+
+ return outString;
+}
+
+StringBuilderBase&
+SHA1::ToHexString(StringBuilderBase& outBuilder) const
+{
+ char str[41];
+ ToHexString(str);
+
+ outBuilder.AppendRange(str, &str[40]);
+
+ return outBuilder;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+void
+sha1_forcelink()
+{
+}
+
+// doctest::String
+// toString(const SHA1& value)
+// {
+// char sha1text[2 * sizeof(SHA1) + 1];
+// value.ToHexString(sha1text);
+
+// return sha1text;
+// }
+
+TEST_CASE("SHA1")
+{
+ uint8_t sha1_empty[20] = {0xda, 0x39, 0xa3, 0xee, 0x5e, 0x6b, 0x4b, 0x0d, 0x32, 0x55,
+ 0xbf, 0xef, 0x95, 0x60, 0x18, 0x90, 0xaf, 0xd8, 0x07, 0x09};
+ SHA1 sha1z;
+ memcpy(sha1z.Hash, sha1_empty, sizeof sha1z.Hash);
+
+ SUBCASE("Empty string")
+ {
+ SHA1 sha1 = SHA1::HashMemory(nullptr, 0);
+
+ CHECK(sha1 == sha1z);
+ }
+
+ SUBCASE("Empty stream")
+ {
+ SHA1Stream sha1s;
+ sha1s.Append(nullptr, 0);
+ sha1s.Append(nullptr, 0);
+ sha1s.Append(nullptr, 0);
+ CHECK(sha1s.GetHash() == sha1z);
+ }
+
+ SUBCASE("SHA1 from string")
+ {
+ const SHA1 sha1empty = SHA1::FromHexString("da39a3ee5e6b4b0d3255bfef95601890afd80709");
+
+ CHECK(sha1z == sha1empty);
+ }
+
+ SUBCASE("SHA1 to string")
+ {
+ char sha1str[41];
+ sha1z.ToHexString(sha1str);
+
+ CHECK(StringEquals(sha1str, "da39a3ee5e6b4b0d3255bfef95601890afd80709"));
+ }
+
+ SUBCASE("Hash ABC")
+ {
+ const SHA1 sha1abc = SHA1::FromHexString("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8");
+
+ SHA1Stream sha1s;
+
+ sha1s.Append("A", 1);
+ sha1s.Append("B", 1);
+ sha1s.Append("C", 1);
+ CHECK(sha1s.GetHash() == sha1abc);
+
+ sha1s.Reset();
+ sha1s.Append("AB", 2);
+ sha1s.Append("C", 1);
+ CHECK(sha1s.GetHash() == sha1abc);
+
+ sha1s.Reset();
+ sha1s.Append("ABC", 3);
+ CHECK(sha1s.GetHash() == sha1abc);
+
+ sha1s.Reset();
+ sha1s.Append("A", 1);
+ sha1s.Append("BC", 2);
+ CHECK(sha1s.GetHash() == sha1abc);
+ }
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/sharedbuffer.cpp b/src/zencore/sharedbuffer.cpp
new file mode 100644
index 000000000..200e06972
--- /dev/null
+++ b/src/zencore/sharedbuffer.cpp
@@ -0,0 +1,146 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/sharedbuffer.h>
+
+#include <zencore/testing.h>
+
+#include <memory.h>
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+UniqueBuffer
+UniqueBuffer::Alloc(uint64_t Size)
+{
+ void* Buffer = Memory::Alloc(Size, 16);
+ IoBufferCore* Owner = new IoBufferCore(Buffer, Size);
+ Owner->SetIsOwnedByThis(true);
+ Owner->SetIsImmutable(false);
+
+ return UniqueBuffer(Owner);
+}
+
+UniqueBuffer
+UniqueBuffer::MakeMutableView(void* DataPtr, uint64_t Size)
+{
+ IoBufferCore* Owner = new IoBufferCore(DataPtr, Size);
+ Owner->SetIsImmutable(false);
+ return UniqueBuffer(Owner);
+}
+
+UniqueBuffer::UniqueBuffer(IoBufferCore* Owner) : m_Buffer(Owner)
+{
+}
+
+SharedBuffer
+UniqueBuffer::MoveToShared()
+{
+ return SharedBuffer(std::move(m_Buffer));
+}
+
+void
+UniqueBuffer::Reset()
+{
+ m_Buffer = nullptr;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+SharedBuffer::SharedBuffer(UniqueBuffer&& InBuffer) : m_Buffer(std::move(InBuffer.m_Buffer))
+{
+}
+
+SharedBuffer
+SharedBuffer::MakeOwned() const&
+{
+ if (IsOwned() || !m_Buffer)
+ {
+ return *this;
+ }
+ else
+ {
+ return Clone(GetView());
+ }
+}
+
+SharedBuffer
+SharedBuffer::MakeOwned() &&
+{
+ if (IsOwned())
+ {
+ return std::move(*this);
+ }
+ else
+ {
+ return Clone(GetView());
+ }
+}
+
+SharedBuffer
+SharedBuffer::MakeView(MemoryView View, SharedBuffer OuterBuffer)
+{
+ if (OuterBuffer)
+ {
+ ZEN_ASSERT(OuterBuffer.GetView().Contains(View));
+ }
+
+ if (View == OuterBuffer.GetView())
+ {
+ // Reference to the full buffer contents, so just return the "outer"
+ return OuterBuffer;
+ }
+
+ IoBufferCore* NewCore = new IoBufferCore(OuterBuffer.m_Buffer, View.GetData(), View.GetSize());
+ NewCore->SetIsImmutable(true);
+ return SharedBuffer(NewCore);
+}
+
+SharedBuffer
+SharedBuffer::MakeView(const void* Data, uint64_t Size)
+{
+ return SharedBuffer(new IoBufferCore(const_cast<void*>(Data), Size));
+}
+
+SharedBuffer
+SharedBuffer::Clone()
+{
+ const uint64_t Size = GetSize();
+ void* Buffer = Memory::Alloc(Size, 16);
+ auto NewOwner = new IoBufferCore(Buffer, Size);
+ NewOwner->SetIsOwnedByThis(true);
+ memcpy(Buffer, m_Buffer->DataPointer(), Size);
+
+ return SharedBuffer(NewOwner);
+}
+
+SharedBuffer
+SharedBuffer::Clone(MemoryView View)
+{
+ const uint64_t Size = View.GetSize();
+ void* Buffer = Memory::Alloc(Size, 16);
+ auto NewOwner = new IoBufferCore(Buffer, Size);
+ NewOwner->SetIsOwnedByThis(true);
+ memcpy(Buffer, View.GetData(), Size);
+
+ return SharedBuffer(NewOwner);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+void
+sharedbuffer_forcelink()
+{
+}
+
+TEST_CASE("SharedBuffer")
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/stats.cpp b/src/zencore/stats.cpp
new file mode 100644
index 000000000..372bc42f8
--- /dev/null
+++ b/src/zencore/stats.cpp
@@ -0,0 +1,715 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencore/stats.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include "zencore/intmath.h"
+#include "zencore/thread.h"
+#include "zencore/timer.h"
+
+#include <cmath>
+#include <gsl/gsl-lite.hpp>
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+#endif
+
+//
+// Derived from https://github.com/dln/medida/blob/master/src/medida/stats/ewma.cc
+//
+
+namespace zen::metrics {
+
+static constinit int kTickIntervalInSeconds = 5;
+static constinit double kSecondsPerMinute = 60.0;
+static constinit int kOneMinute = 1;
+static constinit int kFiveMinutes = 5;
+static constinit int kFifteenMinutes = 15;
+
+static const double kM1_ALPHA = 1.0 - std::exp(-kTickIntervalInSeconds / kSecondsPerMinute / kOneMinute);
+static const double kM5_ALPHA = 1.0 - std::exp(-kTickIntervalInSeconds / kSecondsPerMinute / kFiveMinutes);
+static const double kM15_ALPHA = 1.0 - std::exp(-kTickIntervalInSeconds / kSecondsPerMinute / kFifteenMinutes);
+
+static const uint64_t CountPerTick = GetHifreqTimerFrequencySafe() * kTickIntervalInSeconds;
+static const uint64_t CountPerSecond = GetHifreqTimerFrequencySafe();
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+RawEWMA::Tick(double Alpha, uint64_t Interval, uint64_t Count, bool IsInitialUpdate)
+{
+ const double InstantRate = double(Count) / Interval;
+
+ if (IsInitialUpdate)
+ {
+ m_Rate.store(InstantRate, std::memory_order_release);
+ }
+ else
+ {
+ double Delta = Alpha * (InstantRate - m_Rate);
+
+#if defined(__cpp_lib_atomic_float)
+ m_Rate.fetch_add(Delta);
+#else
+ double Value = m_Rate.load(std::memory_order_acquire);
+ double Next;
+ do
+ {
+ Next = Value + Delta;
+ } while (!m_Rate.compare_exchange_weak(Value, Next, std::memory_order_relaxed));
+#endif
+ }
+}
+
+double
+RawEWMA::Rate() const
+{
+ return m_Rate.load(std::memory_order_relaxed) * CountPerSecond;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+Meter::Meter() : m_StartTick{GetHifreqTimerValue()}, m_LastTick(m_StartTick.load())
+{
+}
+
+Meter::~Meter()
+{
+}
+
+void
+Meter::TickIfNecessary()
+{
+ uint64_t OldTick = m_LastTick.load();
+ const uint64_t NewTick = GetHifreqTimerValue();
+ const uint64_t Age = NewTick - OldTick;
+
+ if (Age > CountPerTick)
+ {
+ // Ensure only one thread at a time updates the time. This
+ // works because our tick interval should be sufficiently
+ // long to ensure two threads don't end up inside this block
+
+ if (m_LastTick.compare_exchange_strong(OldTick, NewTick))
+ {
+ m_Remainder.fetch_add(Age);
+
+ do
+ {
+ int64_t Remain = m_Remainder.load(std::memory_order_relaxed);
+
+ if (Remain < 0)
+ {
+ return;
+ }
+
+ if (m_Remainder.compare_exchange_strong(Remain, Remain - CountPerTick))
+ {
+ Tick();
+ }
+ } while (true);
+ }
+ }
+}
+
+void
+Meter::Tick()
+{
+ const uint64_t PendingCount = m_PendingCount.exchange(0);
+ const bool IsFirstTick = m_IsFirstTick;
+
+ if (IsFirstTick)
+ {
+ m_IsFirstTick = false;
+ }
+
+ m_RateM1.Tick(kM1_ALPHA, CountPerTick, PendingCount, IsFirstTick);
+ m_RateM5.Tick(kM5_ALPHA, CountPerTick, PendingCount, IsFirstTick);
+ m_RateM15.Tick(kM15_ALPHA, CountPerTick, PendingCount, IsFirstTick);
+}
+
+double
+Meter::Rate1()
+{
+ TickIfNecessary();
+
+ return m_RateM1.Rate();
+}
+
+double
+Meter::Rate5()
+{
+ TickIfNecessary();
+
+ return m_RateM5.Rate();
+}
+
+double
+Meter::Rate15()
+{
+ TickIfNecessary();
+
+ return m_RateM15.Rate();
+}
+
+double
+Meter::MeanRate() const
+{
+ const uint64_t Count = m_TotalCount.load(std::memory_order_relaxed);
+
+ if (Count == 0)
+ {
+ return 0.0;
+ }
+
+ const uint64_t Elapsed = GetHifreqTimerValue() - m_StartTick;
+
+ return (double(Count) * GetHifreqTimerFrequency()) / Elapsed;
+}
+
+void
+Meter::Mark(uint64_t Count)
+{
+ TickIfNecessary();
+
+ m_TotalCount.fetch_add(Count);
+ m_PendingCount.fetch_add(Count);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+// TODO: should consider a cheaper RNG here, this will run for every thread
+// that gets created
+
+thread_local std::mt19937_64 ThreadLocalRng;
+
+UniformSample::UniformSample(uint32_t ReservoirSize) : m_Values(ReservoirSize)
+{
+}
+
+UniformSample::~UniformSample()
+{
+}
+
+void
+UniformSample::Clear()
+{
+ for (auto& Value : m_Values)
+ {
+ Value.store(0);
+ }
+ m_SampleCounter = 0;
+}
+
+uint32_t
+UniformSample::Size() const
+{
+ return gsl::narrow_cast<uint32_t>(Min(m_SampleCounter.load(), m_Values.size()));
+}
+
+void
+UniformSample::Update(int64_t Value)
+{
+ const uint64_t Count = m_SampleCounter++;
+ const uint64_t Size = m_Values.size();
+
+ if (Count < Size)
+ {
+ m_Values[Count] = Value;
+ }
+ else
+ {
+ // Randomly choose an old entry to potentially replace (the probability
+ // of replacing an entry diminishes with time)
+
+ std::uniform_int_distribution<uint64_t> UniformDist(0, Count);
+ uint64_t SampleIndex = UniformDist(ThreadLocalRng);
+
+ if (SampleIndex < Size)
+ {
+ m_Values[SampleIndex].store(Value, std::memory_order_release);
+ }
+ }
+}
+
+SampleSnapshot
+UniformSample::Snapshot() const
+{
+ uint64_t ValuesSize = Size();
+ std::vector<double> Values(ValuesSize);
+
+ for (int i = 0, n = int(ValuesSize); i < n; ++i)
+ {
+ Values[i] = double(m_Values[i]);
+ }
+
+ return SampleSnapshot(std::move(Values));
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+Histogram::Histogram(int32_t SampleCount) : m_Sample(SampleCount)
+{
+}
+
+Histogram::~Histogram()
+{
+}
+
+void
+Histogram::Clear()
+{
+ m_Min = m_Max = m_Sum = m_Count = 0;
+ m_Sample.Clear();
+}
+
+void
+Histogram::Update(int64_t Value)
+{
+ m_Sample.Update(Value);
+
+ if (m_Count == 0)
+ {
+ m_Min = m_Max = Value;
+ }
+ else
+ {
+ int64_t CurrentMax = m_Max.load(std::memory_order_relaxed);
+
+ while ((CurrentMax < Value) && !m_Max.compare_exchange_weak(CurrentMax, Value))
+ {
+ }
+
+ int64_t CurrentMin = m_Min.load(std::memory_order_relaxed);
+
+ while ((CurrentMin > Value) && !m_Min.compare_exchange_weak(CurrentMin, Value))
+ {
+ }
+ }
+
+ m_Sum += Value;
+ ++m_Count;
+}
+
+int64_t
+Histogram::Max() const
+{
+ return m_Max.load(std::memory_order_relaxed);
+}
+
+int64_t
+Histogram::Min() const
+{
+ return m_Min.load(std::memory_order_relaxed);
+}
+
+double
+Histogram::Mean() const
+{
+ if (m_Count)
+ {
+ return double(m_Sum.load(std::memory_order_relaxed)) / m_Count;
+ }
+ else
+ {
+ return 0.0;
+ }
+}
+
+uint64_t
+Histogram::Count() const
+{
+ return m_Count.load(std::memory_order_relaxed);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+SampleSnapshot::SampleSnapshot(std::vector<double>&& Values) : m_Values(std::move(Values))
+{
+ std::sort(begin(m_Values), end(m_Values));
+}
+
+SampleSnapshot::~SampleSnapshot()
+{
+}
+
+double
+SampleSnapshot::GetQuantileValue(double Quantile)
+{
+ ZEN_ASSERT((Quantile >= 0.0) && (Quantile <= 1.0));
+
+ if (m_Values.empty())
+ {
+ return 0.0;
+ }
+
+ const double Pos = Quantile * (m_Values.size() + 1);
+
+ if (Pos < 1)
+ {
+ return m_Values.front();
+ }
+
+ if (Pos >= m_Values.size())
+ {
+ return m_Values.back();
+ }
+
+ const int32_t Index = (int32_t)Pos;
+ const double Lower = m_Values[Index - 1];
+ const double Upper = m_Values[Index];
+
+ // Lerp
+ return Lower + (Pos - std::floor(Pos)) * (Upper - Lower);
+}
+
+const std::vector<double>&
+SampleSnapshot::GetValues() const
+{
+ return m_Values;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+OperationTiming::OperationTiming(int32_t SampleCount) : m_Histogram{SampleCount}
+{
+}
+
+OperationTiming::~OperationTiming()
+{
+}
+
+void
+OperationTiming::Update(int64_t Duration)
+{
+ m_Meter.Mark(1);
+ m_Histogram.Update(Duration);
+}
+
+int64_t
+OperationTiming::Max() const
+{
+ return m_Histogram.Max();
+}
+
+int64_t
+OperationTiming::Min() const
+{
+ return m_Histogram.Min();
+}
+
+double
+OperationTiming::Mean() const
+{
+ return m_Histogram.Mean();
+}
+
+uint64_t
+OperationTiming::Count() const
+{
+ return m_Meter.Count();
+}
+
+OperationTiming::Scope::Scope(OperationTiming& Outer) : m_Outer(Outer), m_StartTick(GetHifreqTimerValue())
+{
+}
+
+OperationTiming::Scope::~Scope()
+{
+ Stop();
+}
+
+void
+OperationTiming::Scope::Stop()
+{
+ if (m_StartTick != 0)
+ {
+ m_Outer.Update(GetHifreqTimerValue() - m_StartTick);
+ m_StartTick = 0;
+ }
+}
+
+void
+OperationTiming::Scope::Cancel()
+{
+ m_StartTick = 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+RequestStats::RequestStats(int32_t SampleCount) : m_RequestTimeHistogram{SampleCount}, m_BytesHistogram{SampleCount}
+{
+}
+
+RequestStats::~RequestStats()
+{
+}
+
+void
+RequestStats::Update(int64_t Duration, int64_t Bytes)
+{
+ m_RequestMeter.Mark(1);
+ m_RequestTimeHistogram.Update(Duration);
+
+ m_BytesMeter.Mark(Bytes);
+ m_BytesHistogram.Update(Bytes);
+}
+
+uint64_t
+RequestStats::Count() const
+{
+ return m_RequestMeter.Count();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+EmitSnapshot(Meter& Stat, CbObjectWriter& Cbo)
+{
+ Cbo << "count" << Stat.Count();
+ Cbo << "rate_mean" << Stat.MeanRate();
+ Cbo << "rate_1" << Stat.Rate1() << "rate_5" << Stat.Rate5() << "rate_15" << Stat.Rate15();
+}
+
+void
+RequestStats::EmitSnapshot(std::string_view Tag, CbObjectWriter& Cbo)
+{
+ Cbo.BeginObject(Tag);
+
+ Cbo.BeginObject("requests");
+ metrics::EmitSnapshot(m_RequestMeter, Cbo);
+ metrics::EmitSnapshot(m_RequestTimeHistogram, Cbo, GetHifreqTimerToSeconds());
+ Cbo.EndObject();
+
+ Cbo.BeginObject("bytes");
+ metrics::EmitSnapshot(m_BytesMeter, Cbo);
+ metrics::EmitSnapshot(m_BytesHistogram, Cbo, 1.0);
+ Cbo.EndObject();
+
+ Cbo.EndObject();
+}
+
+void
+EmitSnapshot(std::string_view Tag, OperationTiming& Stat, CbObjectWriter& Cbo)
+{
+ Cbo.BeginObject(Tag);
+
+ SampleSnapshot Snap = Stat.Snapshot();
+
+ Cbo << "count" << Stat.Count();
+ Cbo << "rate_mean" << Stat.MeanRate();
+ Cbo << "rate_1" << Stat.Rate1() << "rate_5" << Stat.Rate5() << "rate_15" << Stat.Rate15();
+
+ const double ToSeconds = GetHifreqTimerToSeconds();
+
+ Cbo << "t_avg" << Stat.Mean() * ToSeconds;
+ Cbo << "t_min" << Stat.Min() * ToSeconds << "t_max" << Stat.Max() * ToSeconds;
+ Cbo << "t_p75" << Snap.Get75Percentile() * ToSeconds << "t_p95" << Snap.Get95Percentile() * ToSeconds << "t_p99"
+ << Snap.Get99Percentile() * ToSeconds << "t_p999" << Snap.Get999Percentile() * ToSeconds;
+
+ Cbo.EndObject();
+}
+
+void
+EmitSnapshot(std::string_view Tag, const Histogram& Stat, CbObjectWriter& Cbo, double ConversionFactor)
+{
+ Cbo.BeginObject(Tag);
+ EmitSnapshot(Stat, Cbo, ConversionFactor);
+ Cbo.EndObject();
+}
+
+void
+EmitSnapshot(const Histogram& Stat, CbObjectWriter& Cbo, double ConversionFactor)
+{
+ SampleSnapshot Snap = Stat.Snapshot();
+
+ Cbo << "count" << Stat.Count() * ConversionFactor << "avg" << Stat.Mean() * ConversionFactor;
+ Cbo << "min" << Stat.Min() * ConversionFactor << "max" << Stat.Max() * ConversionFactor;
+ Cbo << "p75" << Snap.Get75Percentile() * ConversionFactor << "p95" << Snap.Get95Percentile() * ConversionFactor << "p99"
+ << Snap.Get99Percentile() * ConversionFactor << "p999" << Snap.Get999Percentile() * ConversionFactor;
+}
+
+void
+EmitSnapshot(std::string_view Tag, Meter& Stat, CbObjectWriter& Cbo)
+{
+ Cbo.BeginObject(Tag);
+
+ Cbo << "count" << Stat.Count() << "rate_mean" << Stat.MeanRate();
+ Cbo << "rate_1" << Stat.Rate1() << "rate_5" << Stat.Rate5() << "rate_15" << Stat.Rate15();
+
+ Cbo.EndObject();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("Core.Stats.Histogram")
+{
+ Histogram Histo{258};
+
+ SampleSnapshot Snap1 = Histo.Snapshot();
+ CHECK_EQ(Snap1.Size(), 0);
+ CHECK_EQ(Snap1.GetMedian(), 0);
+
+ Histo.Update(1);
+ CHECK_EQ(Histo.Min(), 1);
+ CHECK_EQ(Histo.Max(), 1);
+
+ SampleSnapshot Snap2 = Histo.Snapshot();
+ CHECK_EQ(Snap2.Size(), 1);
+
+ Histo.Update(2);
+ CHECK_EQ(Histo.Min(), 1);
+ CHECK_EQ(Histo.Max(), 2);
+
+ SampleSnapshot Snap3 = Histo.Snapshot();
+ CHECK_EQ(Snap3.Size(), 2);
+
+ Histo.Update(-2);
+ CHECK_EQ(Histo.Min(), -2);
+ CHECK_EQ(Histo.Max(), 2);
+ CHECK_EQ(Histo.Mean(), 1 / 3.0);
+
+ SampleSnapshot Snap4 = Histo.Snapshot();
+ CHECK_EQ(Snap4.Size(), 3);
+ CHECK_EQ(Snap4.GetMedian(), 1);
+ CHECK_EQ(Snap4.Get999Percentile(), 2);
+ CHECK_EQ(Snap4.GetQuantileValue(0), -2);
+}
+
+TEST_CASE("Core.Stats.UniformSample")
+{
+ UniformSample Sample1{100};
+
+ for (int i = 0; i < 100; ++i)
+ {
+ for (int j = 1; j <= 100; ++j)
+ {
+ Sample1.Update(j);
+ }
+ }
+
+ int64_t Sum = 0;
+ int64_t Count = 0;
+
+ Sample1.IterateValues([&](int64_t Value) {
+ ++Count;
+ Sum += Value;
+ });
+
+ double Average = double(Sum) / Count;
+
+ CHECK(fabs(Average - 50) < 10); // What's the right test here? The result could vary massively and still be technically correct
+}
+
+TEST_CASE("Core.Stats.EWMA")
+{
+ SUBCASE("Simple_1")
+ {
+ RawEWMA Ewma1;
+ Ewma1.Tick(kM1_ALPHA, CountPerSecond, 5, true);
+
+ CHECK(fabs(Ewma1.Rate() - 5) < 0.1);
+
+ for (int i = 0; i < 60; ++i)
+ {
+ Ewma1.Tick(kM1_ALPHA, CountPerSecond, 10, false);
+ }
+
+ CHECK(fabs(Ewma1.Rate() - 10) < 0.1);
+
+ for (int i = 0; i < 60; ++i)
+ {
+ Ewma1.Tick(kM1_ALPHA, CountPerSecond, 20, false);
+ }
+
+ CHECK(fabs(Ewma1.Rate() - 20) < 0.1);
+ }
+
+ SUBCASE("Simple_10")
+ {
+ RawEWMA Ewma1;
+ RawEWMA Ewma5;
+ RawEWMA Ewma15;
+ Ewma1.Tick(kM1_ALPHA, CountPerSecond, 5, true);
+ Ewma5.Tick(kM5_ALPHA, CountPerSecond, 5, true);
+ Ewma15.Tick(kM15_ALPHA, CountPerSecond, 5, true);
+
+ CHECK(fabs(Ewma1.Rate() - 5) < 0.1);
+ CHECK(fabs(Ewma5.Rate() - 5) < 0.1);
+ CHECK(fabs(Ewma15.Rate() - 5) < 0.1);
+
+ auto Tick1 = [&Ewma1](auto Value) { Ewma1.Tick(kM1_ALPHA, CountPerSecond, Value, false); };
+ auto Tick5 = [&Ewma5](auto Value) { Ewma5.Tick(kM5_ALPHA, CountPerSecond, Value, false); };
+ auto Tick15 = [&Ewma15](auto Value) { Ewma15.Tick(kM15_ALPHA, CountPerSecond, Value, false); };
+
+ for (int i = 0; i < 60; ++i)
+ {
+ Tick1(10);
+ Tick5(10);
+ Tick15(10);
+ }
+
+ CHECK(fabs(Ewma1.Rate() - 10) < 0.1);
+
+ for (int i = 0; i < 5 * 60; ++i)
+ {
+ Tick1(20);
+ Tick5(20);
+ Tick15(20);
+ }
+
+ CHECK(fabs(Ewma1.Rate() - 20) < 0.1);
+ CHECK(fabs(Ewma5.Rate() - 20) < 0.1);
+
+ for (int i = 0; i < 16 * 60; ++i)
+ {
+ Tick1(100);
+ Tick5(100);
+ Tick15(100);
+ }
+
+ CHECK(fabs(Ewma1.Rate() - 100) < 0.1);
+ CHECK(fabs(Ewma5.Rate() - 100) < 0.1);
+ CHECK(fabs(Ewma15.Rate() - 100) < 0.5);
+ }
+}
+
+# if 0 // This is not really a unit test, but mildly useful to exercise some code
+TEST_CASE("Meter")
+{
+ Meter Meter1;
+ Meter1.Mark(1);
+ Sleep(1000);
+ Meter1.Mark(1);
+ Sleep(1000);
+ Meter1.Mark(1);
+ Sleep(1000);
+ Meter1.Mark(1);
+ Sleep(1000);
+ Meter1.Mark(1);
+ Sleep(1000);
+ Meter1.Mark(1);
+ Sleep(1000);
+ Meter1.Mark(1);
+ Sleep(1000);
+ Meter1.Mark(1);
+ Sleep(1000);
+ Meter1.Mark(1);
+ Sleep(1000);
+ [[maybe_unused]] double Rate = Meter1.MeanRate();
+}
+# endif
+}
+
+namespace zen {
+
+void
+stats_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen::metrics
diff --git a/src/zencore/stream.cpp b/src/zencore/stream.cpp
new file mode 100644
index 000000000..3402e51be
--- /dev/null
+++ b/src/zencore/stream.cpp
@@ -0,0 +1,79 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <stdarg.h>
+#include <zencore/memory.h>
+#include <zencore/stream.h>
+#include <zencore/testing.h>
+
+#include <algorithm>
+#include <stdexcept>
+
+namespace zen {
+
+void
+BinaryWriter::Write(std::initializer_list<const MemoryView> Buffers)
+{
+ size_t TotalByteCount = 0;
+ for (const MemoryView& View : Buffers)
+ {
+ TotalByteCount += View.GetSize();
+ }
+ const size_t NeedEnd = m_Offset + TotalByteCount;
+ if (NeedEnd > m_Buffer.size())
+ {
+ m_Buffer.resize(NeedEnd);
+ }
+ for (const MemoryView& View : Buffers)
+ {
+ memcpy(m_Buffer.data() + m_Offset, View.GetData(), View.GetSize());
+ m_Offset += View.GetSize();
+ }
+}
+
+void
+BinaryWriter::Write(const void* data, size_t ByteCount, uint64_t Offset)
+{
+ const size_t NeedEnd = Offset + ByteCount;
+
+ if (NeedEnd > m_Buffer.size())
+ {
+ m_Buffer.resize(NeedEnd);
+ }
+
+ memcpy(m_Buffer.data() + Offset, data, ByteCount);
+}
+
+void
+BinaryWriter::Reset()
+{
+ m_Buffer.clear();
+ m_Offset = 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("binary.writer.span")
+{
+ BinaryWriter Writer;
+ const MemoryView View1("apa", 3);
+ const MemoryView View2(" ", 1);
+ const MemoryView View3("banan", 5);
+ Writer.Write({View1, View2, View3});
+ MemoryView Result = Writer.GetView();
+ CHECK(Result.GetSize() == 9);
+ CHECK(memcmp(Result.GetData(), "apa banan", 9) == 0);
+}
+
+void
+stream_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/string.cpp b/src/zencore/string.cpp
new file mode 100644
index 000000000..ad6ee78fc
--- /dev/null
+++ b/src/zencore/string.cpp
@@ -0,0 +1,1004 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/memory.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+
+#include <inttypes.h>
+#include <math.h>
+#include <stdio.h>
+#include <exception>
+#include <ostream>
+#include <stdexcept>
+
+#include <utf8.h>
+
+template<typename u16bit_iterator>
+void
+utf16to8_impl(u16bit_iterator StartIt, u16bit_iterator EndIt, ::zen::StringBuilderBase& OutString)
+{
+ while (StartIt != EndIt)
+ {
+ uint32_t cp = utf8::internal::mask16(*StartIt++);
+ // Take care of surrogate pairs first
+ if (utf8::internal::is_lead_surrogate(cp))
+ {
+ uint32_t trail_surrogate = utf8::internal::mask16(*StartIt++);
+ cp = (cp << 10) + trail_surrogate + utf8::internal::SURROGATE_OFFSET;
+ }
+ OutString.AppendCodepoint(cp);
+ }
+}
+
+template<typename u32bit_iterator>
+void
+utf32to8_impl(u32bit_iterator StartIt, u32bit_iterator EndIt, ::zen::StringBuilderBase& OutString)
+{
+ for (; StartIt != EndIt; ++StartIt)
+ {
+ wchar_t cp = *StartIt;
+ OutString.AppendCodepoint(cp);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+bool
+ToString(std::span<char> Buffer, uint64_t Num)
+{
+ snprintf(Buffer.data(), Buffer.size(), "%" PRIu64, Num);
+
+ return true;
+}
+bool
+ToString(std::span<char> Buffer, int64_t Num)
+{
+ snprintf(Buffer.data(), Buffer.size(), "%" PRId64, Num);
+
+ return true;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+const char*
+FilepathFindExtension(const std::string_view& Path, const char* ExtensionToMatch)
+{
+ const size_t PathLen = Path.size();
+
+ if (ExtensionToMatch)
+ {
+ size_t ExtLen = strlen(ExtensionToMatch);
+
+ if (ExtLen > PathLen)
+ return nullptr;
+
+ const char* PathExtension = Path.data() + PathLen - ExtLen;
+
+ if (StringEquals(PathExtension, ExtensionToMatch))
+ return PathExtension;
+
+ return nullptr;
+ }
+
+ if (PathLen == 0)
+ return nullptr;
+
+ // Look for extension introducer ('.')
+
+ for (int64_t i = PathLen - 1; i >= 0; --i)
+ {
+ if (Path[i] == '.')
+ return Path.data() + i;
+ }
+
+ return nullptr;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+Utf8ToWide(const char8_t* Str8, WideStringBuilderBase& OutString)
+{
+ Utf8ToWide(std::u8string_view(Str8), OutString);
+}
+
+void
+Utf8ToWide(const std::string_view& Str8, WideStringBuilderBase& OutString)
+{
+ Utf8ToWide(std::u8string_view{reinterpret_cast<const char8_t*>(Str8.data()), Str8.size()}, OutString);
+}
+
+std::wstring
+Utf8ToWide(const std::string_view& Wstr)
+{
+ ExtendableWideStringBuilder<128> String;
+ Utf8ToWide(Wstr, String);
+
+ return String.c_str();
+}
+
+void
+Utf8ToWide(const std::u8string_view& Str8, WideStringBuilderBase& OutString)
+{
+ const char* str = (const char*)Str8.data();
+ const size_t strLen = Str8.size();
+
+ const char* endStr = str + strLen;
+ size_t ByteCount = 0;
+ size_t CurrentOutChar = 0;
+
+ for (; str != endStr; ++str)
+ {
+ unsigned char Data = static_cast<unsigned char>(*str);
+
+ if (!(Data & 0x80))
+ {
+ // ASCII
+ OutString.Append(wchar_t(Data));
+ continue;
+ }
+ else if (!ByteCount)
+ {
+ // Start of multi-byte sequence. Figure out how
+ // many bytes we're going to consume
+
+ size_t Count = 0;
+
+ for (size_t Temp = Data; Temp & 0x80; Temp <<= 1)
+ ++Count;
+
+ ByteCount = Count - 1;
+ CurrentOutChar = Data & (0xff >> (Count + 1));
+ }
+ else
+ {
+ --ByteCount;
+
+ if ((Data & 0xc0) != 0x80)
+ {
+ break;
+ }
+
+ CurrentOutChar = (CurrentOutChar << 6) | (Data & 0x3f);
+
+ if (!ByteCount)
+ {
+ OutString.Append(wchar_t(CurrentOutChar));
+ CurrentOutChar = 0;
+ }
+ }
+ }
+}
+
+void
+WideToUtf8(const wchar_t* Wstr, StringBuilderBase& OutString)
+{
+ WideToUtf8(std::wstring_view{Wstr}, OutString);
+}
+
+void
+WideToUtf8(const std::wstring_view& Wstr, StringBuilderBase& OutString)
+{
+#if ZEN_SIZEOF_WCHAR_T == 2
+ utf16to8_impl(begin(Wstr), end(Wstr), OutString);
+#else
+ utf32to8_impl(begin(Wstr), end(Wstr), OutString);
+#endif
+}
+
+std::string
+WideToUtf8(const wchar_t* Wstr)
+{
+ ExtendableStringBuilder<128> String;
+ WideToUtf8(std::wstring_view{Wstr}, String);
+
+ return String.c_str();
+}
+
+std::string
+WideToUtf8(const std::wstring_view Wstr)
+{
+ ExtendableStringBuilder<128> String;
+ WideToUtf8(std::wstring_view{Wstr.data(), Wstr.size()}, String);
+
+ return String.c_str();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+enum NicenumFormat
+{
+ kNicenum1024 = 0, // Print kilo, mega, tera, peta, exa..
+ kNicenumBytes = 1, // Print single bytes ("13B"), kilo, mega, tera...
+ kNicenumTime = 2, // Print nanosecs, microsecs, millisecs, seconds...
+ kNicenumRaw = 3, // Print the raw number without any formatting
+ kNicenumRawTime = 4 // Same as RAW, but print dashes ('-') for zero.
+};
+
+namespace {
+ static const char* UnitStrings[3][7] = {
+ /* kNicenum1024 */ {"", "K", "M", "G", "T", "P", "E"},
+ /* kNicenumBytes */ {"B", "K", "M", "G", "T", "P", "E"},
+ /* kNicenumTime */ {"ns", "us", "ms", "s", "?", "?", "?"}};
+
+ static const int UnitsLen[] = {
+ /* kNicenum1024 */ 6,
+ /* kNicenumBytes */ 6,
+ /* kNicenumTime */ 3};
+
+ static const uint64_t KiloUnit[] = {
+ /* kNicenum1024 */ 1024,
+ /* kNicenumBytes */ 1024,
+ /* kNicenumTime */ 1000};
+} // namespace
+
+/*
+ * Convert a number to an appropriately human-readable output.
+ */
+int
+NiceNumGeneral(uint64_t Num, std::span<char> Buffer, NicenumFormat Format)
+{
+ switch (Format)
+ {
+ case kNicenumRaw:
+ return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64, (uint64_t)Num);
+
+ case kNicenumRawTime:
+ if (Num > 0)
+ {
+ return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64, (uint64_t)Num);
+ }
+ else
+ {
+ return snprintf(Buffer.data(), Buffer.size(), "%s", "-");
+ }
+ break;
+
+ case kNicenum1024:
+ case kNicenumBytes:
+ case kNicenumTime:
+ default:
+ break;
+ }
+
+ // Bring into range and select unit
+
+ int Index = 0;
+ uint64_t n = Num;
+
+ {
+ const uint64_t Unit = KiloUnit[Format];
+ const int maxIndex = UnitsLen[Format];
+
+ while (n >= Unit && Index < maxIndex)
+ {
+ n /= Unit;
+ Index++;
+ }
+ }
+
+ const char* u = UnitStrings[Format][Index];
+
+ if ((Index == 0) || ((Num % (uint64_t)powl((int)KiloUnit[Format], Index)) == 0))
+ {
+ /*
+ * If this is an even multiple of the base, always display
+ * without any decimal precision.
+ */
+ return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64 "%s", (uint64_t)n, u);
+ }
+ else
+ {
+ /*
+ * We want to choose a precision that reflects the best choice
+ * for fitting in 5 characters. This can get rather tricky when
+ * we have numbers that are very close to an order of magnitude.
+ * For example, when displaying 10239 (which is really 9.999K),
+ * we want only a single place of precision for 10.0K. We could
+ * develop some complex heuristics for this, but it's much
+ * easier just to try each combination in turn.
+ */
+
+ int StrLen = 0;
+
+ for (int i = 2; i >= 0; i--)
+ {
+ double Value = (double)Num / (uint64_t)powl((int)KiloUnit[Format], Index);
+
+ /*
+ * Don't print floating point values for time. Note,
+ * we use floor() instead of round() here, since
+ * round can result in undesirable results. For
+ * example, if "num" is in the range of
+ * 999500-999999, it will print out "1000us". This
+ * doesn't happen if we use floor().
+ */
+ if (Format == kNicenumTime)
+ {
+ StrLen = snprintf(Buffer.data(), Buffer.size(), "%d%s", (unsigned int)floor(Value), u);
+
+ if (StrLen <= 5)
+ break;
+ }
+ else
+ {
+ StrLen = snprintf(Buffer.data(), Buffer.size(), "%.*f%s", i, Value, u);
+
+ if (StrLen <= 5)
+ break;
+ }
+ }
+
+ return StrLen;
+ }
+}
+
+size_t
+NiceNumToBuffer(uint64_t Num, std::span<char> Buffer)
+{
+ return NiceNumGeneral(Num, Buffer, kNicenum1024);
+}
+
+size_t
+NiceBytesToBuffer(uint64_t Num, std::span<char> Buffer)
+{
+ return NiceNumGeneral(Num, Buffer, kNicenumBytes);
+}
+
+size_t
+NiceByteRateToBuffer(uint64_t Num, uint64_t ElapsedMs, std::span<char> Buffer)
+{
+ size_t n = 0;
+
+ if (ElapsedMs)
+ {
+ n = NiceNumGeneral(Num * 1000 / ElapsedMs, Buffer, kNicenumBytes);
+ }
+ else
+ {
+ Buffer[n++] = '0';
+ Buffer[n++] = 'B';
+ }
+
+ Buffer[n++] = '/';
+ Buffer[n++] = 's';
+ Buffer[n++] = '\0';
+
+ return n;
+}
+
+size_t
+NiceLatencyNsToBuffer(uint64_t Nanos, std::span<char> Buffer)
+{
+ return NiceNumGeneral(Nanos, Buffer, kNicenumTime);
+}
+
+size_t
+NiceTimeSpanMsToBuffer(uint64_t Millis, std::span<char> Buffer)
+{
+ if (Millis < 1000)
+ {
+ return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64 "ms", Millis);
+ }
+ else if (Millis < 10000)
+ {
+ return snprintf(Buffer.data(), Buffer.size(), "%.2fs", Millis / 1000.0);
+ }
+ else if (Millis < 60000)
+ {
+ return snprintf(Buffer.data(), Buffer.size(), "%.1fs", Millis / 1000.0);
+ }
+ else if (Millis < 60 * 60000)
+ {
+ return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64 "m%02" PRIu64 "s", Millis / 60000, (Millis / 1000) % 60);
+ }
+ else
+ {
+ return snprintf(Buffer.data(), Buffer.size(), "%" PRIu64 "h%02" PRIu64 "m", Millis / 3600000, (Millis / 60000) % 60);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+template<typename C>
+StringBuilderImpl<C>::~StringBuilderImpl()
+{
+ if (m_IsDynamic)
+ {
+ FreeBuffer(m_Base, m_End - m_Base);
+ }
+}
+
+template<typename C>
+void
+StringBuilderImpl<C>::Extend(size_t extraCapacity)
+{
+ if (!m_IsExtendable)
+ {
+ Fail("exceeded capacity");
+ }
+
+ const size_t oldCapacity = m_End - m_Base;
+ const size_t newCapacity = NextPow2(oldCapacity + extraCapacity);
+
+ C* newBase = (C*)AllocBuffer(newCapacity);
+
+ size_t pos = m_CurPos - m_Base;
+ memcpy(newBase, m_Base, pos * sizeof(C));
+
+ if (m_IsDynamic)
+ {
+ FreeBuffer(m_Base, oldCapacity);
+ }
+
+ m_Base = newBase;
+ m_CurPos = newBase + pos;
+ m_End = newBase + newCapacity;
+ m_IsDynamic = true;
+}
+
+template<typename C>
+void*
+StringBuilderImpl<C>::AllocBuffer(size_t byteCount)
+{
+ return Memory::Alloc(byteCount * sizeof(C));
+}
+
+template<typename C>
+void
+StringBuilderImpl<C>::FreeBuffer(void* buffer, size_t byteCount)
+{
+ ZEN_UNUSED(byteCount);
+
+ Memory::Free(buffer);
+}
+
+template<typename C>
+[[noreturn]] void
+StringBuilderImpl<C>::Fail(const char* reason)
+{
+ throw std::runtime_error(reason);
+}
+
+// Instantiate templates once
+
+template class StringBuilderImpl<char>;
+template class StringBuilderImpl<wchar_t>;
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Unit tests
+//
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("niceNum")
+{
+ char Buffer[16];
+
+ SUBCASE("raw")
+ {
+ NiceNumGeneral(1, Buffer, kNicenumRaw);
+ CHECK(StringEquals(Buffer, "1"));
+
+ NiceNumGeneral(10, Buffer, kNicenumRaw);
+ CHECK(StringEquals(Buffer, "10"));
+
+ NiceNumGeneral(100, Buffer, kNicenumRaw);
+ CHECK(StringEquals(Buffer, "100"));
+
+ NiceNumGeneral(1000, Buffer, kNicenumRaw);
+ CHECK(StringEquals(Buffer, "1000"));
+
+ NiceNumGeneral(10000, Buffer, kNicenumRaw);
+ CHECK(StringEquals(Buffer, "10000"));
+
+ NiceNumGeneral(100000, Buffer, kNicenumRaw);
+ CHECK(StringEquals(Buffer, "100000"));
+ }
+
+ SUBCASE("1024")
+ {
+ NiceNumGeneral(1, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "1"));
+
+ NiceNumGeneral(10, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "10"));
+
+ NiceNumGeneral(100, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "100"));
+
+ NiceNumGeneral(1000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "1000"));
+
+ NiceNumGeneral(10000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "9.77K"));
+
+ NiceNumGeneral(100000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "97.7K"));
+
+ NiceNumGeneral(1000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "977K"));
+
+ NiceNumGeneral(10000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "9.54M"));
+
+ NiceNumGeneral(100000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "95.4M"));
+
+ NiceNumGeneral(1000000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "954M"));
+
+ NiceNumGeneral(10000000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "9.31G"));
+
+ NiceNumGeneral(100000000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "93.1G"));
+
+ NiceNumGeneral(1000000000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "931G"));
+
+ NiceNumGeneral(10000000000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "9.09T"));
+
+ NiceNumGeneral(100000000000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "90.9T"));
+
+ NiceNumGeneral(1000000000000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "909T"));
+
+ NiceNumGeneral(10000000000000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "8.88P"));
+
+ NiceNumGeneral(100000000000000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "88.8P"));
+
+ NiceNumGeneral(1000000000000000000, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "888P"));
+
+ NiceNumGeneral(10000000000000000000ull, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "8.67E"));
+
+ // pow2
+
+ NiceNumGeneral(0, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "0"));
+
+ NiceNumGeneral(1, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "1"));
+
+ NiceNumGeneral(1024, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "1K"));
+
+ NiceNumGeneral(1024 * 1024, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "1M"));
+
+ NiceNumGeneral(1024 * 1024 * 1024, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "1G"));
+
+ NiceNumGeneral(1024llu * 1024 * 1024 * 1024, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "1T"));
+
+ NiceNumGeneral(1024llu * 1024 * 1024 * 1024 * 1024, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "1P"));
+
+ NiceNumGeneral(1024llu * 1024 * 1024 * 1024 * 1024 * 1024, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "1E"));
+
+ // pow2-1
+
+ NiceNumGeneral(1023, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "1023"));
+
+ NiceNumGeneral(2047, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "2.00K"));
+
+ NiceNumGeneral(9 * 1024 - 1, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "9.00K"));
+
+ NiceNumGeneral(10 * 1024 - 1, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "10.0K"));
+
+ NiceNumGeneral(10 * 1024 - 5, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "10.0K"));
+
+ NiceNumGeneral(10 * 1024 - 6, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "9.99K"));
+
+ NiceNumGeneral(10 * 1024 - 10, Buffer, kNicenum1024);
+ CHECK(StringEquals(Buffer, "9.99K"));
+ }
+
+ SUBCASE("time")
+ {
+ NiceNumGeneral(1, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "1ns"));
+
+ NiceNumGeneral(100, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "100ns"));
+
+ NiceNumGeneral(1000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "1us"));
+
+ NiceNumGeneral(10000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "10us"));
+
+ NiceNumGeneral(100000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "100us"));
+
+ NiceNumGeneral(1000000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "1ms"));
+
+ NiceNumGeneral(10000000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "10ms"));
+
+ NiceNumGeneral(100000000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "100ms"));
+
+ NiceNumGeneral(1000000000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "1s"));
+
+ NiceNumGeneral(10000000000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "10s"));
+
+ NiceNumGeneral(100000000000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "100s"));
+
+ NiceNumGeneral(1000000000000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "1000s"));
+
+ NiceNumGeneral(10000000000000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "10000s"));
+
+ NiceNumGeneral(100000000000000, Buffer, kNicenumTime);
+ CHECK(StringEquals(Buffer, "100000s"));
+ }
+
+ SUBCASE("bytes")
+ {
+ NiceNumGeneral(1, Buffer, kNicenumBytes);
+ CHECK(StringEquals(Buffer, "1B"));
+
+ NiceNumGeneral(10, Buffer, kNicenumBytes);
+ CHECK(StringEquals(Buffer, "10B"));
+
+ NiceNumGeneral(100, Buffer, kNicenumBytes);
+ CHECK(StringEquals(Buffer, "100B"));
+
+ NiceNumGeneral(1000, Buffer, kNicenumBytes);
+ CHECK(StringEquals(Buffer, "1000B"));
+
+ NiceNumGeneral(10000, Buffer, kNicenumBytes);
+ CHECK(StringEquals(Buffer, "9.77K"));
+ }
+
+ SUBCASE("byteRate")
+ {
+ NiceByteRateToBuffer(1, 1, Buffer);
+ CHECK(StringEquals(Buffer, "1000B/s"));
+
+ NiceByteRateToBuffer(1000, 1000, Buffer);
+ CHECK(StringEquals(Buffer, "1000B/s"));
+
+ NiceByteRateToBuffer(1024, 1, Buffer);
+ CHECK(StringEquals(Buffer, "1000K/s"));
+
+ NiceByteRateToBuffer(1024, 1000, Buffer);
+ CHECK(StringEquals(Buffer, "1K/s"));
+ }
+
+ SUBCASE("timespan")
+ {
+ NiceTimeSpanMsToBuffer(1, Buffer);
+ CHECK(StringEquals(Buffer, "1ms"));
+
+ NiceTimeSpanMsToBuffer(900, Buffer);
+ CHECK(StringEquals(Buffer, "900ms"));
+
+ NiceTimeSpanMsToBuffer(1000, Buffer);
+ CHECK(StringEquals(Buffer, "1.00s"));
+
+ NiceTimeSpanMsToBuffer(1900, Buffer);
+ CHECK(StringEquals(Buffer, "1.90s"));
+
+ NiceTimeSpanMsToBuffer(19000, Buffer);
+ CHECK(StringEquals(Buffer, "19.0s"));
+
+ NiceTimeSpanMsToBuffer(60000, Buffer);
+ CHECK(StringEquals(Buffer, "1m00s"));
+
+ NiceTimeSpanMsToBuffer(600000, Buffer);
+ CHECK(StringEquals(Buffer, "10m00s"));
+
+ NiceTimeSpanMsToBuffer(3600000, Buffer);
+ CHECK(StringEquals(Buffer, "1h00m"));
+
+ NiceTimeSpanMsToBuffer(36000000, Buffer);
+ CHECK(StringEquals(Buffer, "10h00m"));
+
+ NiceTimeSpanMsToBuffer(360000000, Buffer);
+ CHECK(StringEquals(Buffer, "100h00m"));
+ }
+}
+
+void
+string_forcelink()
+{
+}
+
+TEST_CASE("StringBuilder")
+{
+ StringBuilder<64> sb;
+
+ SUBCASE("Empty init")
+ {
+ const char* str = sb.c_str();
+
+ CHECK(StringLength(str) == 0);
+ }
+
+ SUBCASE("Append single character")
+ {
+ sb.Append('a');
+
+ const char* str = sb.c_str();
+ CHECK(StringLength(str) == 1);
+ CHECK(str[0] == 'a');
+
+ sb.Append('b');
+ str = sb.c_str();
+ CHECK(StringLength(str) == 2);
+ CHECK(str[0] == 'a');
+ CHECK(str[1] == 'b');
+ }
+
+ SUBCASE("Append string")
+ {
+ sb.Append("a");
+
+ const char* str = sb.c_str();
+ CHECK(StringLength(str) == 1);
+ CHECK(str[0] == 'a');
+
+ sb.Append("b");
+ str = sb.c_str();
+ CHECK(StringLength(str) == 2);
+ CHECK(str[0] == 'a');
+ CHECK(str[1] == 'b');
+
+ sb.Append("cdefghijklmnopqrstuvwxyz");
+ CHECK(sb.Size() == 26);
+
+ sb.Append("abcdefghijklmnopqrstuvwxyz");
+ CHECK(sb.Size() == 52);
+
+ sb.Append("abcdefghijk");
+ CHECK(sb.Size() == 63);
+ }
+}
+
+TEST_CASE("ExtendableStringBuilder")
+{
+ ExtendableStringBuilder<16> sb;
+
+ SUBCASE("Empty init")
+ {
+ const char* str = sb.c_str();
+
+ CHECK(StringLength(str) == 0);
+ }
+
+ SUBCASE("Short append")
+ {
+ sb.Append("abcd");
+ CHECK(sb.IsDynamic() == false);
+ }
+
+ SUBCASE("Short+long append")
+ {
+ sb.Append("abcd");
+ CHECK(sb.IsDynamic() == false);
+ // This should trigger a dynamic buffer allocation since the required
+ // capacity exceeds the internal fixed buffer.
+ sb.Append("abcdefghijklmnopqrstuvwxyz");
+ CHECK(sb.IsDynamic() == true);
+ CHECK(sb.Size() == 30);
+ CHECK(sb.Size() == StringLength(sb.c_str()));
+ }
+}
+
+TEST_CASE("WideStringBuilder")
+{
+ WideStringBuilder<64> sb;
+
+ SUBCASE("Empty init")
+ {
+ const wchar_t* str = sb.c_str();
+
+ CHECK(StringLength(str) == 0);
+ }
+
+ SUBCASE("Append single character")
+ {
+ sb.Append(L'a');
+
+ const wchar_t* str = sb.c_str();
+ CHECK(StringLength(str) == 1);
+ CHECK(str[0] == L'a');
+
+ sb.Append(L'b');
+ str = sb.c_str();
+ CHECK(StringLength(str) == 2);
+ CHECK(str[0] == L'a');
+ CHECK(str[1] == L'b');
+ }
+
+ SUBCASE("Append string")
+ {
+ sb.Append(L"a");
+
+ const wchar_t* str = sb.c_str();
+ CHECK(StringLength(str) == 1);
+ CHECK(str[0] == L'a');
+
+ sb.Append(L"b");
+ str = sb.c_str();
+ CHECK(StringLength(str) == 2);
+ CHECK(str[0] == L'a');
+ CHECK(str[1] == L'b');
+
+ sb.Append(L"cdefghijklmnopqrstuvwxyz");
+ CHECK(sb.Size() == 26);
+
+ sb.Append(L"abcdefghijklmnopqrstuvwxyz");
+ CHECK(sb.Size() == 52);
+
+ sb.Append(L"abcdefghijk");
+ CHECK(sb.Size() == 63);
+ }
+}
+
+TEST_CASE("ExtendableWideStringBuilder")
+{
+ ExtendableWideStringBuilder<16> sb;
+
+ SUBCASE("Empty init")
+ {
+ CHECK(sb.Size() == 0);
+
+ const wchar_t* str = sb.c_str();
+ CHECK(StringLength(str) == 0);
+ }
+
+ SUBCASE("Short append")
+ {
+ sb.Append(L"abcd");
+ CHECK(sb.IsDynamic() == false);
+ }
+
+ SUBCASE("Short+long append")
+ {
+ sb.Append(L"abcd");
+ CHECK(sb.IsDynamic() == false);
+ // This should trigger a dynamic buffer allocation since the required
+ // capacity exceeds the internal fixed buffer.
+ sb.Append(L"abcdefghijklmnopqrstuvwxyz");
+ CHECK(sb.IsDynamic() == true);
+ CHECK(sb.Size() == 30);
+ CHECK(sb.Size() == StringLength(sb.c_str()));
+ }
+}
+
+TEST_CASE("utf8")
+{
+ SUBCASE("utf8towide")
+ {
+ // TODO: add more extensive testing here - this covers a very small space
+
+ WideStringBuilder<32> wout;
+ Utf8ToWide(u8"abcdefghi", wout);
+ CHECK(StringEquals(L"abcdefghi", wout.c_str()));
+
+ wout.Reset();
+
+ Utf8ToWide(u8"abc���", wout);
+ CHECK(StringEquals(L"abc���", wout.c_str()));
+ }
+
+ SUBCASE("widetoutf8")
+ {
+ // TODO: add more extensive testing here - this covers a very small space
+
+ StringBuilder<32> out;
+
+ WideToUtf8(L"abcdefghi", out);
+ CHECK(StringEquals("abcdefghi", out.c_str()));
+
+ out.Reset();
+
+ WideToUtf8(L"abc���", out);
+ CHECK(StringEquals(u8"abc���", out.c_str()));
+ }
+}
+
+TEST_CASE("filepath")
+{
+ CHECK(FilepathFindExtension("foo\\bar\\baz.txt", ".txt") != nullptr);
+ CHECK(FilepathFindExtension("foo\\bar\\baz.txt", ".zap") == nullptr);
+
+ CHECK(FilepathFindExtension("foo\\bar\\baz.txt") != nullptr);
+ CHECK(FilepathFindExtension("foo\\bar\\baz.txt") == std::string_view(".txt"));
+
+ CHECK(FilepathFindExtension(".txt") == std::string_view(".txt"));
+}
+
+TEST_CASE("string")
+{
+ using namespace std::literals;
+
+ SUBCASE("hash_djb2")
+ {
+ CHECK(HashStringAsLowerDjb2("AbcdZ"sv) == HashStringDjb2("abcdz"sv));
+ CHECK(HashStringAsLowerDjb2("aBCd"sv) == HashStringDjb2("abcd"sv));
+ CHECK(HashStringAsLowerDjb2("aBCd"sv) == HashStringDjb2(ToLower("aBCd"sv)));
+ }
+
+ SUBCASE("tolower")
+ {
+ CHECK_EQ(ToLower("te!st"sv), "te!st"sv);
+ CHECK_EQ(ToLower("TE%St"sv), "te%st"sv);
+ }
+
+ SUBCASE("StrCaseCompare")
+ {
+ CHECK(StrCaseCompare("foo", "FoO") == 0);
+ CHECK(StrCaseCompare("Bar", "bAs") < 0);
+ CHECK(StrCaseCompare("bAr", "Bas") < 0);
+ CHECK(StrCaseCompare("BBr", "Bar") > 0);
+ CHECK(StrCaseCompare("Bbr", "BAr") > 0);
+ CHECK(StrCaseCompare("foo", "FoO", 3) == 0);
+ CHECK(StrCaseCompare("Bar", "bAs", 3) < 0);
+ CHECK(StrCaseCompare("BBr", "Bar", 2) > 0);
+ }
+
+ SUBCASE("ForEachStrTok")
+ {
+ const auto Tokens = "here,is,my,different,tokens"sv;
+ int32_t ExpectedTokenCount = 5;
+ int32_t TokenCount = 0;
+ StringBuilder<512> Sb;
+
+ TokenCount = ForEachStrTok(Tokens, ',', [&Sb](const std::string_view& Token) {
+ if (Sb.Size())
+ {
+ Sb << ",";
+ }
+ Sb << Token;
+ return true;
+ });
+
+ CHECK(TokenCount == ExpectedTokenCount);
+ CHECK(Sb.ToString() == Tokens);
+
+ ExpectedTokenCount = 1;
+ const auto Str = "mosdef"sv;
+
+ Sb.Reset();
+ TokenCount = ForEachStrTok(Str, ' ', [&Sb](const std::string_view& Token) {
+ Sb << Token;
+ return true;
+ });
+ CHECK(Sb.ToString() == Str);
+ CHECK(TokenCount == ExpectedTokenCount);
+
+ ExpectedTokenCount = 0;
+ TokenCount = ForEachStrTok(""sv, ',', [](const std::string_view&) { return true; });
+ CHECK(TokenCount == ExpectedTokenCount);
+ }
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/testing.cpp b/src/zencore/testing.cpp
new file mode 100644
index 000000000..1599e9d1f
--- /dev/null
+++ b/src/zencore/testing.cpp
@@ -0,0 +1,54 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencore/testing.h"
+#include "zencore/logging.h"
+
+#if ZEN_WITH_TESTS
+
+namespace zen::testing {
+
+using namespace std::literals;
+
+struct TestRunner::Impl
+{
+ doctest::Context Session;
+};
+
+TestRunner::TestRunner()
+{
+ m_Impl = std::make_unique<Impl>();
+}
+
+TestRunner::~TestRunner()
+{
+}
+
+int
+TestRunner::ApplyCommandLine(int argc, char const* const* argv)
+{
+ m_Impl->Session.applyCommandLine(argc, argv);
+
+ for (int i = 1; i < argc; ++i)
+ {
+ if (argv[i] == "--debug"sv)
+ {
+ spdlog::set_level(spdlog::level::debug);
+ }
+ }
+
+ return 0;
+}
+
+int
+TestRunner::Run()
+{
+ int Rv = 0;
+
+ m_Impl->Session.run();
+
+ return Rv;
+}
+
+} // namespace zen::testing
+
+#endif
diff --git a/src/zencore/testutils.cpp b/src/zencore/testutils.cpp
new file mode 100644
index 000000000..dbc3ab5af
--- /dev/null
+++ b/src/zencore/testutils.cpp
@@ -0,0 +1,42 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zencore/testutils.h"
+#include <zencore/session.h>
+#include "zencore/string.h"
+
+#include <atomic>
+
+namespace zen {
+
+static std::atomic<int> Sequence{0};
+
+std::filesystem::path
+CreateTemporaryDirectory()
+{
+ std::error_code Ec;
+
+ std::filesystem::path DirPath = std::filesystem::temp_directory_path() / GetSessionIdString() / IntNum(++Sequence).c_str();
+ std::filesystem::remove_all(DirPath, Ec);
+ std::filesystem::create_directories(DirPath);
+
+ return DirPath;
+}
+
+ScopedTemporaryDirectory::ScopedTemporaryDirectory() : m_RootPath(CreateTemporaryDirectory())
+{
+}
+
+ScopedTemporaryDirectory::ScopedTemporaryDirectory(std::filesystem::path Directory) : m_RootPath(Directory)
+{
+ std::error_code Ec;
+ std::filesystem::remove_all(Directory, Ec);
+ std::filesystem::create_directories(Directory);
+}
+
+ScopedTemporaryDirectory::~ScopedTemporaryDirectory()
+{
+ std::error_code Ec;
+ std::filesystem::remove_all(m_RootPath, Ec);
+}
+
+} // namespace zen \ No newline at end of file
diff --git a/src/zencore/thread.cpp b/src/zencore/thread.cpp
new file mode 100644
index 000000000..1597a7dd9
--- /dev/null
+++ b/src/zencore/thread.cpp
@@ -0,0 +1,1212 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/thread.h>
+
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/scopeguard.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+
+#if ZEN_PLATFORM_LINUX
+# if !defined(_GNU_SOURCE)
+# define _GNU_SOURCE // for semtimedop()
+# endif
+#endif
+
+#if ZEN_PLATFORM_WINDOWS
+# include <shellapi.h>
+# include <Shlobj.h>
+# include <zencore/windows.h>
+#else
+# include <chrono>
+# include <condition_variable>
+# include <mutex>
+
+# include <fcntl.h>
+# include <pthread.h>
+# include <signal.h>
+# include <sys/file.h>
+# include <sys/sem.h>
+# include <sys/stat.h>
+# include <sys/syscall.h>
+# include <sys/wait.h>
+# include <time.h>
+# include <unistd.h>
+#endif
+
+#include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+#if ZEN_PLATFORM_WINDOWS
+// The information on how to set the thread name comes from
+// a MSDN article: http://msdn2.microsoft.com/en-us/library/xcb2z8hs.aspx
+const DWORD kVCThreadNameException = 0x406D1388;
+typedef struct tagTHREADNAME_INFO
+{
+ DWORD dwType; // Must be 0x1000.
+ LPCSTR szName; // Pointer to name (in user addr space).
+ DWORD dwThreadID; // Thread ID (-1=caller thread).
+ DWORD dwFlags; // Reserved for future use, must be zero.
+} THREADNAME_INFO;
+// The SetThreadDescription API was brought in version 1607 of Windows 10.
+typedef HRESULT(WINAPI* SetThreadDescription)(HANDLE hThread, PCWSTR lpThreadDescription);
+// This function has try handling, so it is separated out of its caller.
+void
+SetNameInternal(DWORD thread_id, const char* name)
+{
+ THREADNAME_INFO info;
+ info.dwType = 0x1000;
+ info.szName = name;
+ info.dwThreadID = thread_id;
+ info.dwFlags = 0;
+ __try
+ {
+ RaiseException(kVCThreadNameException, 0, sizeof(info) / sizeof(DWORD), reinterpret_cast<DWORD_PTR*>(&info));
+ }
+ __except (EXCEPTION_CONTINUE_EXECUTION)
+ {
+ }
+}
+#endif
+
+#if ZEN_PLATFORM_LINUX
+const bool bNoZombieChildren = []() {
+ // When a child process exits it is put into a zombie state until the parent
+ // collects its result. This doesn't fit the Windows-like model that Zen uses
+ // where there is a less strict familial model and no zombification. Ignoring
+ // SIGCHLD siganals removes the need to call wait() on zombies. Another option
+ // would be for the child to call setsid() but that would detatch the child
+ // from the terminal.
+ struct sigaction Action = {};
+ sigemptyset(&Action.sa_mask);
+ Action.sa_handler = SIG_IGN;
+ sigaction(SIGCHLD, &Action, nullptr);
+ return true;
+}();
+#endif
+
+void
+SetCurrentThreadName([[maybe_unused]] std::string_view ThreadName)
+{
+#if ZEN_PLATFORM_WINDOWS
+ // The SetThreadDescription API works even if no debugger is attached.
+ static auto SetThreadDescriptionFunc =
+ reinterpret_cast<SetThreadDescription>(::GetProcAddress(::GetModuleHandle(L"Kernel32.dll"), "SetThreadDescription"));
+
+ if (SetThreadDescriptionFunc)
+ {
+ SetThreadDescriptionFunc(::GetCurrentThread(), Utf8ToWide(ThreadName).c_str());
+ }
+ // The debugger needs to be around to catch the name in the exception. If
+ // there isn't a debugger, we are just needlessly throwing an exception.
+ if (!::IsDebuggerPresent())
+ return;
+
+ std::string ThreadNameZ{ThreadName};
+ SetNameInternal(GetCurrentThreadId(), ThreadNameZ.c_str());
+#else
+ std::string ThreadNameZ{ThreadName};
+# if ZEN_PLATFORM_MAC
+ pthread_setname_np(ThreadNameZ.c_str());
+# else
+ pthread_setname_np(pthread_self(), ThreadNameZ.c_str());
+# endif
+#endif
+} // namespace zen
+
+void
+RwLock::AcquireShared()
+{
+ m_Mutex.lock_shared();
+}
+
+void
+RwLock::ReleaseShared()
+{
+ m_Mutex.unlock_shared();
+}
+
+void
+RwLock::AcquireExclusive()
+{
+ m_Mutex.lock();
+}
+
+void
+RwLock::ReleaseExclusive()
+{
+ m_Mutex.unlock();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if !ZEN_PLATFORM_WINDOWS
+struct EventInner
+{
+ std::mutex Mutex;
+ std::condition_variable CondVar;
+ bool volatile bSet = false;
+};
+#endif // !ZEN_PLATFORM_WINDOWS
+
+Event::Event()
+{
+ bool bManualReset = true;
+ bool bInitialState = false;
+
+#if ZEN_PLATFORM_WINDOWS
+ m_EventHandle = CreateEvent(nullptr, bManualReset, bInitialState, nullptr);
+#else
+ ZEN_UNUSED(bManualReset);
+ auto* Inner = new EventInner();
+ Inner->bSet = bInitialState;
+ m_EventHandle = Inner;
+#endif
+}
+
+Event::~Event()
+{
+ Close();
+}
+
+void
+Event::Set()
+{
+#if ZEN_PLATFORM_WINDOWS
+ SetEvent(m_EventHandle);
+#else
+ auto* Inner = (EventInner*)m_EventHandle;
+ {
+ std::unique_lock Lock(Inner->Mutex);
+ Inner->bSet = true;
+ }
+ Inner->CondVar.notify_all();
+#endif
+}
+
+void
+Event::Reset()
+{
+#if ZEN_PLATFORM_WINDOWS
+ ResetEvent(m_EventHandle);
+#else
+ auto* Inner = (EventInner*)m_EventHandle;
+ {
+ std::unique_lock Lock(Inner->Mutex);
+ Inner->bSet = false;
+ }
+#endif
+}
+
+void
+Event::Close()
+{
+#if ZEN_PLATFORM_WINDOWS
+ CloseHandle(m_EventHandle);
+#else
+ auto* Inner = (EventInner*)m_EventHandle;
+ delete Inner;
+#endif
+ m_EventHandle = nullptr;
+}
+
+bool
+Event::Wait(int TimeoutMs)
+{
+#if ZEN_PLATFORM_WINDOWS
+ using namespace std::literals;
+
+ const DWORD Timeout = (TimeoutMs < 0) ? INFINITE : TimeoutMs;
+
+ DWORD Result = WaitForSingleObject(m_EventHandle, Timeout);
+
+ if (Result == WAIT_FAILED)
+ {
+ zen::ThrowLastError("Event wait failed"sv);
+ }
+
+ return (Result == WAIT_OBJECT_0);
+#else
+ auto* Inner = (EventInner*)m_EventHandle;
+
+ if (TimeoutMs >= 0)
+ {
+ std::unique_lock Lock(Inner->Mutex);
+
+ if (Inner->bSet)
+ {
+ return true;
+ }
+
+ return Inner->CondVar.wait_for(Lock, std::chrono::milliseconds(TimeoutMs), [&] { return Inner->bSet; });
+ }
+
+ std::unique_lock Lock(Inner->Mutex);
+
+ if (!Inner->bSet)
+ {
+ Inner->CondVar.wait(Lock, [&] { return Inner->bSet; });
+ }
+
+ return true;
+#endif
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+NamedEvent::NamedEvent(std::string_view EventName)
+{
+#if ZEN_PLATFORM_WINDOWS
+ using namespace std::literals;
+
+ ExtendableStringBuilder<64> Name;
+ Name << "Local\\"sv;
+ Name << EventName;
+
+ m_EventHandle = CreateEventA(nullptr, true, false, Name.c_str());
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ // Create a file to back the semaphore
+ ExtendableStringBuilder<64> EventPath;
+ EventPath << "/tmp/" << EventName;
+
+ int Fd = open(EventPath.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666);
+ if (Fd < 0)
+ {
+ ThrowLastError(fmt::format("Failed to create '{}' for named event", EventPath));
+ }
+ fchmod(Fd, 0666);
+
+ // Use the file path to generate an IPC key
+ key_t IpcKey = ftok(EventPath.c_str(), 1);
+ if (IpcKey < 0)
+ {
+ close(Fd);
+ ThrowLastError("Failed to create an SysV IPC key");
+ }
+
+ // Use the key to create/open the semaphore
+ int Sem = semget(IpcKey, 1, 0600 | IPC_CREAT);
+ if (Sem < 0)
+ {
+ close(Fd);
+ ThrowLastError("Failed creating an SysV semaphore");
+ }
+
+ // Atomically claim ownership of the semaphore's key. The owner initialises
+ // the semaphore to 1 so we can use the wait-for-zero op as that does not
+ // modify the semaphore's value on a successful wait.
+ int LockResult = flock(Fd, LOCK_EX | LOCK_NB);
+ if (LockResult == 0)
+ {
+ // This isn't thread safe really. Another thread could open the same
+ // semaphore and successfully wait on it in the period of time where
+ // this comment is but before the semaphore's initialised.
+ semctl(Sem, 0, SETVAL, 1);
+ }
+
+ // Pack into the handle
+ static_assert(sizeof(Sem) + sizeof(Fd) <= sizeof(void*), "Semaphore packing assumptions not met");
+ intptr_t Packed;
+ Packed = intptr_t(Sem) << 32;
+ Packed |= intptr_t(Fd) & 0xffff'ffff;
+ m_EventHandle = (void*)Packed;
+#endif
+}
+
+NamedEvent::~NamedEvent()
+{
+ Close();
+}
+
+void
+NamedEvent::Close()
+{
+ if (m_EventHandle == nullptr)
+ {
+ return;
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ CloseHandle(m_EventHandle);
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ int Fd = int(intptr_t(m_EventHandle) & 0xffff'ffff);
+
+ if (flock(Fd, LOCK_EX | LOCK_NB) == 0)
+ {
+ std::filesystem::path Name = PathFromHandle((void*)(intptr_t(Fd)));
+ unlink(Name.c_str());
+
+ flock(Fd, LOCK_UN | LOCK_NB);
+ close(Fd);
+
+ int Sem = int(intptr_t(m_EventHandle) >> 32);
+ semctl(Sem, 0, IPC_RMID);
+ }
+#endif
+
+ m_EventHandle = nullptr;
+}
+
+void
+NamedEvent::Set()
+{
+#if ZEN_PLATFORM_WINDOWS
+ SetEvent(m_EventHandle);
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ int Sem = int(intptr_t(m_EventHandle) >> 32);
+ semctl(Sem, 0, SETVAL, 0);
+#endif
+}
+
+bool
+NamedEvent::Wait(int TimeoutMs)
+{
+#if ZEN_PLATFORM_WINDOWS
+ const DWORD Timeout = (TimeoutMs < 0) ? INFINITE : TimeoutMs;
+
+ DWORD Result = WaitForSingleObject(m_EventHandle, Timeout);
+
+ if (Result == WAIT_FAILED)
+ {
+ using namespace std::literals;
+ zen::ThrowLastError("Event wait failed"sv);
+ }
+
+ return (Result == WAIT_OBJECT_0);
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ int Sem = int(intptr_t(m_EventHandle) >> 32);
+
+ int Result;
+ struct sembuf SemOp = {};
+
+ if (TimeoutMs < 0)
+ {
+ Result = semop(Sem, &SemOp, 1);
+ return Result == 0;
+ }
+
+# if defined(_GNU_SOURCE)
+ struct timespec TimeoutValue = {
+ .tv_sec = TimeoutMs >> 10,
+ .tv_nsec = (TimeoutMs & 0x3ff) << 20,
+ };
+ Result = semtimedop(Sem, &SemOp, 1, &TimeoutValue);
+# else
+ const int SleepTimeMs = 10;
+ SemOp.sem_flg = IPC_NOWAIT;
+ do
+ {
+ Result = semop(Sem, &SemOp, 1);
+ if (Result == 0 || errno != EAGAIN)
+ {
+ break;
+ }
+
+ Sleep(SleepTimeMs);
+ TimeoutMs -= SleepTimeMs;
+ } while (TimeoutMs > 0);
+# endif // _GNU_SOURCE
+
+ return Result == 0;
+#endif
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+NamedMutex::~NamedMutex()
+{
+#if ZEN_PLATFORM_WINDOWS
+ if (m_MutexHandle)
+ {
+ CloseHandle(m_MutexHandle);
+ }
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ int Inner = int(intptr_t(m_MutexHandle));
+ flock(Inner, LOCK_UN);
+ close(Inner);
+#endif
+}
+
+bool
+NamedMutex::Create(std::string_view MutexName)
+{
+#if ZEN_PLATFORM_WINDOWS
+ ZEN_ASSERT(m_MutexHandle == nullptr);
+
+ using namespace std::literals;
+
+ ExtendableStringBuilder<64> Name;
+ Name << "Global\\"sv;
+ Name << MutexName;
+
+ m_MutexHandle = CreateMutexA(nullptr, /* InitialOwner */ TRUE, Name.c_str());
+
+ return !!m_MutexHandle;
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ ExtendableStringBuilder<64> Name;
+ Name << "/tmp/" << MutexName;
+
+ int Inner = open(Name.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666);
+ if (Inner < 0)
+ {
+ return false;
+ }
+ fchmod(Inner, 0666);
+
+ if (flock(Inner, LOCK_EX) != 0)
+ {
+ close(Inner);
+ Inner = 0;
+ return false;
+ }
+
+ m_MutexHandle = (void*)(intptr_t(Inner));
+ return true;
+#endif // ZEN_PLATFORM_WINDOWS
+}
+
+bool
+NamedMutex::Exists(std::string_view MutexName)
+{
+#if ZEN_PLATFORM_WINDOWS
+ using namespace std::literals;
+
+ ExtendableStringBuilder<64> Name;
+ Name << "Global\\"sv;
+ Name << MutexName;
+
+ void* MutexHandle = OpenMutexA(SYNCHRONIZE, /* InheritHandle */ FALSE, Name.c_str());
+
+ if (MutexHandle == nullptr)
+ {
+ return false;
+ }
+
+ CloseHandle(MutexHandle);
+
+ return true;
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ ExtendableStringBuilder<64> Name;
+ Name << "/tmp/" << MutexName;
+
+ bool bExists = false;
+ int Fd = open(Name.c_str(), O_RDWR | O_CLOEXEC);
+ if (Fd >= 0)
+ {
+ if (flock(Fd, LOCK_EX | LOCK_NB) == 0)
+ {
+ flock(Fd, LOCK_UN | LOCK_NB);
+ }
+ else
+ {
+ bExists = true;
+ }
+ close(Fd);
+ }
+
+ return bExists;
+#endif // ZEN_PLATFORM_WINDOWS
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ProcessHandle::ProcessHandle() = default;
+
+#if ZEN_PLATFORM_WINDOWS
+void
+ProcessHandle::Initialize(void* ProcessHandle)
+{
+ ZEN_ASSERT(m_ProcessHandle == nullptr);
+
+ if (ProcessHandle == INVALID_HANDLE_VALUE)
+ {
+ ProcessHandle = nullptr;
+ }
+
+ // TODO: perform some debug verification here to verify it's a valid handle?
+ m_ProcessHandle = ProcessHandle;
+ m_Pid = GetProcessId(m_ProcessHandle);
+}
+#endif // ZEN_PLATFORM_WINDOWS
+
+ProcessHandle::~ProcessHandle()
+{
+ Reset();
+}
+
+void
+ProcessHandle::Initialize(int Pid)
+{
+ ZEN_ASSERT(m_ProcessHandle == nullptr);
+
+#if ZEN_PLATFORM_WINDOWS
+ m_ProcessHandle = OpenProcess(PROCESS_QUERY_INFORMATION | SYNCHRONIZE, FALSE, Pid);
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ if (Pid > 0)
+ {
+ m_ProcessHandle = (void*)(intptr_t(Pid));
+ }
+#endif
+
+ if (!m_ProcessHandle)
+ {
+ ThrowLastError(fmt::format("ProcessHandle::Initialize(pid: {}) failed", Pid));
+ }
+
+ m_Pid = Pid;
+}
+
+bool
+ProcessHandle::IsRunning() const
+{
+ bool bActive = false;
+
+#if ZEN_PLATFORM_WINDOWS
+ DWORD ExitCode = 0;
+ GetExitCodeProcess(m_ProcessHandle, &ExitCode);
+ bActive = (ExitCode == STILL_ACTIVE);
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ bActive = (kill(pid_t(m_Pid), 0) == 0);
+#endif
+
+ return bActive;
+}
+
+bool
+ProcessHandle::IsValid() const
+{
+ return (m_ProcessHandle != nullptr);
+}
+
+void
+ProcessHandle::Terminate(int ExitCode)
+{
+ if (!IsRunning())
+ {
+ return;
+ }
+
+ bool bSuccess = false;
+
+#if ZEN_PLATFORM_WINDOWS
+ TerminateProcess(m_ProcessHandle, ExitCode);
+ DWORD WaitResult = WaitForSingleObject(m_ProcessHandle, INFINITE);
+ bSuccess = (WaitResult != WAIT_OBJECT_0);
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ ZEN_UNUSED(ExitCode);
+ bSuccess = (kill(m_Pid, SIGKILL) == 0);
+#endif
+
+ if (!bSuccess)
+ {
+ // What might go wrong here, and what is meaningful to act on?
+ }
+}
+
+void
+ProcessHandle::Reset()
+{
+ if (IsValid())
+ {
+#if ZEN_PLATFORM_WINDOWS
+ CloseHandle(m_ProcessHandle);
+#endif
+ m_ProcessHandle = nullptr;
+ m_Pid = 0;
+ }
+}
+
+bool
+ProcessHandle::Wait(int TimeoutMs)
+{
+ using namespace std::literals;
+
+#if ZEN_PLATFORM_WINDOWS
+ const DWORD Timeout = (TimeoutMs < 0) ? INFINITE : TimeoutMs;
+
+ const DWORD WaitResult = WaitForSingleObject(m_ProcessHandle, Timeout);
+
+ switch (WaitResult)
+ {
+ case WAIT_OBJECT_0:
+ return true;
+
+ case WAIT_TIMEOUT:
+ return false;
+
+ case WAIT_FAILED:
+ break;
+ }
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ const int SleepMs = 20;
+ timespec SleepTime = {0, SleepMs * 1000 * 1000};
+ for (int i = 0;; i += SleepMs)
+ {
+# if ZEN_PLATFORM_MAC
+ int WaitState = 0;
+ waitpid(m_Pid, &WaitState, WNOHANG | WCONTINUED | WUNTRACED);
+# endif
+
+ if (kill(m_Pid, 0) < 0)
+ {
+ if (zen::GetLastError() == ESRCH)
+ {
+ return true;
+ }
+ break;
+ }
+
+ if (TimeoutMs >= 0 && i >= TimeoutMs)
+ {
+ return false;
+ }
+
+ nanosleep(&SleepTime, nullptr);
+ }
+#endif
+
+ // What might go wrong here, and what is meaningful to act on?
+ ThrowLastError("Process::Wait failed"sv);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if !ZEN_PLATFORM_WINDOWS || ZEN_WITH_TESTS
+static void
+BuildArgV(std::vector<char*>& Out, char* CommandLine)
+{
+ char* Cursor = CommandLine;
+ while (true)
+ {
+ // Skip leading whitespace
+ for (; *Cursor == ' '; ++Cursor)
+ ;
+
+ // Check for nullp terminator
+ if (*Cursor == '\0')
+ {
+ break;
+ }
+
+ Out.push_back(Cursor);
+
+ // Extract word
+ int QuoteCount = 0;
+ do
+ {
+ QuoteCount += (*Cursor == '\"');
+ if (*Cursor == ' ' && !(QuoteCount & 1))
+ {
+ break;
+ }
+ ++Cursor;
+ } while (*Cursor != '\0');
+
+ if (*Cursor == '\0')
+ {
+ break;
+ }
+
+ *Cursor = '\0';
+ ++Cursor;
+ }
+}
+#endif // !WINDOWS || TESTS
+
+#if ZEN_PLATFORM_WINDOWS
+static CreateProcResult
+CreateProcNormal(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options)
+{
+ PROCESS_INFORMATION ProcessInfo{};
+ STARTUPINFO StartupInfo{.cb = sizeof(STARTUPINFO)};
+
+ const bool InheritHandles = false;
+ void* Environment = nullptr;
+ LPSECURITY_ATTRIBUTES ProcessAttributes = nullptr;
+ LPSECURITY_ATTRIBUTES ThreadAttributes = nullptr;
+
+ DWORD CreationFlags = 0;
+ if (Options.Flags & CreateProcOptions::Flag_NewConsole)
+ {
+ CreationFlags |= CREATE_NEW_CONSOLE;
+ }
+
+ const wchar_t* WorkingDir = nullptr;
+ if (Options.WorkingDirectory != nullptr)
+ {
+ WorkingDir = Options.WorkingDirectory->c_str();
+ }
+
+ ExtendableWideStringBuilder<256> CommandLineZ;
+ CommandLineZ << CommandLine;
+
+ BOOL Success = CreateProcessW(Executable.c_str(),
+ CommandLineZ.Data(),
+ ProcessAttributes,
+ ThreadAttributes,
+ InheritHandles,
+ CreationFlags,
+ Environment,
+ WorkingDir,
+ &StartupInfo,
+ &ProcessInfo);
+
+ if (!Success)
+ {
+ return nullptr;
+ }
+
+ CloseHandle(ProcessInfo.hThread);
+ return ProcessInfo.hProcess;
+}
+
+static CreateProcResult
+CreateProcUnelevated(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options)
+{
+ /* Launches a binary with the shell as its parent. The shell (such as
+ Explorer) should be an unelevated process. */
+
+ // No sense in using this route if we are not elevated in the first place
+ if (IsUserAnAdmin() == FALSE)
+ {
+ return CreateProcNormal(Executable, CommandLine, Options);
+ }
+
+ // Get the users' shell process and open it for process creation
+ HWND ShellWnd = GetShellWindow();
+ if (ShellWnd == nullptr)
+ {
+ return nullptr;
+ }
+
+ DWORD ShellPid;
+ GetWindowThreadProcessId(ShellWnd, &ShellPid);
+
+ HANDLE Process = OpenProcess(PROCESS_CREATE_PROCESS, FALSE, ShellPid);
+ if (Process == nullptr)
+ {
+ return nullptr;
+ }
+ auto $0 = MakeGuard([&] { CloseHandle(Process); });
+
+ // Creating a process as a child of another process is done by setting a
+ // thread-attribute list on the startup info passed to CreateProcess()
+ SIZE_T AttrListSize;
+ InitializeProcThreadAttributeList(nullptr, 1, 0, &AttrListSize);
+
+ auto AttrList = (PPROC_THREAD_ATTRIBUTE_LIST)malloc(AttrListSize);
+ auto $1 = MakeGuard([&] { free(AttrList); });
+
+ if (!InitializeProcThreadAttributeList(AttrList, 1, 0, &AttrListSize))
+ {
+ return nullptr;
+ }
+
+ BOOL bOk =
+ UpdateProcThreadAttribute(AttrList, 0, PROC_THREAD_ATTRIBUTE_PARENT_PROCESS, (HANDLE*)&Process, sizeof(Process), nullptr, nullptr);
+ if (!bOk)
+ {
+ return nullptr;
+ }
+
+ // By this point we know we are an elevated process. It is not allowed to
+ // create a process as a child of another unelevated process that share our
+ // elevated console window if we have one. So we'll need to create a new one.
+ uint32_t CreateProcFlags = EXTENDED_STARTUPINFO_PRESENT;
+ if (GetConsoleWindow() != nullptr)
+ {
+ CreateProcFlags |= CREATE_NEW_CONSOLE;
+ }
+ else
+ {
+ CreateProcFlags |= DETACHED_PROCESS;
+ }
+
+ // Everything is set up now so we can proceed and launch the process
+ STARTUPINFOEXW StartupInfo = {
+ .StartupInfo = {.cb = sizeof(STARTUPINFOEXW)},
+ .lpAttributeList = AttrList,
+ };
+ PROCESS_INFORMATION ProcessInfo = {};
+
+ if (Options.Flags & CreateProcOptions::Flag_NewConsole)
+ {
+ CreateProcFlags |= CREATE_NEW_CONSOLE;
+ }
+
+ ExtendableWideStringBuilder<256> CommandLineZ;
+ CommandLineZ << CommandLine;
+
+ bOk = CreateProcessW(Executable.c_str(),
+ CommandLineZ.Data(),
+ nullptr,
+ nullptr,
+ FALSE,
+ CreateProcFlags,
+ nullptr,
+ nullptr,
+ &StartupInfo.StartupInfo,
+ &ProcessInfo);
+ if (bOk == FALSE)
+ {
+ return nullptr;
+ }
+
+ CloseHandle(ProcessInfo.hThread);
+ return ProcessInfo.hProcess;
+}
+
+static CreateProcResult
+CreateProcElevated(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options)
+{
+ ExtendableWideStringBuilder<256> CommandLineZ;
+ CommandLineZ << CommandLine;
+
+ SHELLEXECUTEINFO ShellExecuteInfo;
+ ZeroMemory(&ShellExecuteInfo, sizeof(ShellExecuteInfo));
+ ShellExecuteInfo.cbSize = sizeof(ShellExecuteInfo);
+ ShellExecuteInfo.fMask = SEE_MASK_UNICODE | SEE_MASK_NOCLOSEPROCESS;
+ ShellExecuteInfo.lpFile = Executable.c_str();
+ ShellExecuteInfo.lpVerb = TEXT("runas");
+ ShellExecuteInfo.nShow = SW_SHOW;
+ ShellExecuteInfo.lpParameters = CommandLineZ.c_str();
+
+ if (Options.WorkingDirectory != nullptr)
+ {
+ ShellExecuteInfo.lpDirectory = Options.WorkingDirectory->c_str();
+ }
+
+ if (::ShellExecuteEx(&ShellExecuteInfo))
+ {
+ return ShellExecuteInfo.hProcess;
+ }
+
+ return nullptr;
+}
+#endif // ZEN_PLATFORM_WINDOWS
+
+CreateProcResult
+CreateProc(const std::filesystem::path& Executable, std::string_view CommandLine, const CreateProcOptions& Options)
+{
+#if ZEN_PLATFORM_WINDOWS
+ if (Options.Flags & CreateProcOptions::Flag_Unelevated)
+ {
+ return CreateProcUnelevated(Executable, CommandLine, Options);
+ }
+
+ if (Options.Flags & CreateProcOptions::Flag_Elevated)
+ {
+ return CreateProcElevated(Executable, CommandLine, Options);
+ }
+
+ return CreateProcNormal(Executable, CommandLine, Options);
+#else
+ std::vector<char*> ArgV;
+ std::string CommandLineZ(CommandLine);
+ BuildArgV(ArgV, CommandLineZ.data());
+ ArgV.push_back(nullptr);
+
+ int ChildPid = fork();
+ if (ChildPid < 0)
+ {
+ ThrowLastError("Failed to fork a new child process");
+ }
+ else if (ChildPid == 0)
+ {
+ if (Options.WorkingDirectory != nullptr)
+ {
+ int Result = chdir(Options.WorkingDirectory->c_str());
+ ZEN_UNUSED(Result);
+ }
+
+ if (execv(Executable.c_str(), ArgV.data()) < 0)
+ {
+ ThrowLastError("Failed to exec() a new process image");
+ }
+ }
+
+ return ChildPid;
+#endif
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ProcessMonitor::ProcessMonitor()
+{
+}
+
+ProcessMonitor::~ProcessMonitor()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ for (HandleType& Proc : m_ProcessHandles)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ CloseHandle(Proc);
+#endif
+ Proc = 0;
+ }
+}
+
+bool
+ProcessMonitor::IsRunning()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ bool FoundOne = false;
+
+ for (HandleType& Proc : m_ProcessHandles)
+ {
+ bool ProcIsActive;
+
+#if ZEN_PLATFORM_WINDOWS
+ DWORD ExitCode = 0;
+ GetExitCodeProcess(Proc, &ExitCode);
+
+ ProcIsActive = (ExitCode == STILL_ACTIVE);
+ if (!ProcIsActive)
+ {
+ CloseHandle(Proc);
+ }
+#else
+ int Pid = int(intptr_t(Proc));
+ ProcIsActive = IsProcessRunning(Pid);
+#endif
+
+ if (!ProcIsActive)
+ {
+ Proc = 0;
+ }
+
+ // Still alive
+ FoundOne |= ProcIsActive;
+ }
+
+ std::erase_if(m_ProcessHandles, [](HandleType Handle) { return Handle == 0; });
+
+ return FoundOne;
+}
+
+void
+ProcessMonitor::AddPid(int Pid)
+{
+ HandleType ProcessHandle;
+
+#if ZEN_PLATFORM_WINDOWS
+ ProcessHandle = OpenProcess(PROCESS_QUERY_INFORMATION | SYNCHRONIZE, FALSE, Pid);
+#else
+ ProcessHandle = HandleType(intptr_t(Pid));
+#endif
+
+ if (ProcessHandle)
+ {
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_ProcessHandles.push_back(ProcessHandle);
+ }
+}
+
+bool
+ProcessMonitor::IsActive() const
+{
+ RwLock::SharedLockScope _(m_Lock);
+ return m_ProcessHandles.empty() == false;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+bool
+IsProcessRunning(int pid)
+{
+ // This function is arguably not super useful, a pid can be re-used
+ // by the OS so holding on to a pid and polling it over some time
+ // period will not necessarily tell you what you probably want to know.
+
+#if ZEN_PLATFORM_WINDOWS
+ HANDLE hProc = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, FALSE, pid);
+
+ if (!hProc)
+ {
+ DWORD Error = zen::GetLastError();
+
+ if (Error == ERROR_INVALID_PARAMETER)
+ {
+ return false;
+ }
+
+ ThrowSystemError(Error, fmt::format("failed to open process with pid {}", pid));
+ }
+
+ CloseHandle(hProc);
+
+ return true;
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ return (kill(pid_t(pid), 0) == 0);
+#endif
+}
+
+int
+GetCurrentProcessId()
+{
+#if ZEN_PLATFORM_WINDOWS
+ return ::GetCurrentProcessId();
+#else
+ return int(getpid());
+#endif
+}
+
+int
+GetCurrentThreadId()
+{
+#if ZEN_PLATFORM_WINDOWS
+ return ::GetCurrentThreadId();
+#elif ZEN_PLATFORM_LINUX
+ return int(syscall(SYS_gettid));
+#elif ZEN_PLATFORM_MAC
+ return int(pthread_mach_thread_np(pthread_self()));
+#endif
+}
+
+void
+Sleep(int ms)
+{
+#if ZEN_PLATFORM_WINDOWS
+ ::Sleep(ms);
+#else
+ usleep(ms * 1000U);
+#endif
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+void
+thread_forcelink()
+{
+}
+
+TEST_CASE("Thread")
+{
+ int Pid = GetCurrentProcessId();
+ CHECK(Pid > 0);
+ CHECK(IsProcessRunning(Pid));
+
+ CHECK_FALSE(GetCurrentThreadId() == 0);
+}
+
+TEST_CASE("BuildArgV")
+{
+ const char* Words[] = {"one", "two", "three", "four", "five"};
+ struct
+ {
+ int WordCount;
+ const char* Input;
+ } Cases[] = {
+ {0, ""},
+ {0, " "},
+ {1, "one"},
+ {1, " one"},
+ {1, "one "},
+ {2, "one two"},
+ {2, " one two"},
+ {2, "one two "},
+ {2, " one two"},
+ {2, "one two "},
+ {2, "one two "},
+ {3, "one two three"},
+ {3, "\"one\" two \"three\""},
+ {5, "one two three four five"},
+ };
+
+ for (const auto& Case : Cases)
+ {
+ std::vector<char*> OutArgs;
+ StringBuilder<64> Mutable;
+ Mutable << Case.Input;
+ BuildArgV(OutArgs, Mutable.Data());
+
+ CHECK_EQ(OutArgs.size(), Case.WordCount);
+
+ for (int i = 0, n = int(OutArgs.size()); i < n; ++i)
+ {
+ const char* Truth = Words[i];
+ size_t TruthLen = strlen(Truth);
+
+ const char* Candidate = OutArgs[i];
+ bool bQuoted = (Candidate[0] == '\"');
+ Candidate += bQuoted;
+
+ CHECK(strncmp(Truth, Candidate, TruthLen) == 0);
+
+ if (bQuoted)
+ {
+ CHECK_EQ(Candidate[TruthLen], '\"');
+ }
+ }
+ }
+}
+
+TEST_CASE("NamedEvent")
+{
+ std::string Name = "zencore_test_event";
+ NamedEvent TestEvent(Name);
+
+ // Timeout test
+ for (uint32_t i = 0; i < 8; ++i)
+ {
+ bool bEventSet = TestEvent.Wait(100);
+ CHECK(!bEventSet);
+ }
+
+ // Thread check
+ std::thread Waiter = std::thread([Name]() {
+ NamedEvent ReadyEvent(Name + "_ready");
+ ReadyEvent.Set();
+
+ NamedEvent TestEvent(Name);
+ TestEvent.Wait(1000);
+ });
+
+ NamedEvent ReadyEvent(Name + "_ready");
+ ReadyEvent.Wait();
+
+ zen::Sleep(500);
+ TestEvent.Set();
+
+ Waiter.join();
+
+ // Manual reset property
+ for (uint32_t i = 0; i < 8; ++i)
+ {
+ bool bEventSet = TestEvent.Wait(100);
+ CHECK(bEventSet);
+ }
+}
+
+TEST_CASE("NamedMutex")
+{
+ static const char* Name = "zen_test_mutex";
+
+ CHECK(!NamedMutex::Exists(Name));
+
+ {
+ NamedMutex TestMutex;
+ CHECK(TestMutex.Create(Name));
+ CHECK(NamedMutex::Exists(Name));
+ }
+
+ CHECK(!NamedMutex::Exists(Name));
+}
+
+#endif // ZEN_WITH_TESTS
+
+} // namespace zen
diff --git a/src/zencore/timer.cpp b/src/zencore/timer.cpp
new file mode 100644
index 000000000..1655e912d
--- /dev/null
+++ b/src/zencore/timer.cpp
@@ -0,0 +1,105 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/thread.h>
+#include <zencore/timer.h>
+
+#include <zencore/testing.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#elif ZEN_PLATFORM_LINUX
+# include <time.h>
+# include <unistd.h>
+#endif
+
+namespace zen {
+
+uint64_t
+GetHifreqTimerValue()
+{
+ uint64_t Timestamp;
+
+#if ZEN_PLATFORM_WINDOWS
+ LARGE_INTEGER li;
+ QueryPerformanceCounter(&li);
+
+ Timestamp = li.QuadPart;
+#else
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ Timestamp = (uint64_t(ts.tv_sec) * 1000000ull) + (uint64_t(ts.tv_nsec) / 1000ull);
+#endif
+
+ return Timestamp;
+}
+
+uint64_t
+InternalGetHifreqTimerFrequency()
+{
+#if ZEN_PLATFORM_WINDOWS
+ LARGE_INTEGER li;
+ QueryPerformanceFrequency(&li);
+
+ return li.QuadPart;
+#else
+ return 1000000ull;
+#endif
+}
+
+uint64_t QpcFreq = InternalGetHifreqTimerFrequency();
+static const double QpcFactor = 1.0 / InternalGetHifreqTimerFrequency();
+
+uint64_t
+GetHifreqTimerFrequency()
+{
+ return QpcFreq;
+}
+
+double
+GetHifreqTimerToSeconds()
+{
+ return QpcFactor;
+}
+
+uint64_t
+GetHifreqTimerFrequencySafe()
+{
+ if (!QpcFreq)
+ {
+ QpcFreq = InternalGetHifreqTimerFrequency();
+ }
+
+ return QpcFreq;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+uint64_t detail::g_LofreqTimerValue = GetHifreqTimerValue();
+
+void
+UpdateLofreqTimerValue()
+{
+ detail::g_LofreqTimerValue = GetHifreqTimerValue();
+}
+
+uint64_t
+GetLofreqTimerFrequency()
+{
+ return GetHifreqTimerFrequencySafe();
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+void
+timer_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/trace.cpp b/src/zencore/trace.cpp
new file mode 100644
index 000000000..788dcec07
--- /dev/null
+++ b/src/zencore/trace.cpp
@@ -0,0 +1,45 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+/* clang-format off */
+
+#if ZEN_WITH_TRACE
+
+#include <zencore/zencore.h>
+
+#define TRACE_IMPLEMENT 1
+#include <zencore/trace.h>
+
+void
+TraceInit(const char* HostOrPath, TraceType Type)
+{
+ bool EnableEvents = true;
+
+ switch (Type)
+ {
+ case TraceType::Network:
+ trace::SendTo(HostOrPath);
+ break;
+
+ case TraceType::File:
+ trace::WriteTo(HostOrPath);
+ break;
+
+ case TraceType::None:
+ EnableEvents = false;
+ break;
+ }
+
+ trace::FInitializeDesc Desc = {
+ .bUseImportantCache = false,
+ };
+ trace::Initialize(Desc);
+
+ if (EnableEvents)
+ {
+ trace::ToggleChannel("cpu", true);
+ }
+}
+
+#endif // ZEN_WITH_TRACE
+
+/* clang-format on */
diff --git a/src/zencore/uid.cpp b/src/zencore/uid.cpp
new file mode 100644
index 000000000..86cdfae3a
--- /dev/null
+++ b/src/zencore/uid.cpp
@@ -0,0 +1,148 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/uid.h>
+
+#include <zencore/endian.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+
+#include <atomic>
+#include <bit>
+#include <chrono>
+#include <random>
+#include <set>
+#include <unordered_map>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+ static bool OidInitialised;
+ static uint32_t RunId;
+ static std::atomic_uint32_t Serial;
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+const Oid Oid::Zero = {{0u, 0u, 0u}};
+const Oid Oid::Max = {{~0u, ~0u, ~0u}};
+
+void
+Oid::Initialize()
+{
+ if (!detail::OidInitialised)
+ {
+ std::random_device Rng;
+ detail::RunId = Rng();
+ detail::Serial = Rng();
+
+ detail::OidInitialised = true;
+ }
+}
+
+const Oid&
+Oid::Generate()
+{
+ if (!detail::OidInitialised)
+ {
+ Oid::Initialize();
+ }
+
+ const uint64_t kOffset = 1'609'459'200; // Seconds from 1970 -> 2021
+ const uint64_t Time = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()) - kOffset;
+
+ OidBits[0] = ToNetworkOrder(uint32_t(Time));
+ OidBits[1] = ToNetworkOrder(uint32_t(++detail::Serial));
+ OidBits[2] = detail::RunId;
+
+ return *this;
+}
+
+Oid
+Oid::NewOid()
+{
+ return Oid().Generate();
+}
+
+Oid
+Oid::FromHexString(const std::string_view String)
+{
+ ZEN_ASSERT(String.size() == 2 * sizeof(Oid::OidBits));
+
+ Oid Id;
+
+ if (ParseHexBytes(String.data(), String.size(), reinterpret_cast<uint8_t*>(Id.OidBits)))
+ {
+ return Id;
+ }
+ else
+ {
+ return Oid::Zero;
+ }
+}
+
+Oid
+Oid::FromMemory(const void* Ptr)
+{
+ Oid Id;
+ memcpy(Id.OidBits, Ptr, sizeof Id);
+ return Id;
+}
+
+void
+Oid::ToString(char OutString[StringLength])
+{
+ ToHexBytes(reinterpret_cast<const uint8_t*>(OidBits), sizeof(Oid::OidBits), OutString);
+}
+
+StringBuilderBase&
+Oid::ToString(StringBuilderBase& OutString) const
+{
+ String_t Str;
+ ToHexBytes(reinterpret_cast<const uint8_t*>(OidBits), sizeof(Oid::OidBits), Str);
+
+ OutString.AppendRange(Str, &Str[StringLength]);
+
+ return OutString;
+}
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("Oid")
+{
+ SUBCASE("Basic")
+ {
+ Oid id1 = Oid::NewOid();
+ ZEN_UNUSED(id1);
+
+ std::vector<Oid> ids;
+ std::set<Oid> idset;
+ std::unordered_map<Oid, int, Oid::Hasher> idmap;
+
+ const int Count = 1000;
+
+ for (int i = 0; i < Count; ++i)
+ {
+ Oid id;
+ id.Generate();
+
+ ids.emplace_back(id);
+ idset.insert(id);
+ idmap.insert({id, i});
+ }
+
+ CHECK(ids.size() == Count);
+ CHECK(idset.size() == Count); // All ids should be unique
+ CHECK(idmap.size() == Count); // Ditto
+ }
+}
+
+void
+uid_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zencore/workthreadpool.cpp b/src/zencore/workthreadpool.cpp
new file mode 100644
index 000000000..b4328cdbd
--- /dev/null
+++ b/src/zencore/workthreadpool.cpp
@@ -0,0 +1,83 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/workthreadpool.h>
+
+#include <zencore/logging.h>
+
+namespace zen {
+
+namespace detail {
+ struct LambdaWork : IWork
+ {
+ LambdaWork(auto Work) : WorkFunction(Work) {}
+ virtual void Execute() override { WorkFunction(); }
+
+ std::function<void()> WorkFunction;
+ };
+} // namespace detail
+
+WorkerThreadPool::WorkerThreadPool(int InThreadCount)
+{
+ for (int i = 0; i < InThreadCount; ++i)
+ {
+ m_WorkerThreads.emplace_back(&WorkerThreadPool::WorkerThreadFunction, this);
+ }
+}
+
+WorkerThreadPool::~WorkerThreadPool()
+{
+ m_WorkQueue.CompleteAdding();
+
+ for (std::thread& Thread : m_WorkerThreads)
+ {
+ Thread.join();
+ }
+
+ m_WorkerThreads.clear();
+}
+
+void
+WorkerThreadPool::ScheduleWork(Ref<IWork> Work)
+{
+ m_WorkQueue.Enqueue(std::move(Work));
+}
+
+void
+WorkerThreadPool::ScheduleWork(std::function<void()>&& Work)
+{
+ m_WorkQueue.Enqueue(Ref<IWork>(new detail::LambdaWork(Work)));
+}
+
+[[nodiscard]] size_t
+WorkerThreadPool::PendingWork() const
+{
+ return m_WorkQueue.Size();
+}
+
+void
+WorkerThreadPool::WorkerThreadFunction()
+{
+ do
+ {
+ Ref<IWork> Work;
+ if (m_WorkQueue.WaitAndDequeue(Work))
+ {
+ try
+ {
+ Work->Execute();
+ }
+ catch (std::exception& e)
+ {
+ Work->m_Exception = std::current_exception();
+
+ ZEN_WARN("Caught exception in worker thread: {}", e.what());
+ }
+ }
+ else
+ {
+ return;
+ }
+ } while (true);
+}
+
+} // namespace zen
diff --git a/src/zencore/xmake.lua b/src/zencore/xmake.lua
new file mode 100644
index 000000000..e1e649c1d
--- /dev/null
+++ b/src/zencore/xmake.lua
@@ -0,0 +1,61 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target('zencore')
+ set_kind("static")
+ add_headerfiles("**.h")
+ add_configfiles("include/zencore/config.h.in")
+ on_load(function (target)
+ local version = io.readfile("VERSION.txt")
+ version = string.gsub(version,"%-pre.*", "")
+ target:set("version", version:trim(), {build = "%Y%m%d%H%M"})
+ end)
+ set_configdir("include/zencore")
+ add_files("**.cpp")
+ add_includedirs("include", {public=true})
+ add_includedirs("$(projectdir)/thirdparty/utfcpp/source")
+ add_includedirs("$(projectdir)/thirdparty/trace", {public=true})
+ if is_os("windows") then
+ add_linkdirs("$(projectdir)/thirdparty/Oodle/lib/Win64")
+ elseif is_os("linux") then
+ add_linkdirs("$(projectdir)/thirdparty/Oodle/lib/Linux_x64")
+ add_links("oo2corelinux64")
+ add_syslinks("pthread")
+ elseif is_os("macosx") then
+ add_linkdirs("$(projectdir)/thirdparty/Oodle/lib/Mac_x64")
+ add_links("oo2coremac64")
+ end
+ add_options("zentrace")
+ add_packages(
+ "vcpkg::blake3",
+ "vcpkg::cpr",
+ "vcpkg::curl", -- required by cpr
+ "vcpkg::doctest",
+ "vcpkg::fmt",
+ "vcpkg::gsl-lite",
+ "vcpkg::json11",
+ "vcpkg::lz4",
+ "vcpkg::mimalloc",
+ "vcpkg::openssl", -- required by curl
+ "vcpkg::spdlog",
+ "vcpkg::zlib", -- required by curl
+ "vcpkg::xxhash")
+
+ if is_plat("linux") then
+ -- The 'vcpkg::openssl' package is two libraries; ssl and crypto, with
+ -- ssl being dependent on symbols in crypto. When GCC-like linkers read
+ -- object files from their command line, those object files only resolve
+ -- symbols of objects previously encountered. Thus crypto must appear
+ -- after ssl so it can fill out ssl's unresolved symbol table. Xmake's
+ -- vcpkg support is basic and works by parsing .list files. Openssl's
+ -- archives are listed alphabetically causing crypto to be _before_ ssl
+ -- and resulting in link errors. The links are restated here to force
+ -- xmake to use the correct order, and "syslinks" is used to force the
+ -- arguments to the end of the line (otherwise they can appear before
+ -- curl and cause more errors).
+ add_syslinks("crypto")
+ add_syslinks("dl")
+ end
+
+ if is_plat("linux") then
+ add_syslinks("rt")
+ end
diff --git a/src/zencore/xxhash.cpp b/src/zencore/xxhash.cpp
new file mode 100644
index 000000000..450131d19
--- /dev/null
+++ b/src/zencore/xxhash.cpp
@@ -0,0 +1,50 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/xxhash.h>
+
+#include <zencore/string.h>
+#include <zencore/testing.h>
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+XXH3_128 XXH3_128::Zero; // Initialized to all zeros
+
+XXH3_128
+XXH3_128::FromHexString(const char* InString)
+{
+ return FromHexString({InString, sizeof(XXH3_128::Hash) * 2});
+}
+
+XXH3_128
+XXH3_128::FromHexString(std::string_view InString)
+{
+ ZEN_ASSERT(InString.size() == 2 * sizeof(XXH3_128::Hash));
+
+ XXH3_128 Xx;
+ ParseHexBytes(InString.data(), InString.size(), Xx.Hash);
+ return Xx;
+}
+
+const char*
+XXH3_128::ToHexString(char* OutString /* 40 characters + NUL terminator */) const
+{
+ ToHexBytes(Hash, sizeof(XXH3_128), OutString);
+ OutString[2 * sizeof(XXH3_128)] = '\0';
+
+ return OutString;
+}
+
+StringBuilderBase&
+XXH3_128::ToHexString(StringBuilderBase& OutBuilder) const
+{
+ String_t str;
+ ToHexString(str);
+
+ OutBuilder.AppendRange(str, &str[StringLength]);
+
+ return OutBuilder;
+}
+
+} // namespace zen
diff --git a/src/zencore/zencore.cpp b/src/zencore/zencore.cpp
new file mode 100644
index 000000000..2a7c5755e
--- /dev/null
+++ b/src/zencore/zencore.cpp
@@ -0,0 +1,175 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/zencore.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#endif
+
+#if ZEN_PLATFORM_LINUX
+# include <pthread.h>
+#endif
+
+#include <zencore/blake3.h>
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/compress.h>
+#include <zencore/crypto.h>
+#include <zencore/filesystem.h>
+#include <zencore/intmath.h>
+#include <zencore/iobuffer.h>
+#include <zencore/memory.h>
+#include <zencore/mpscqueue.h>
+#include <zencore/refcount.h>
+#include <zencore/sha1.h>
+#include <zencore/stats.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/timer.h>
+#include <zencore/uid.h>
+
+namespace zen {
+
+AssertImpl AssertImpl::DefaultAssertImpl;
+AssertImpl* AssertImpl::CurrentAssertImpl = &AssertImpl::DefaultAssertImpl;
+
+//////////////////////////////////////////////////////////////////////////
+
+bool
+IsDebuggerPresent()
+{
+#if ZEN_PLATFORM_WINDOWS
+ return ::IsDebuggerPresent();
+#else
+ return false;
+#endif
+}
+
+std::optional<bool> InteractiveSessionFlag;
+
+void
+SetIsInteractiveSession(bool Value)
+{
+ InteractiveSessionFlag = Value;
+}
+
+bool
+IsInteractiveSession()
+{
+ if (!InteractiveSessionFlag.has_value())
+ {
+#if ZEN_PLATFORM_WINDOWS
+ DWORD dwSessionId = 0;
+ if (ProcessIdToSessionId(GetCurrentProcessId(), &dwSessionId))
+ {
+ InteractiveSessionFlag = (dwSessionId != 0);
+ }
+ else
+ {
+ InteractiveSessionFlag = false;
+ }
+#else
+ // TODO: figure out what actually makes sense here
+ InteractiveSessionFlag = true;
+#endif
+ }
+
+ return InteractiveSessionFlag.value();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+static int s_ApplicationExitCode = 0;
+static bool s_ApplicationExitRequested;
+
+bool
+IsApplicationExitRequested()
+{
+ return s_ApplicationExitRequested;
+}
+
+void
+RequestApplicationExit(int ExitCode)
+{
+ s_ApplicationExitCode = ExitCode;
+ s_ApplicationExitRequested = true;
+}
+
+#if ZEN_WITH_TESTS
+void
+zencore_forcelinktests()
+{
+ zen::blake3_forcelink();
+ zen::compositebuffer_forcelink();
+ zen::compress_forcelink();
+ zen::filesystem_forcelink();
+ zen::intmath_forcelink();
+ zen::iobuffer_forcelink();
+ zen::memory_forcelink();
+ zen::mpscqueue_forcelink();
+ zen::refcount_forcelink();
+ zen::sha1_forcelink();
+ zen::stats_forcelink();
+ zen::stream_forcelink();
+ zen::string_forcelink();
+ zen::thread_forcelink();
+ zen::timer_forcelink();
+ zen::uid_forcelink();
+ zen::uson_forcelink();
+ zen::usonbuilder_forcelink();
+ zen::usonpackage_forcelink();
+ zen::crypto_forcelink();
+}
+} // namespace zen
+
+# include <zencore/testing.h>
+
+namespace zen {
+
+TEST_CASE("Assert.Default")
+{
+ bool A = true;
+ bool B = false;
+ CHECK_THROWS_WITH(ZEN_ASSERT(A == B), "A == B");
+}
+
+TEST_CASE("Assert.Custom")
+{
+ struct MyAssertImpl : AssertImpl
+ {
+ ZEN_FORCENOINLINE ZEN_DEBUG_SECTION MyAssertImpl() : PrevAssertImpl(CurrentAssertImpl) { CurrentAssertImpl = this; }
+ virtual ZEN_FORCENOINLINE ZEN_DEBUG_SECTION ~MyAssertImpl() { CurrentAssertImpl = PrevAssertImpl; }
+ virtual void ZEN_FORCENOINLINE ZEN_DEBUG_SECTION OnAssert(const char* Filename,
+ int LineNumber,
+ const char* FunctionName,
+ const char* Msg)
+ {
+ AssertFileName = Filename;
+ Line = LineNumber;
+ FuncName = FunctionName;
+ Message = Msg;
+ }
+ AssertImpl* PrevAssertImpl;
+
+ const char* AssertFileName = nullptr;
+ int Line = -1;
+ const char* FuncName = nullptr;
+ const char* Message = nullptr;
+ };
+
+ MyAssertImpl MyAssert;
+ bool A = true;
+ bool B = false;
+ CHECK_THROWS_WITH(ZEN_ASSERT(A == B), "A == B");
+ CHECK(MyAssert.AssertFileName != nullptr);
+ CHECK(MyAssert.Line != -1);
+ CHECK(MyAssert.FuncName != nullptr);
+ CHECK(strcmp(MyAssert.Message, "A == B") == 0);
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenhttp/httpasio.cpp b/src/zenhttp/httpasio.cpp
new file mode 100644
index 000000000..79b2c0a3d
--- /dev/null
+++ b/src/zenhttp/httpasio.cpp
@@ -0,0 +1,1372 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpasio.h"
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+#include <deque>
+#include <memory>
+#include <string_view>
+#include <vector>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#if ZEN_PLATFORM_WINDOWS
+# include <conio.h>
+# include <mstcpip.h>
+#endif
+#include <http_parser.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#define ASIO_VERBOSE_TRACE 0
+
+#if ASIO_VERBOSE_TRACE
+# define ZEN_TRACE_VERBOSE ZEN_TRACE
+#else
+# define ZEN_TRACE_VERBOSE(fmtstr, ...)
+#endif
+
+namespace zen::asio_http {
+
+using namespace std::literals;
+
+struct HttpAcceptor;
+struct HttpRequest;
+struct HttpResponse;
+struct HttpServerConnection;
+
+static constinit uint32_t HashContentLength = HashStringAsLowerDjb2("Content-Length"sv);
+static constinit uint32_t HashContentType = HashStringAsLowerDjb2("Content-Type"sv);
+static constinit uint32_t HashAccept = HashStringAsLowerDjb2("Accept"sv);
+static constinit uint32_t HashExpect = HashStringAsLowerDjb2("Expect"sv);
+static constinit uint32_t HashSession = HashStringAsLowerDjb2("UE-Session"sv);
+static constinit uint32_t HashRequest = HashStringAsLowerDjb2("UE-Request"sv);
+static constinit uint32_t HashRange = HashStringAsLowerDjb2("Range"sv);
+
+inline spdlog::logger&
+InitLogger()
+{
+ spdlog::logger& Logger = logging::Get("asio");
+ // Logger.set_level(spdlog::level::trace);
+ return Logger;
+}
+
+inline spdlog::logger&
+Log()
+{
+ static spdlog::logger& g_Logger = InitLogger();
+ return g_Logger;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpAsioServerImpl
+{
+public:
+ HttpAsioServerImpl();
+ ~HttpAsioServerImpl();
+
+ int Start(uint16_t Port, int ThreadCount);
+ void Stop();
+ void RegisterService(const char* UrlPath, HttpService& Service);
+ HttpService* RouteRequest(std::string_view Url);
+
+ asio::io_service m_IoService;
+ asio::io_service::work m_Work{m_IoService};
+ std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
+ std::vector<std::thread> m_ThreadPool;
+
+ struct ServiceEntry
+ {
+ std::string ServiceUrlPath;
+ HttpService* Service;
+ };
+
+ RwLock m_Lock;
+ std::vector<ServiceEntry> m_UriHandlers;
+};
+
+/**
+ * This is the class which request handlers use to interact with the server instance
+ */
+
+class HttpAsioServerRequest : public HttpServerRequest
+{
+public:
+ HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer);
+ ~HttpAsioServerRequest();
+
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
+
+ virtual IoBuffer ReadPayload() override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override;
+ virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
+ virtual bool TryGetRanges(HttpRanges& Ranges) override;
+
+ using HttpServerRequest::WriteResponse;
+
+ HttpAsioServerRequest(const HttpAsioServerRequest&) = delete;
+ HttpAsioServerRequest& operator=(const HttpAsioServerRequest&) = delete;
+
+ asio_http::HttpRequest& m_Request;
+ IoBuffer m_PayloadBuffer;
+ std::unique_ptr<HttpResponse> m_Response;
+};
+
+struct HttpRequest
+{
+ explicit HttpRequest(HttpServerConnection& Connection) : m_Connection(Connection) {}
+
+ void Initialize();
+ size_t ConsumeData(const char* InputData, size_t DataSize);
+ void ResetState();
+
+ HttpVerb RequestVerb() const { return m_RequestVerb; }
+ bool IsKeepAlive() const { return m_KeepAlive; }
+ std::string_view Url() const { return m_NormalizedUrl.empty() ? std::string_view(m_Url, m_UrlLength) : m_NormalizedUrl; }
+ std::string_view QueryString() const { return std::string_view(m_QueryString, m_QueryLength); }
+ IoBuffer Body() { return m_BodyBuffer; }
+
+ inline HttpContentType ContentType()
+ {
+ if (m_ContentTypeHeaderIndex < 0)
+ {
+ return HttpContentType::kUnknownContentType;
+ }
+
+ return ParseContentType(m_Headers[m_ContentTypeHeaderIndex].Value);
+ }
+
+ inline HttpContentType AcceptType()
+ {
+ if (m_AcceptHeaderIndex < 0)
+ {
+ return HttpContentType::kUnknownContentType;
+ }
+
+ return ParseContentType(m_Headers[m_AcceptHeaderIndex].Value);
+ }
+
+ Oid SessionId() const { return m_SessionId; }
+ int RequestId() const { return m_RequestId; }
+
+ std::string_view RangeHeader() const { return m_RangeHeaderIndex != -1 ? m_Headers[m_RangeHeaderIndex].Value : std::string_view(); }
+
+private:
+ struct HeaderEntry
+ {
+ HeaderEntry() = default;
+
+ HeaderEntry(std::string_view InName, std::string_view InValue) : Name(InName), Value(InValue) {}
+
+ std::string_view Name;
+ std::string_view Value;
+ };
+
+ HttpServerConnection& m_Connection;
+ char* m_HeaderCursor = m_HeaderBuffer;
+ char* m_Url = nullptr;
+ size_t m_UrlLength = 0;
+ char* m_QueryString = nullptr;
+ size_t m_QueryLength = 0;
+ char* m_CurrentHeaderName = nullptr; // Used while parsing headers
+ size_t m_CurrentHeaderNameLength = 0;
+ char* m_CurrentHeaderValue = nullptr; // Used while parsing headers
+ size_t m_CurrentHeaderValueLength = 0;
+ std::vector<HeaderEntry> m_Headers;
+ int8_t m_ContentLengthHeaderIndex;
+ int8_t m_AcceptHeaderIndex;
+ int8_t m_ContentTypeHeaderIndex;
+ int8_t m_RangeHeaderIndex;
+ HttpVerb m_RequestVerb;
+ bool m_KeepAlive = false;
+ bool m_Expect100Continue = false;
+ int m_RequestId = -1;
+ Oid m_SessionId{};
+ IoBuffer m_BodyBuffer;
+ uint64_t m_BodyPosition = 0;
+ http_parser m_Parser;
+ char m_HeaderBuffer[1024];
+ std::string m_NormalizedUrl;
+
+ void AppendCurrentHeader();
+
+ int OnMessageBegin();
+ int OnUrl(const char* Data, size_t Bytes);
+ int OnHeader(const char* Data, size_t Bytes);
+ int OnHeaderValue(const char* Data, size_t Bytes);
+ int OnHeadersComplete();
+ int OnBody(const char* Data, size_t Bytes);
+ int OnMessageComplete();
+
+ static HttpRequest* GetThis(http_parser* Parser) { return reinterpret_cast<HttpRequest*>(Parser->data); }
+ static http_parser_settings s_ParserSettings;
+};
+
+struct HttpResponse
+{
+public:
+ HttpResponse() = default;
+ explicit HttpResponse(HttpContentType ContentType) : m_ContentType(ContentType) {}
+
+ void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
+ {
+ m_ResponseCode = ResponseCode;
+ const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size());
+
+ m_DataBuffers.reserve(ChunkCount);
+
+ for (IoBuffer& Buffer : BlobList)
+ {
+#if 1
+ m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned();
+#else
+ IoBuffer TempBuffer = std::move(Buffer);
+ TempBuffer.MakeOwned();
+ m_DataBuffers.emplace_back(IoBufferBuilder::ReadFromFileMaybe(TempBuffer));
+#endif
+ }
+
+ uint64_t LocalDataSize = 0;
+
+ m_AsioBuffers.push_back({}); // Placeholder for header
+
+ for (IoBuffer& Buffer : m_DataBuffers)
+ {
+ uint64_t BufferDataSize = Buffer.Size();
+
+ ZEN_ASSERT(BufferDataSize);
+
+ LocalDataSize += BufferDataSize;
+
+ IoBufferFileReference FileRef;
+ if (Buffer.GetFileReference(/* out */ FileRef))
+ {
+ // TODO: Use direct file transfer, via TransmitFile/sendfile
+ //
+ // this looks like it requires some custom asio plumbing however
+
+ m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()});
+ }
+ else
+ {
+ // Send from memory
+
+ m_AsioBuffers.push_back({Buffer.Data(), Buffer.Size()});
+ }
+ }
+ m_ContentLength = LocalDataSize;
+
+ auto Headers = GetHeaders();
+ m_AsioBuffers[0] = asio::const_buffer(Headers.data(), Headers.size());
+ }
+
+ uint16_t ResponseCode() const { return m_ResponseCode; }
+ uint64_t ContentLength() const { return m_ContentLength; }
+
+ const std::vector<asio::const_buffer>& AsioBuffers() const { return m_AsioBuffers; }
+
+ std::string_view GetHeaders()
+ {
+ m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n"
+ << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n"
+ << "Content-Length: " << ContentLength() << "\r\n"sv;
+
+ if (!m_IsKeepAlive)
+ {
+ m_Headers << "Connection: close\r\n"sv;
+ }
+
+ m_Headers << "\r\n"sv;
+
+ return m_Headers;
+ }
+
+ void SuppressPayload() { m_AsioBuffers.resize(1); }
+
+private:
+ uint16_t m_ResponseCode = 0;
+ bool m_IsKeepAlive = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ uint64_t m_ContentLength = 0;
+ std::vector<IoBuffer> m_DataBuffers;
+ std::vector<asio::const_buffer> m_AsioBuffers;
+ ExtendableStringBuilder<160> m_Headers;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpServerConnection : std::enable_shared_from_this<HttpServerConnection>
+{
+ HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket);
+ ~HttpServerConnection();
+
+ void HandleNewRequest();
+ void TerminateConnection();
+ void HandleRequest();
+
+ std::shared_ptr<HttpServerConnection> AsSharedPtr() { return shared_from_this(); }
+
+private:
+ enum class RequestState
+ {
+ kInitialState,
+ kInitialRead,
+ kReadingMore,
+ kWriting,
+ kWritingFinal,
+ kDone,
+ kTerminated
+ };
+
+ RequestState m_RequestState = RequestState::kInitialState;
+ HttpRequest m_RequestData{*this};
+
+ void EnqueueRead();
+ void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
+ void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount, bool Pop = false);
+ void OnError();
+
+ HttpAsioServerImpl& m_Server;
+ asio::streambuf m_RequestBuffer;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::atomic<uint32_t> m_RequestCounter{0};
+ uint32_t m_ConnectionId = 0;
+ Ref<IHttpPackageHandler> m_PackageHandler;
+
+ RwLock m_ResponsesLock;
+ std::deque<std::unique_ptr<HttpResponse>> m_Responses;
+};
+
+std::atomic<uint32_t> g_ConnectionIdCounter{0};
+
+HttpServerConnection::HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket)
+: m_Server(Server)
+, m_Socket(std::move(Socket))
+, m_ConnectionId(g_ConnectionIdCounter.fetch_add(1))
+{
+ ZEN_TRACE_VERBOSE("new connection #{}", m_ConnectionId);
+}
+
+HttpServerConnection::~HttpServerConnection()
+{
+ ZEN_TRACE_VERBOSE("destroying connection #{}", m_ConnectionId);
+}
+
+void
+HttpServerConnection::HandleNewRequest()
+{
+ m_RequestData.Initialize();
+
+ EnqueueRead();
+}
+
+void
+HttpServerConnection::TerminateConnection()
+{
+ m_RequestState = RequestState::kTerminated;
+
+ std::error_code Ec;
+ m_Socket->close(Ec);
+}
+
+void
+HttpServerConnection::EnqueueRead()
+{
+ if (m_RequestState == RequestState::kInitialRead)
+ {
+ m_RequestState = RequestState::kReadingMore;
+ }
+ else
+ {
+ m_RequestState = RequestState::kInitialRead;
+ }
+
+ m_RequestBuffer.prepare(64 * 1024);
+
+ asio::async_read(*m_Socket.get(),
+ m_RequestBuffer,
+ asio::transfer_at_least(1),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); });
+}
+
+void
+HttpServerConnection::OnDataReceived(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount)
+{
+ if (Ec)
+ {
+ if (m_RequestState == RequestState::kDone || m_RequestState == RequestState::kInitialRead)
+ {
+ ZEN_TRACE_VERBOSE("on data received ERROR (EXPECTED), connection '{}' reason '{}'", m_ConnectionId, Ec.message());
+ return;
+ }
+ else
+ {
+ ZEN_WARN("on data received ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message());
+ return OnError();
+ }
+ }
+
+ ZEN_TRACE_VERBOSE("on data received, connection '{}', request '{}', thread '{}', bytes '{}'",
+ m_ConnectionId,
+ m_RequestCounter.load(std::memory_order_relaxed),
+ zen::GetCurrentThreadId(),
+ NiceBytes(ByteCount));
+
+ while (m_RequestBuffer.size())
+ {
+ const asio::const_buffer& InputBuffer = m_RequestBuffer.data();
+
+ size_t Result = m_RequestData.ConsumeData((const char*)InputBuffer.data(), InputBuffer.size());
+ if (Result == ~0ull)
+ {
+ return OnError();
+ }
+
+ m_RequestBuffer.consume(Result);
+ }
+
+ switch (m_RequestState)
+ {
+ case RequestState::kDone:
+ case RequestState::kWritingFinal:
+ case RequestState::kTerminated:
+ break;
+
+ default:
+ EnqueueRead();
+ break;
+ }
+}
+
+void
+HttpServerConnection::OnResponseDataSent(const asio::error_code& Ec, [[maybe_unused]] std::size_t ByteCount, bool Pop)
+{
+ if (Ec)
+ {
+ ZEN_WARN("on data sent ERROR, connection '{}' reason '{}'", m_ConnectionId, Ec.message());
+ OnError();
+ }
+ else
+ {
+ ZEN_TRACE_VERBOSE("on data sent, connection '{}', request '{}', thread '{}', bytes '{}'",
+ m_ConnectionId,
+ m_RequestCounter.load(std::memory_order_relaxed),
+ zen::GetCurrentThreadId(),
+ NiceBytes(ByteCount));
+
+ if (!m_RequestData.IsKeepAlive())
+ {
+ m_RequestState = RequestState::kDone;
+
+ m_Socket->close();
+ }
+ else
+ {
+ if (Pop)
+ {
+ RwLock::ExclusiveLockScope _(m_ResponsesLock);
+ m_Responses.pop_front();
+ }
+
+ m_RequestCounter.fetch_add(1);
+ }
+ }
+}
+
+void
+HttpServerConnection::OnError()
+{
+ m_Socket->close();
+}
+
+void
+HttpServerConnection::HandleRequest()
+{
+ if (!m_RequestData.IsKeepAlive())
+ {
+ m_RequestState = RequestState::kWritingFinal;
+
+ std::error_code Ec;
+ m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("socket shutdown ERROR, reason '{}'", Ec.message());
+ }
+ }
+ else
+ {
+ m_RequestState = RequestState::kWriting;
+ }
+
+ if (HttpService* Service = m_Server.RouteRequest(m_RequestData.Url()))
+ {
+ HttpAsioServerRequest Request(m_RequestData, *Service, m_RequestData.Body());
+
+ ZEN_TRACE_VERBOSE("handle request, connection '{}' request '{}'", m_ConnectionId, m_RequestCounter.load(std::memory_order_relaxed));
+
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ {
+ try
+ {
+ Service->HandleRequest(Request);
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_ERROR("Caught exception while handling request: '{}'", ex.what());
+
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
+ }
+
+ if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response))
+ {
+ // Transmit the response
+
+ if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ {
+ Response->SuppressPayload();
+ }
+
+ auto ResponseBuffers = Response->AsioBuffers();
+
+ uint64_t ResponseLength = 0;
+
+ for (auto& Buffer : ResponseBuffers)
+ {
+ ResponseLength += Buffer.size();
+ }
+
+ {
+ RwLock::ExclusiveLockScope _(m_ResponsesLock);
+ m_Responses.push_back(std::move(Response));
+ }
+
+ // TODO: should cork/uncork for Linux?
+
+ asio::async_write(*m_Socket.get(),
+ ResponseBuffers,
+ asio::transfer_exactly(ResponseLength),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) {
+ Conn->OnResponseDataSent(Ec, ByteCount, true);
+ });
+
+ return;
+ }
+ }
+
+ if (m_RequestData.RequestVerb() == HttpVerb::kHead)
+ {
+ std::string_view Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "\r\n"sv;
+
+ if (!m_RequestData.IsKeepAlive())
+ {
+ Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Connection: close\r\n"
+ "\r\n"sv;
+ }
+
+ asio::async_write(
+ *m_Socket.get(),
+ asio::buffer(Response),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); });
+ }
+ else
+ {
+ std::string_view Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Content-Length: 23\r\n"
+ "Content-Type: text/plain\r\n"
+ "\r\n"
+ "No suitable route found"sv;
+
+ if (!m_RequestData.IsKeepAlive())
+ {
+ Response =
+ "HTTP/1.1 404 NOT FOUND\r\n"
+ "Content-Length: 23\r\n"
+ "Content-Type: text/plain\r\n"
+ "Connection: close\r\n"
+ "\r\n"
+ "No suitable route found"sv;
+ }
+
+ asio::async_write(
+ *m_Socket.get(),
+ asio::buffer(Response),
+ [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnResponseDataSent(Ec, ByteCount); });
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// HttpRequest
+//
+
+http_parser_settings HttpRequest::s_ParserSettings{
+ .on_message_begin = [](http_parser* p) { return GetThis(p)->OnMessageBegin(); },
+ .on_url = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnUrl(Data, ByteCount); },
+ .on_status =
+ [](http_parser* p, const char* Data, size_t ByteCount) {
+ ZEN_UNUSED(p, Data, ByteCount);
+ return 0;
+ },
+ .on_header_field = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeader(Data, ByteCount); },
+ .on_header_value = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnHeaderValue(Data, ByteCount); },
+ .on_headers_complete = [](http_parser* p) { return GetThis(p)->OnHeadersComplete(); },
+ .on_body = [](http_parser* p, const char* Data, size_t ByteCount) { return GetThis(p)->OnBody(Data, ByteCount); },
+ .on_message_complete = [](http_parser* p) { return GetThis(p)->OnMessageComplete(); },
+ .on_chunk_header{},
+ .on_chunk_complete{}};
+
+void
+HttpRequest::Initialize()
+{
+ http_parser_init(&m_Parser, HTTP_REQUEST);
+ m_Parser.data = this;
+
+ ResetState();
+}
+
+size_t
+HttpRequest::ConsumeData(const char* InputData, size_t DataSize)
+{
+ const size_t ConsumedBytes = http_parser_execute(&m_Parser, &s_ParserSettings, InputData, DataSize);
+
+ http_errno HttpErrno = HTTP_PARSER_ERRNO((&m_Parser));
+
+ if (HttpErrno && HttpErrno != HPE_INVALID_EOF_STATE)
+ {
+ ZEN_WARN("HTTP parser error {} ('{}'). Closing connection", http_errno_name(HttpErrno), http_errno_description(HttpErrno));
+ return ~0ull;
+ }
+
+ return ConsumedBytes;
+}
+
+int
+HttpRequest::OnUrl(const char* Data, size_t Bytes)
+{
+ if (!m_Url)
+ {
+ ZEN_ASSERT_SLOW(m_UrlLength == 0);
+ m_Url = m_HeaderCursor;
+ }
+
+ const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
+
+ if (RemainingBufferSpace < Bytes)
+ {
+ ZEN_WARN("HTTP parser does not have enough space for incoming request, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
+ }
+
+ memcpy(m_HeaderCursor, Data, Bytes);
+ m_HeaderCursor += Bytes;
+ m_UrlLength += Bytes;
+
+ return 0;
+}
+
+int
+HttpRequest::OnHeader(const char* Data, size_t Bytes)
+{
+ if (m_CurrentHeaderValueLength)
+ {
+ AppendCurrentHeader();
+
+ m_CurrentHeaderNameLength = 0;
+ m_CurrentHeaderValueLength = 0;
+ m_CurrentHeaderName = m_HeaderCursor;
+ }
+ else if (m_CurrentHeaderName == nullptr)
+ {
+ m_CurrentHeaderName = m_HeaderCursor;
+ }
+
+ const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
+ if (RemainingBufferSpace < Bytes)
+ {
+ ZEN_WARN("HTTP parser does not have enough space for incoming header name, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
+ }
+
+ memcpy(m_HeaderCursor, Data, Bytes);
+ m_HeaderCursor += Bytes;
+ m_CurrentHeaderNameLength += Bytes;
+
+ return 0;
+}
+
+void
+HttpRequest::AppendCurrentHeader()
+{
+ std::string_view HeaderName(m_CurrentHeaderName, m_CurrentHeaderNameLength);
+ std::string_view HeaderValue(m_CurrentHeaderValue, m_CurrentHeaderValueLength);
+
+ const uint32_t HeaderHash = HashStringAsLowerDjb2(HeaderName);
+
+ if (HeaderHash == HashContentLength)
+ {
+ m_ContentLengthHeaderIndex = (int8_t)m_Headers.size();
+ }
+ else if (HeaderHash == HashAccept)
+ {
+ m_AcceptHeaderIndex = (int8_t)m_Headers.size();
+ }
+ else if (HeaderHash == HashContentType)
+ {
+ m_ContentTypeHeaderIndex = (int8_t)m_Headers.size();
+ }
+ else if (HeaderHash == HashSession)
+ {
+ m_SessionId = Oid::FromHexString(HeaderValue);
+ }
+ else if (HeaderHash == HashRequest)
+ {
+ std::from_chars(HeaderValue.data(), HeaderValue.data() + HeaderValue.size(), m_RequestId);
+ }
+ else if (HeaderHash == HashExpect)
+ {
+ if (HeaderValue == "100-continue"sv)
+ {
+ // We don't currently do anything with this
+ m_Expect100Continue = true;
+ }
+ else
+ {
+ ZEN_INFO("Unexpected expect - Expect: {}", HeaderValue);
+ }
+ }
+ else if (HeaderHash == HashRange)
+ {
+ m_RangeHeaderIndex = (int8_t)m_Headers.size();
+ }
+
+ m_Headers.emplace_back(HeaderName, HeaderValue);
+}
+
+int
+HttpRequest::OnHeaderValue(const char* Data, size_t Bytes)
+{
+ if (m_CurrentHeaderValueLength == 0)
+ {
+ m_CurrentHeaderValue = m_HeaderCursor;
+ }
+
+ const size_t RemainingBufferSpace = sizeof m_HeaderBuffer + m_HeaderBuffer - m_HeaderCursor;
+ if (RemainingBufferSpace < Bytes)
+ {
+ ZEN_WARN("HTTP parser does not have enough space for incoming header value, need {} more bytes", Bytes - RemainingBufferSpace);
+ return 1;
+ }
+
+ memcpy(m_HeaderCursor, Data, Bytes);
+ m_HeaderCursor += Bytes;
+ m_CurrentHeaderValueLength += Bytes;
+
+ return 0;
+}
+
+static void
+NormalizeUrlPath(const char* Url, size_t UrlLength, std::string& NormalizedUrl)
+{
+ bool LastCharWasSeparator = false;
+ for (std::string_view::size_type UrlIndex = 0; UrlIndex < UrlLength; ++UrlIndex)
+ {
+ const char UrlChar = Url[UrlIndex];
+ const bool IsSeparator = (UrlChar == '/');
+
+ if (IsSeparator && LastCharWasSeparator)
+ {
+ if (NormalizedUrl.empty())
+ {
+ NormalizedUrl.reserve(UrlLength);
+ NormalizedUrl.append(Url, UrlIndex);
+ }
+
+ if (!LastCharWasSeparator)
+ {
+ NormalizedUrl.push_back('/');
+ }
+ }
+ else if (!NormalizedUrl.empty())
+ {
+ NormalizedUrl.push_back(UrlChar);
+ }
+
+ LastCharWasSeparator = IsSeparator;
+ }
+}
+
+int
+HttpRequest::OnHeadersComplete()
+{
+ if (m_CurrentHeaderValueLength)
+ {
+ AppendCurrentHeader();
+ }
+
+ if (m_ContentLengthHeaderIndex >= 0)
+ {
+ std::string_view& Value = m_Headers[m_ContentLengthHeaderIndex].Value;
+ uint64_t ContentLength = 0;
+ std::from_chars(Value.data(), Value.data() + Value.size(), ContentLength);
+
+ if (ContentLength)
+ {
+ m_BodyBuffer = IoBuffer(ContentLength);
+ }
+
+ m_BodyBuffer.SetContentType(ContentType());
+
+ m_BodyPosition = 0;
+ }
+
+ m_KeepAlive = !!http_should_keep_alive(&m_Parser);
+
+ switch (m_Parser.method)
+ {
+ case HTTP_GET:
+ m_RequestVerb = HttpVerb::kGet;
+ break;
+
+ case HTTP_POST:
+ m_RequestVerb = HttpVerb::kPost;
+ break;
+
+ case HTTP_PUT:
+ m_RequestVerb = HttpVerb::kPut;
+ break;
+
+ case HTTP_DELETE:
+ m_RequestVerb = HttpVerb::kDelete;
+ break;
+
+ case HTTP_HEAD:
+ m_RequestVerb = HttpVerb::kHead;
+ break;
+
+ case HTTP_COPY:
+ m_RequestVerb = HttpVerb::kCopy;
+ break;
+
+ case HTTP_OPTIONS:
+ m_RequestVerb = HttpVerb::kOptions;
+ break;
+
+ default:
+ ZEN_WARN("invalid HTTP method: '{}'", http_method_str((http_method)m_Parser.method));
+ break;
+ }
+
+ std::string_view Url(m_Url, m_UrlLength);
+
+ if (auto QuerySplit = Url.find_first_of('?'); QuerySplit != std::string_view::npos)
+ {
+ m_UrlLength = QuerySplit;
+ m_QueryString = m_Url + QuerySplit + 1;
+ m_QueryLength = Url.size() - QuerySplit - 1;
+ }
+
+ NormalizeUrlPath(m_Url, m_UrlLength, m_NormalizedUrl);
+
+ return 0;
+}
+
+int
+HttpRequest::OnBody(const char* Data, size_t Bytes)
+{
+ if (m_BodyPosition + Bytes > m_BodyBuffer.Size())
+ {
+ ZEN_WARN("HTTP parser incoming body is larger than content size, need {} more bytes",
+ (m_BodyPosition + Bytes) - m_BodyBuffer.Size());
+ return 1;
+ }
+ memcpy(reinterpret_cast<uint8_t*>(m_BodyBuffer.MutableData()) + m_BodyPosition, Data, Bytes);
+ m_BodyPosition += Bytes;
+
+ if (http_body_is_final(&m_Parser))
+ {
+ if (m_BodyPosition != m_BodyBuffer.Size())
+ {
+ ZEN_WARN("Body mismatch! {} != {}", m_BodyPosition, m_BodyBuffer.Size());
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+void
+HttpRequest::ResetState()
+{
+ m_HeaderCursor = m_HeaderBuffer;
+ m_CurrentHeaderName = nullptr;
+ m_CurrentHeaderNameLength = 0;
+ m_CurrentHeaderValue = nullptr;
+ m_CurrentHeaderValueLength = 0;
+ m_CurrentHeaderName = nullptr;
+ m_Url = nullptr;
+ m_UrlLength = 0;
+ m_QueryString = nullptr;
+ m_QueryLength = 0;
+ m_ContentLengthHeaderIndex = -1;
+ m_AcceptHeaderIndex = -1;
+ m_ContentTypeHeaderIndex = -1;
+ m_RangeHeaderIndex = -1;
+ m_Expect100Continue = false;
+ m_BodyBuffer = {};
+ m_BodyPosition = 0;
+ m_Headers.clear();
+ m_NormalizedUrl.clear();
+}
+
+int
+HttpRequest::OnMessageBegin()
+{
+ return 0;
+}
+
+int
+HttpRequest::OnMessageComplete()
+{
+ m_Connection.HandleRequest();
+
+ ResetState();
+
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpAcceptor
+{
+ HttpAcceptor(HttpAsioServerImpl& Server, asio::io_service& IoService, uint16_t BasePort)
+ : m_Server(Server)
+ , m_IoService(IoService)
+ , m_Acceptor(m_IoService, asio::ip::tcp::v6())
+ {
+ m_Acceptor.set_option(asio::ip::v6_only(false));
+#if ZEN_PLATFORM_WINDOWS
+ // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms
+ typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> excluse_address;
+ m_Acceptor.set_option(excluse_address(true));
+#else // ZEN_PLATFORM_WINDOWS
+ m_Acceptor.set_option(asio::socket_base::reuse_address(false));
+#endif // ZEN_PLATFORM_WINDOWS
+
+ m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
+ m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ uint16_t EffectivePort = BasePort;
+
+ asio::error_code BindErrorCode;
+ m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
+ // Sharing violation implies the port is being used by another process
+ for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset)
+ {
+ EffectivePort = BasePort + (PortOffset * 100);
+ m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
+ }
+ if (BindErrorCode == asio::error::access_denied)
+ {
+ EffectivePort = 0;
+ m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
+ }
+ if (BindErrorCode)
+ {
+ ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message());
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor.native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
+#endif
+ m_Acceptor.listen();
+
+ ZEN_INFO("Started asio server at port '{}'", EffectivePort);
+ }
+
+ void Start()
+ {
+ m_Acceptor.listen();
+ InitAccept();
+ }
+
+ void Stop() { m_IsStopped = true; }
+
+ void InitAccept()
+ {
+ auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService);
+ asio::ip::tcp::socket& SocketRef = *SocketPtr.get();
+
+ m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
+ if (Ec)
+ {
+ ZEN_WARN("asio async_accept, connection failed to '{}:{}' reason '{}'",
+ m_Acceptor.local_endpoint().address().to_string(),
+ m_Acceptor.local_endpoint().port(),
+ Ec.message());
+ }
+ else
+ {
+ // New connection established, pass socket ownership into connection object
+ // and initiate request handling loop. The connection lifetime is
+ // managed by the async read/write loop by passing the shared
+ // reference to the callbacks.
+
+ Socket->set_option(asio::ip::tcp::no_delay(true));
+ Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
+ Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));
+
+ auto Conn = std::make_shared<HttpServerConnection>(m_Server, std::move(Socket));
+ Conn->HandleNewRequest();
+ }
+
+ if (!m_IsStopped.load())
+ {
+ InitAccept();
+ }
+ else
+ {
+ m_Acceptor.close();
+ }
+ });
+ }
+
+ int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }
+
+private:
+ HttpAsioServerImpl& m_Server;
+ asio::io_service& m_IoService;
+ asio::ip::tcp::acceptor m_Acceptor;
+ std::atomic<bool> m_IsStopped{false};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpAsioServerRequest::HttpAsioServerRequest(asio_http::HttpRequest& Request, HttpService& Service, IoBuffer PayloadBuffer)
+: m_Request(Request)
+, m_PayloadBuffer(std::move(PayloadBuffer))
+{
+ const int PrefixLength = Service.UriPrefixLength();
+
+ std::string_view Uri = Request.Url();
+ Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size())));
+ m_Uri = Uri;
+ m_UriWithExtension = Uri;
+ m_QueryString = Request.QueryString();
+
+ m_Verb = Request.RequestVerb();
+ m_ContentLength = Request.Body().Size();
+ m_ContentType = Request.ContentType();
+
+ HttpContentType AcceptContentType = HttpContentType::kUnknownContentType;
+
+ // Parse any extension, to allow requesting a particular response encoding via the URL
+
+ {
+ std::string_view UriSuffix8{m_Uri};
+
+ const size_t LastComponentIndex = UriSuffix8.find_last_of('/');
+
+ if (LastComponentIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastComponentIndex);
+ }
+
+ const size_t LastDotIndex = UriSuffix8.find_last_of('.');
+
+ if (LastDotIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastDotIndex + 1);
+
+ AcceptContentType = ParseContentType(UriSuffix8);
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1));
+ }
+ }
+ }
+
+ // It an explicit content type extension was specified then we'll use that over any
+ // Accept: header value that may be present
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_AcceptType = AcceptContentType;
+ }
+ else
+ {
+ m_AcceptType = Request.AcceptType();
+ }
+}
+
+HttpAsioServerRequest::~HttpAsioServerRequest()
+{
+}
+
+Oid
+HttpAsioServerRequest::ParseSessionId() const
+{
+ return m_Request.SessionId();
+}
+
+uint32_t
+HttpAsioServerRequest::ParseRequestId() const
+{
+ return m_Request.RequestId();
+}
+
+IoBuffer
+HttpAsioServerRequest::ReadPayload()
+{
+ return m_PayloadBuffer;
+}
+
+void
+HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode)
+{
+ ZEN_ASSERT(!m_Response);
+
+ m_Response.reset(new HttpResponse(HttpContentType::kBinary));
+ std::array<IoBuffer, 0> Empty;
+
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty);
+}
+
+void
+HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
+{
+ ZEN_ASSERT(!m_Response);
+
+ m_Response.reset(new HttpResponse(ContentType));
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs);
+}
+
+void
+HttpAsioServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
+{
+ ZEN_ASSERT(!m_Response);
+ m_Response.reset(new HttpResponse(ContentType));
+
+ IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size());
+ std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
+
+ m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList);
+}
+
+void
+HttpAsioServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
+{
+ ZEN_ASSERT(!m_Response);
+
+ // Not one bit async, innit
+ ContinuationHandler(*this);
+}
+
+bool
+HttpAsioServerRequest::TryGetRanges(HttpRanges& Ranges)
+{
+ return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpAsioServerImpl::HttpAsioServerImpl()
+{
+}
+
+HttpAsioServerImpl::~HttpAsioServerImpl()
+{
+}
+
+int
+HttpAsioServerImpl::Start(uint16_t Port, int ThreadCount)
+{
+ ZEN_ASSERT(ThreadCount > 0);
+
+ ZEN_INFO("starting asio http with {} service threads", ThreadCount);
+
+ m_Acceptor.reset(new asio_http::HttpAcceptor(*this, m_IoService, Port));
+ m_Acceptor->Start();
+
+ for (int i = 0; i < ThreadCount; ++i)
+ {
+ m_ThreadPool.emplace_back([this, Index = i + 1] {
+ SetCurrentThreadName(fmt::format("asio worker {}", Index));
+
+ try
+ {
+ m_IoService.run();
+ }
+ catch (std::exception& e)
+ {
+ ZEN_ERROR("Exception caught in asio event loop: '{}'", e.what());
+ }
+ });
+ }
+
+ ZEN_INFO("asio http started (port {})", m_Acceptor->GetAcceptPort());
+
+ return m_Acceptor->GetAcceptPort();
+}
+
+void
+HttpAsioServerImpl::Stop()
+{
+ m_Acceptor->Stop();
+ m_IoService.stop();
+ for (auto& Thread : m_ThreadPool)
+ {
+ Thread.join();
+ }
+}
+
+void
+HttpAsioServerImpl::RegisterService(const char* InUrlPath, HttpService& Service)
+{
+ std::string_view UrlPath(InUrlPath);
+ Service.SetUriPrefixLength(UrlPath.size());
+ if (!UrlPath.empty() && UrlPath.back() == '/')
+ {
+ UrlPath.remove_suffix(1);
+ }
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_UriHandlers.push_back({std::string(UrlPath), &Service});
+}
+
+HttpService*
+HttpAsioServerImpl::RouteRequest(std::string_view Url)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ HttpService* CandidateService = nullptr;
+ std::string::size_type CandidateMatchSize = 0;
+ for (const ServiceEntry& SvcEntry : m_UriHandlers)
+ {
+ const std::string& SvcUrl = SvcEntry.ServiceUrlPath;
+ const std::string::size_type SvcUrlSize = SvcUrl.size();
+ if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 &&
+ ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/')))
+ {
+ CandidateMatchSize = SvcUrl.size();
+ CandidateService = SvcEntry.Service;
+ }
+ }
+
+ return CandidateService;
+}
+
+} // namespace zen::asio_http
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+HttpAsioServer::HttpAsioServer() : m_Impl(std::make_unique<asio_http::HttpAsioServerImpl>())
+{
+ ZEN_DEBUG("Request object size: {} ({:#x})", sizeof(asio_http::HttpRequest), sizeof(asio_http::HttpRequest));
+}
+
+HttpAsioServer::~HttpAsioServer()
+{
+ try
+ {
+ m_Impl->Stop();
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_WARN("Caught exception stopping http asio server: {}", ex.what());
+ }
+}
+
+void
+HttpAsioServer::RegisterService(HttpService& Service)
+{
+ m_Impl->RegisterService(Service.BaseUri(), Service);
+}
+
+int
+HttpAsioServer::Initialize(int BasePort)
+{
+ m_BasePort = m_Impl->Start(gsl::narrow<uint16_t>(BasePort), Max(std::thread::hardware_concurrency(), 8u));
+ return m_BasePort;
+}
+
+void
+HttpAsioServer::Run(bool IsInteractive)
+{
+ const bool TestMode = !IsInteractive;
+
+ int WaitTimeout = -1;
+ if (!TestMode)
+ {
+ WaitTimeout = 1000;
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Press ESC or Q to quit");
+ }
+
+ do
+ {
+ if (!TestMode && _kbhit() != 0)
+ {
+ char c = (char)_getch();
+
+ if (c == 27 || c == 'Q' || c == 'q')
+ {
+ RequestApplicationExit(0);
+ }
+ }
+
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+#else
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (asio HTTP). Ctrl-C to quit");
+ }
+
+ do
+ {
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+#endif
+}
+
+void
+HttpAsioServer::RequestExit()
+{
+ m_ShutdownEvent.Set();
+}
+
+} // namespace zen
diff --git a/src/zenhttp/httpasio.h b/src/zenhttp/httpasio.h
new file mode 100644
index 000000000..716145955
--- /dev/null
+++ b/src/zenhttp/httpasio.h
@@ -0,0 +1,36 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/thread.h>
+#include <zenhttp/httpserver.h>
+
+#include <memory>
+
+namespace zen {
+
+namespace asio_http {
+ struct HttpServerConnection;
+ struct HttpAcceptor;
+ struct HttpAsioServerImpl;
+} // namespace asio_http
+
+class HttpAsioServer : public HttpServer
+{
+public:
+ HttpAsioServer();
+ ~HttpAsioServer();
+
+ virtual void RegisterService(HttpService& Service) override;
+ virtual int Initialize(int BasePort) override;
+ virtual void Run(bool IsInteractiveSession) override;
+ virtual void RequestExit() override;
+
+private:
+ Event m_ShutdownEvent;
+ int m_BasePort = 0;
+
+ std::unique_ptr<asio_http::HttpAsioServerImpl> m_Impl;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
new file mode 100644
index 000000000..e6813d407
--- /dev/null
+++ b/src/zenhttp/httpclient.cpp
@@ -0,0 +1,176 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpclient.h>
+#include <zenhttp/httpserver.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/session.h>
+#include <zencore/sharedbuffer.h>
+#include <zencore/stream.h>
+#include <zencore/testing.h>
+#include <zenhttp/httpshared.h>
+
+static std::atomic<uint32_t> HttpClientRequestIdCounter{0};
+
+namespace zen {
+
+using namespace std::literals;
+
+HttpClient::Response
+FromCprResponse(cpr::Response& InResponse)
+{
+ return {.StatusCode = int(InResponse.status_code)};
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpClient::HttpClient(std::string_view BaseUri) : m_BaseUri(BaseUri)
+{
+ StringBuilder<32> SessionId;
+ GetSessionId().ToString(SessionId);
+ m_SessionId = SessionId;
+}
+
+HttpClient::~HttpClient()
+{
+}
+
+HttpClient::Response
+HttpClient::TransactPackage(std::string_view Url, CbPackage Package)
+{
+ cpr::Session Sess;
+ Sess.SetUrl(m_BaseUri + std::string(Url));
+
+ // First, list of offered chunks for filtering on the server end
+
+ std::vector<IoHash> AttachmentsToSend;
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+
+ const uint32_t RequestId = ++HttpClientRequestIdCounter;
+ auto RequestIdString = fmt::to_string(RequestId);
+
+ if (Attachments.empty() == false)
+ {
+ CbObjectWriter Writer;
+ Writer.BeginArray("offer");
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ IoHash Hash = Attachment.GetHash();
+
+ Writer.AddHash(Hash);
+ }
+
+ Writer.EndArray();
+
+ BinaryWriter MemWriter;
+ Writer.Save(MemWriter);
+
+ Sess.SetHeader({{"Content-Type", "application/x-ue-offer"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}});
+ Sess.SetBody(cpr::Body{(const char*)MemWriter.Data(), MemWriter.Size()});
+
+ cpr::Response FilterResponse = Sess.Post();
+
+ if (FilterResponse.status_code == 200)
+ {
+ IoBuffer ResponseBuffer(IoBuffer::Wrap, FilterResponse.text.data(), FilterResponse.text.size());
+ CbObject ResponseObject = LoadCompactBinaryObject(ResponseBuffer);
+
+ for (auto& Entry : ResponseObject["need"])
+ {
+ ZEN_ASSERT(Entry.IsHash());
+ AttachmentsToSend.push_back(Entry.AsHash());
+ }
+ }
+ }
+
+ // Prepare package for send
+
+ CbPackage SendPackage;
+ SendPackage.SetObject(Package.GetObject(), Package.GetObjectHash());
+
+ for (const IoHash& AttachmentCid : AttachmentsToSend)
+ {
+ const CbAttachment* Attachment = Package.FindAttachment(AttachmentCid);
+
+ if (Attachment)
+ {
+ SendPackage.AddAttachment(*Attachment);
+ }
+ else
+ {
+ // This should be an error -- server asked to have something we can't find
+ }
+ }
+
+ // Transmit package payload
+
+ CompositeBuffer Message = FormatPackageMessageBuffer(SendPackage);
+ SharedBuffer FlatMessage = Message.Flatten();
+
+ Sess.SetHeader({{"Content-Type", "application/x-ue-cbpkg"}, {"UE-Session", m_SessionId}, {"UE-Request", RequestIdString}});
+ Sess.SetBody(cpr::Body{(const char*)FlatMessage.GetData(), FlatMessage.GetSize()});
+
+ cpr::Response FilterResponse = Sess.Post();
+
+ if (!IsHttpSuccessCode(FilterResponse.status_code))
+ {
+ return FromCprResponse(FilterResponse);
+ }
+
+ IoBuffer ResponseBuffer(IoBuffer::Clone, FilterResponse.text.data(), FilterResponse.text.size());
+
+ if (auto It = FilterResponse.header.find("Content-Type"); It != FilterResponse.header.end())
+ {
+ HttpContentType ContentType = ParseContentType(It->second);
+
+ ResponseBuffer.SetContentType(ContentType);
+ }
+
+ return {.StatusCode = int(FilterResponse.status_code), .ResponsePayload = ResponseBuffer};
+}
+
+HttpClient::Response
+HttpClient::Put(std::string_view Url, IoBuffer Payload)
+{
+ ZEN_UNUSED(Url);
+ ZEN_UNUSED(Payload);
+ return {};
+}
+
+HttpClient::Response
+HttpClient::Get(std::string_view Url)
+{
+ ZEN_UNUSED(Url);
+ return {};
+}
+
+HttpClient::Response
+HttpClient::Delete(std::string_view Url)
+{
+ ZEN_UNUSED(Url);
+ return {};
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("httpclient")
+{
+ using namespace std::literals;
+
+ SUBCASE("client") {}
+}
+
+void
+httpclient_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenhttp/httpnull.cpp b/src/zenhttp/httpnull.cpp
new file mode 100644
index 000000000..a6e1d3567
--- /dev/null
+++ b/src/zenhttp/httpnull.cpp
@@ -0,0 +1,83 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpnull.h"
+
+#include <zencore/logging.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <conio.h>
+#endif
+
+namespace zen {
+
+HttpNullServer::HttpNullServer()
+{
+}
+
+HttpNullServer::~HttpNullServer()
+{
+}
+
+void
+HttpNullServer::RegisterService(HttpService& Service)
+{
+ ZEN_UNUSED(Service);
+}
+
+int
+HttpNullServer::Initialize(int BasePort)
+{
+ return BasePort;
+}
+
+void
+HttpNullServer::Run(bool IsInteractiveSession)
+{
+ const bool TestMode = !IsInteractiveSession;
+
+ int WaitTimeout = -1;
+ if (!TestMode)
+ {
+ WaitTimeout = 1000;
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Press ESC or Q to quit");
+ }
+
+ do
+ {
+ if (!TestMode && _kbhit() != 0)
+ {
+ char c = (char)_getch();
+
+ if (c == 27 || c == 'Q' || c == 'q')
+ {
+ RequestApplicationExit(0);
+ }
+ }
+
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+#else
+ if (TestMode == false)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running (null HTTP). Ctrl-C to quit");
+ }
+
+ do
+ {
+ m_ShutdownEvent.Wait(WaitTimeout);
+ } while (!IsApplicationExitRequested());
+#endif
+}
+
+void
+HttpNullServer::RequestExit()
+{
+ m_ShutdownEvent.Set();
+}
+
+} // namespace zen
diff --git a/src/zenhttp/httpnull.h b/src/zenhttp/httpnull.h
new file mode 100644
index 000000000..74f021f6b
--- /dev/null
+++ b/src/zenhttp/httpnull.h
@@ -0,0 +1,29 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/thread.h>
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+/**
+ * @brief Null implementation of "http" server. Does nothing
+ */
+
+class HttpNullServer : public HttpServer
+{
+public:
+ HttpNullServer();
+ ~HttpNullServer();
+
+ virtual void RegisterService(HttpService& Service) override;
+ virtual int Initialize(int BasePort) override;
+ virtual void Run(bool IsInteractiveSession) override;
+ virtual void RequestExit() override;
+
+private:
+ Event m_ShutdownEvent;
+};
+
+} // namespace zen
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
new file mode 100644
index 000000000..671cbd319
--- /dev/null
+++ b/src/zenhttp/httpserver.cpp
@@ -0,0 +1,885 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpserver.h>
+
+#include "httpasio.h"
+#include "httpnull.h"
+#include "httpsys.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/refcount.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/thread.h>
+#include <zenhttp/httpshared.h>
+
+#include <charconv>
+#include <mutex>
+#include <span>
+#include <string_view>
+
+namespace zen {
+
+using namespace std::literals;
+
+std::string_view
+MapContentTypeToString(HttpContentType ContentType)
+{
+ switch (ContentType)
+ {
+ default:
+ case HttpContentType::kUnknownContentType:
+ case HttpContentType::kBinary:
+ return "application/octet-stream"sv;
+
+ case HttpContentType::kText:
+ return "text/plain"sv;
+
+ case HttpContentType::kJSON:
+ return "application/json"sv;
+
+ case HttpContentType::kCbObject:
+ return "application/x-ue-cb"sv;
+
+ case HttpContentType::kCbPackage:
+ return "application/x-ue-cbpkg"sv;
+
+ case HttpContentType::kCbPackageOffer:
+ return "application/x-ue-offer"sv;
+
+ case HttpContentType::kCompressedBinary:
+ return "application/x-ue-comp"sv;
+
+ case HttpContentType::kYAML:
+ return "text/yaml"sv;
+
+ case HttpContentType::kHTML:
+ return "text/html"sv;
+
+ case HttpContentType::kJavaScript:
+ return "application/javascript"sv;
+
+ case HttpContentType::kCSS:
+ return "text/css"sv;
+
+ case HttpContentType::kPNG:
+ return "image/png"sv;
+
+ case HttpContentType::kIcon:
+ return "image/x-icon"sv;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Note that in addition to MIME types we accept abbreviated versions, for
+// use in suffix parsing as well as for convenience when using curl
+
+static constinit uint32_t HashBinary = HashStringDjb2("application/octet-stream"sv);
+static constinit uint32_t HashJson = HashStringDjb2("json"sv);
+static constinit uint32_t HashApplicationJson = HashStringDjb2("application/json"sv);
+static constinit uint32_t HashYaml = HashStringDjb2("yaml"sv);
+static constinit uint32_t HashTextYaml = HashStringDjb2("text/yaml"sv);
+static constinit uint32_t HashText = HashStringDjb2("text/plain"sv);
+static constinit uint32_t HashApplicationCompactBinary = HashStringDjb2("application/x-ue-cb"sv);
+static constinit uint32_t HashCompactBinary = HashStringDjb2("ucb"sv);
+static constinit uint32_t HashCompactBinaryPackage = HashStringDjb2("application/x-ue-cbpkg"sv);
+static constinit uint32_t HashCompactBinaryPackageShort = HashStringDjb2("cbpkg"sv);
+static constinit uint32_t HashCompactBinaryPackageOffer = HashStringDjb2("application/x-ue-offer"sv);
+static constinit uint32_t HashCompressedBinary = HashStringDjb2("application/x-ue-comp"sv);
+static constinit uint32_t HashHtml = HashStringDjb2("html"sv);
+static constinit uint32_t HashTextHtml = HashStringDjb2("text/html"sv);
+static constinit uint32_t HashJavaScript = HashStringDjb2("js"sv);
+static constinit uint32_t HashApplicationJavaScript = HashStringDjb2("application/javascript"sv);
+static constinit uint32_t HashCss = HashStringDjb2("css"sv);
+static constinit uint32_t HashTextCss = HashStringDjb2("text/css"sv);
+static constinit uint32_t HashPng = HashStringDjb2("png"sv);
+static constinit uint32_t HashImagePng = HashStringDjb2("image/png"sv);
+static constinit uint32_t HashIcon = HashStringDjb2("ico"sv);
+static constinit uint32_t HashImageIcon = HashStringDjb2("image/x-icon"sv);
+
+std::once_flag InitContentTypeLookup;
+
+struct HashedTypeEntry
+{
+ uint32_t Hash;
+ HttpContentType Type;
+} TypeHashTable[] = {
+ // clang-format off
+ {HashBinary, HttpContentType::kBinary},
+ {HashApplicationCompactBinary, HttpContentType::kCbObject},
+ {HashCompactBinary, HttpContentType::kCbObject},
+ {HashCompactBinaryPackage, HttpContentType::kCbPackage},
+ {HashCompactBinaryPackageShort, HttpContentType::kCbPackage},
+ {HashCompactBinaryPackageOffer, HttpContentType::kCbPackageOffer},
+ {HashJson, HttpContentType::kJSON},
+ {HashApplicationJson, HttpContentType::kJSON},
+ {HashYaml, HttpContentType::kYAML},
+ {HashTextYaml, HttpContentType::kYAML},
+ {HashText, HttpContentType::kText},
+ {HashCompressedBinary, HttpContentType::kCompressedBinary},
+ {HashHtml, HttpContentType::kHTML},
+ {HashTextHtml, HttpContentType::kHTML},
+ {HashJavaScript, HttpContentType::kJavaScript},
+ {HashApplicationJavaScript, HttpContentType::kJavaScript},
+ {HashCss, HttpContentType::kCSS},
+ {HashTextCss, HttpContentType::kCSS},
+ {HashPng, HttpContentType::kPNG},
+ {HashImagePng, HttpContentType::kPNG},
+ {HashIcon, HttpContentType::kIcon},
+ {HashImageIcon, HttpContentType::kIcon},
+ // clang-format on
+};
+
+HttpContentType
+ParseContentTypeImpl(const std::string_view& ContentTypeString)
+{
+ if (!ContentTypeString.empty())
+ {
+ const uint32_t CtHash = HashStringDjb2(ContentTypeString);
+
+ if (auto It = std::lower_bound(std::begin(TypeHashTable),
+ std::end(TypeHashTable),
+ CtHash,
+ [](const HashedTypeEntry& Lhs, const uint32_t Rhs) { return Lhs.Hash < Rhs; });
+ It != std::end(TypeHashTable))
+ {
+ if (It->Hash == CtHash)
+ {
+ return It->Type;
+ }
+ }
+ }
+
+ return HttpContentType::kUnknownContentType;
+}
+
+HttpContentType
+ParseContentTypeInit(const std::string_view& ContentTypeString)
+{
+ std::call_once(InitContentTypeLookup, [] {
+ std::sort(std::begin(TypeHashTable), std::end(TypeHashTable), [](const HashedTypeEntry& Lhs, const HashedTypeEntry& Rhs) {
+ return Lhs.Hash < Rhs.Hash;
+ });
+
+ // validate that there are no hash collisions
+
+ uint32_t LastHash = 0;
+
+ for (const auto& Item : TypeHashTable)
+ {
+ ZEN_ASSERT(LastHash != Item.Hash);
+ LastHash = Item.Hash;
+ }
+ });
+
+ ParseContentType = ParseContentTypeImpl;
+
+ return ParseContentTypeImpl(ContentTypeString);
+}
+
+HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString) = &ParseContentTypeInit;
+
+bool
+TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges)
+{
+ if (RangeHeader.empty())
+ {
+ return false;
+ }
+
+ const size_t Count = Ranges.size();
+
+ std::size_t UnitDelim = RangeHeader.find_first_of('=');
+ if (UnitDelim == std::string_view::npos)
+ {
+ return false;
+ }
+
+ // only bytes for now
+ std::string_view Unit = RangeHeader.substr(0, UnitDelim);
+ if (Unit != "bytes"sv)
+ {
+ return false;
+ }
+
+ std::string_view Tokens = RangeHeader.substr(UnitDelim);
+ while (!Tokens.empty())
+ {
+ // Skip =,
+ Tokens = Tokens.substr(1);
+
+ size_t Delim = Tokens.find_first_of(',');
+ if (Delim == std::string_view::npos)
+ {
+ Delim = Tokens.length();
+ }
+
+ std::string_view Token = Tokens.substr(0, Delim);
+ Tokens = Tokens.substr(Delim);
+
+ Delim = Token.find_first_of('-');
+ if (Delim == std::string_view::npos)
+ {
+ return false;
+ }
+
+ const auto Start = ParseInt<uint32_t>(Token.substr(0, Delim));
+ const auto End = ParseInt<uint32_t>(Token.substr(Delim + 1));
+
+ if (Start.has_value() && End.has_value() && End.value() > Start.value())
+ {
+ Ranges.push_back({.Start = Start.value(), .End = End.value()});
+ }
+ else if (Start)
+ {
+ Ranges.push_back({.Start = Start.value()});
+ }
+ else if (End)
+ {
+ Ranges.push_back({.End = End.value()});
+ }
+ }
+
+ return Count != Ranges.size();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+const std::string_view
+ToString(HttpVerb Verb)
+{
+ switch (Verb)
+ {
+ case HttpVerb::kGet:
+ return "GET"sv;
+ case HttpVerb::kPut:
+ return "PUT"sv;
+ case HttpVerb::kPost:
+ return "POST"sv;
+ case HttpVerb::kDelete:
+ return "DELETE"sv;
+ case HttpVerb::kHead:
+ return "HEAD"sv;
+ case HttpVerb::kCopy:
+ return "COPY"sv;
+ case HttpVerb::kOptions:
+ return "OPTIONS"sv;
+ default:
+ return "???"sv;
+ }
+}
+
+std::string_view
+ReasonStringForHttpResultCode(int HttpCode)
+{
+ switch (HttpCode)
+ {
+ // 1xx Informational
+
+ case 100:
+ return "Continue"sv;
+ case 101:
+ return "Switching Protocols"sv;
+
+ // 2xx Success
+
+ case 200:
+ return "OK"sv;
+ case 201:
+ return "Created"sv;
+ case 202:
+ return "Accepted"sv;
+ case 204:
+ return "No Content"sv;
+ case 205:
+ return "Reset Content"sv;
+ case 206:
+ return "Partial Content"sv;
+
+ // 3xx Redirection
+
+ case 300:
+ return "Multiple Choices"sv;
+ case 301:
+ return "Moved Permanently"sv;
+ case 302:
+ return "Found"sv;
+ case 303:
+ return "See Other"sv;
+ case 304:
+ return "Not Modified"sv;
+ case 305:
+ return "Use Proxy"sv;
+ case 306:
+ return "Switch Proxy"sv;
+ case 307:
+ return "Temporary Redirect"sv;
+ case 308:
+ return "Permanent Redirect"sv;
+
+ // 4xx Client errors
+
+ case 400:
+ return "Bad Request"sv;
+ case 401:
+ return "Unauthorized"sv;
+ case 402:
+ return "Payment Required"sv;
+ case 403:
+ return "Forbidden"sv;
+ case 404:
+ return "Not Found"sv;
+ case 405:
+ return "Method Not Allowed"sv;
+ case 406:
+ return "Not Acceptable"sv;
+ case 407:
+ return "Proxy Authentication Required"sv;
+ case 408:
+ return "Request Timeout"sv;
+ case 409:
+ return "Conflict"sv;
+ case 410:
+ return "Gone"sv;
+ case 411:
+ return "Length Required"sv;
+ case 412:
+ return "Precondition Failed"sv;
+ case 413:
+ return "Payload Too Large"sv;
+ case 414:
+ return "URI Too Long"sv;
+ case 415:
+ return "Unsupported Media Type"sv;
+ case 416:
+ return "Range Not Satisifiable"sv;
+ case 417:
+ return "Expectation Failed"sv;
+ case 418:
+ return "I'm a teapot"sv;
+ case 421:
+ return "Misdirected Request"sv;
+ case 422:
+ return "Unprocessable Entity"sv;
+ case 423:
+ return "Locked"sv;
+ case 424:
+ return "Failed Dependency"sv;
+ case 425:
+ return "Too Early"sv;
+ case 426:
+ return "Upgrade Required"sv;
+ case 428:
+ return "Precondition Required"sv;
+ case 429:
+ return "Too Many Requests"sv;
+ case 431:
+ return "Request Header Fields Too Large"sv;
+
+ // 5xx Server errors
+
+ case 500:
+ return "Internal Server Error"sv;
+ case 501:
+ return "Not Implemented"sv;
+ case 502:
+ return "Bad Gateway"sv;
+ case 503:
+ return "Service Unavailable"sv;
+ case 504:
+ return "Gateway Timeout"sv;
+ case 505:
+ return "HTTP Version Not Supported"sv;
+ case 506:
+ return "Variant Also Negotiates"sv;
+ case 507:
+ return "Insufficient Storage"sv;
+ case 508:
+ return "Loop Detected"sv;
+ case 510:
+ return "Not Extended"sv;
+ case 511:
+ return "Network Authentication Required"sv;
+
+ default:
+ return "Unknown Result"sv;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+Ref<IHttpPackageHandler>
+HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
+{
+ ZEN_UNUSED(HttpServiceRequest);
+
+ return Ref<IHttpPackageHandler>();
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpServerRequest::HttpServerRequest()
+{
+}
+
+HttpServerRequest::~HttpServerRequest()
+{
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbPackage Data)
+{
+ std::vector<IoBuffer> ResponseBuffers = FormatPackageMessage(Data);
+ return WriteResponse(ResponseCode, HttpContentType::kCbPackage, ResponseBuffers);
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbObject Data)
+{
+ if (m_AcceptType == HttpContentType::kJSON)
+ {
+ ExtendableStringBuilder<1024> Sb;
+ WriteResponse(ResponseCode, HttpContentType::kJSON, Data.ToJson(Sb).ToView());
+ }
+ else
+ {
+ SharedBuffer Buf = Data.GetBuffer();
+ std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())};
+ return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers);
+ }
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, CbArray Array)
+{
+ if (m_AcceptType == HttpContentType::kJSON)
+ {
+ ExtendableStringBuilder<1024> Sb;
+ WriteResponse(ResponseCode, HttpContentType::kJSON, Array.ToJson(Sb).ToView());
+ }
+ else
+ {
+ SharedBuffer Buf = Array.GetBuffer();
+ std::array<IoBuffer, 1> Buffers{IoBufferBuilder::MakeCloneFromMemory(Buf.GetData(), Buf.GetSize())};
+ return WriteResponse(ResponseCode, HttpContentType::kCbObject, Buffers);
+ }
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString)
+{
+ return WriteResponse(ResponseCode, ContentType, std::u8string_view{(char8_t*)ResponseString.data(), ResponseString.size()});
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob)
+{
+ std::array<IoBuffer, 1> Buffers{Blob};
+ return WriteResponse(ResponseCode, ContentType, Buffers);
+}
+
+void
+HttpServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload)
+{
+ std::span<const SharedBuffer> Segments = Payload.GetSegments();
+
+ std::vector<IoBuffer> Buffers;
+
+ for (auto& Segment : Segments)
+ {
+ Buffers.push_back(Segment.AsIoBuffer());
+ }
+
+ WriteResponse(ResponseCode, ContentType, Buffers);
+}
+
+HttpServerRequest::QueryParams
+HttpServerRequest::GetQueryParams()
+{
+ QueryParams Params;
+
+ const std::string_view QStr = QueryString();
+
+ const char* QueryIt = QStr.data();
+ const char* QueryEnd = QueryIt + QStr.size();
+
+ while (QueryIt != QueryEnd)
+ {
+ if (*QueryIt == '&')
+ {
+ ++QueryIt;
+ continue;
+ }
+
+ size_t QueryLen = ptrdiff_t(QueryEnd - QueryIt);
+ const std::string_view Query{QueryIt, QueryLen};
+
+ size_t DelimIndex = Query.find('&', 0);
+
+ if (DelimIndex == std::string_view::npos)
+ {
+ DelimIndex = Query.size();
+ }
+
+ std::string_view ThisQuery{QueryIt, DelimIndex};
+
+ size_t EqIndex = ThisQuery.find('=', 0);
+
+ if (EqIndex != std::string_view::npos)
+ {
+ std::string_view Param{ThisQuery.data(), EqIndex};
+ ThisQuery.remove_prefix(EqIndex + 1);
+
+ Params.KvPairs.emplace_back(Param, ThisQuery);
+ }
+
+ QueryIt += DelimIndex;
+ }
+
+ return Params;
+}
+
+Oid
+HttpServerRequest::SessionId() const
+{
+ if (m_Flags & kHaveSessionId)
+ {
+ return m_SessionId;
+ }
+
+ m_SessionId = ParseSessionId();
+ m_Flags |= kHaveSessionId;
+ return m_SessionId;
+}
+
+uint32_t
+HttpServerRequest::RequestId() const
+{
+ if (m_Flags & kHaveRequestId)
+ {
+ return m_RequestId;
+ }
+
+ m_RequestId = ParseRequestId();
+ m_Flags |= kHaveRequestId;
+ return m_RequestId;
+}
+
+CbObject
+HttpServerRequest::ReadPayloadObject()
+{
+ if (IoBuffer Payload = ReadPayload())
+ {
+ return LoadCompactBinaryObject(std::move(Payload));
+ }
+
+ return {};
+}
+
+CbPackage
+HttpServerRequest::ReadPayloadPackage()
+{
+ if (IoBuffer Payload = ReadPayload())
+ {
+ return ParsePackageMessage(std::move(Payload));
+ }
+
+ return {};
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+void
+HttpRequestRouter::AddPattern(const char* Id, const char* Regex)
+{
+ ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end());
+
+ m_PatternMap.insert({Id, Regex});
+}
+
+void
+HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs)
+{
+ ExtendableStringBuilder<128> ExpandedRegex;
+ ProcessRegexSubstitutions(Regex, ExpandedRegex);
+
+ m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex);
+}
+
+void
+HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex)
+{
+ size_t RegexLen = strlen(Regex);
+
+ for (size_t i = 0; i < RegexLen;)
+ {
+ bool matched = false;
+
+ if (Regex[i] == '{' && ((i == 0) || (Regex[i - 1] != '\\')))
+ {
+ // Might have a pattern reference - find closing brace
+
+ for (size_t j = i + 1; j < RegexLen; ++j)
+ {
+ if (Regex[j] == '}')
+ {
+ std::string Pattern(&Regex[i + 1], j - i - 1);
+
+ if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end())
+ {
+ OutExpandedRegex.Append(it->second.c_str());
+ }
+ else
+ {
+ // Default to anything goes (or should this just be an error?)
+
+ OutExpandedRegex.Append("(.+?)");
+ }
+
+ // skip ahead
+ i = j + 1;
+
+ matched = true;
+
+ break;
+ }
+ }
+ }
+
+ if (!matched)
+ {
+ OutExpandedRegex.Append(Regex[i++]);
+ }
+ }
+}
+
+bool
+HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
+{
+ const HttpVerb Verb = Request.RequestVerb();
+
+ std::string_view Uri = Request.RelativeUri();
+ HttpRouterRequest RouterRequest(Request);
+
+ for (const auto& Handler : m_Handlers)
+ {
+ if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx))
+ {
+ Handler.Handler(RouterRequest);
+
+ return true; // Route matched
+ }
+ }
+
+ return false; // No route matched
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpRpcHandler::HttpRpcHandler()
+{
+}
+
+HttpRpcHandler::~HttpRpcHandler()
+{
+}
+
+void
+HttpRpcHandler::AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction)
+{
+ ZEN_UNUSED(RpcId, HandlerFunction);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+enum class HttpServerClass
+{
+ kHttpAsio,
+ kHttpSys,
+ kHttpNull
+};
+
+// Implemented in httpsys.cpp
+Ref<HttpServer> CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads);
+
+Ref<HttpServer>
+CreateHttpServer(std::string_view ServerClass)
+{
+ using namespace std::literals;
+
+ HttpServerClass Class = HttpServerClass::kHttpNull;
+
+#if ZEN_WITH_HTTPSYS
+ Class = HttpServerClass::kHttpSys;
+#elif 1
+ Class = HttpServerClass::kHttpAsio;
+#endif
+
+ if (ServerClass == "asio"sv)
+ {
+ Class = HttpServerClass::kHttpAsio;
+ }
+ else if (ServerClass == "httpsys"sv)
+ {
+ Class = HttpServerClass::kHttpSys;
+ }
+ else if (ServerClass == "null"sv)
+ {
+ Class = HttpServerClass::kHttpNull;
+ }
+
+ switch (Class)
+ {
+ default:
+ case HttpServerClass::kHttpAsio:
+ ZEN_INFO("using asio HTTP server implementation");
+ return Ref<HttpServer>(new HttpAsioServer());
+
+#if ZEN_WITH_HTTPSYS
+ case HttpServerClass::kHttpSys:
+ ZEN_INFO("using http.sys server implementation");
+ return Ref<HttpServer>(new HttpSysServer(std::thread::hardware_concurrency(), /* background worker threads */ 16));
+#endif
+
+ case HttpServerClass::kHttpNull:
+ ZEN_INFO("using null HTTP server implementation");
+ return Ref<HttpServer>(new HttpNullServer);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+bool
+HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef)
+{
+ if (Request.RequestVerb() == HttpVerb::kPost)
+ {
+ if (Request.RequestContentType() == HttpContentType::kCbPackageOffer)
+ {
+ // The client is presenting us with a package attachments offer, we need
+ // to filter it down to the list of attachments we need them to send in
+ // the follow-up request
+
+ PackageHandlerRef = Service.HandlePackageRequest(Request);
+
+ if (PackageHandlerRef)
+ {
+ CbObject OfferMessage = LoadCompactBinaryObject(Request.ReadPayload());
+
+ std::vector<IoHash> OfferCids;
+
+ for (auto& CidEntry : OfferMessage["offer"])
+ {
+ if (!CidEntry.IsHash())
+ {
+ // Should yield bad request response?
+
+ ZEN_WARN("found invalid entry in offer");
+
+ continue;
+ }
+
+ OfferCids.push_back(CidEntry.AsHash());
+ }
+
+ ZEN_TRACE("request #{} -> filtering offer of {} entries", Request.RequestId(), OfferCids.size());
+
+ PackageHandlerRef->FilterOffer(OfferCids);
+
+ ZEN_TRACE("request #{} -> filtered to {} entries", Request.RequestId(), OfferCids.size());
+
+ CbObjectWriter ResponseWriter;
+ ResponseWriter.BeginArray("need");
+
+ for (const IoHash& Cid : OfferCids)
+ {
+ ResponseWriter.AddHash(Cid);
+ }
+
+ ResponseWriter.EndArray();
+
+ // Emit filter response
+ Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
+ return true;
+ }
+ }
+ else if (Request.RequestContentType() == HttpContentType::kCbPackage)
+ {
+ // Process chunks in package request
+
+ PackageHandlerRef = Service.HandlePackageRequest(Request);
+
+ // TODO: this should really be done in a streaming fashion, currently this emulates
+ // the intended flow from an API perspective
+
+ if (PackageHandlerRef)
+ {
+ PackageHandlerRef->OnRequestBegin();
+
+ auto CreateBuffer = [&](const IoHash& Cid, uint64_t Size) -> IoBuffer {
+ return PackageHandlerRef->CreateTarget(Cid, Size);
+ };
+
+ CbPackage Package = ParsePackageMessage(Request.ReadPayload(), CreateBuffer);
+
+ PackageHandlerRef->OnRequestComplete();
+ }
+ }
+ }
+ return false;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("http.common")
+{
+ using namespace std::literals;
+
+ SUBCASE("router")
+ {
+ HttpRequestRouter r;
+ r.AddPattern("a", "[[:alpha:]]+");
+ r.RegisterRoute(
+ "{a}",
+ [&](auto) {},
+ HttpVerb::kGet);
+
+ // struct TestHttpServerRequest : public HttpServerRequest
+ //{
+ // TestHttpServerRequest(std::string_view Uri) : m_uri{Uri} {}
+ //};
+
+ // TestHttpServerRequest req{};
+ // r.HandleRequest(req);
+ }
+
+ SUBCASE("content-type")
+ {
+ for (uint8_t i = 0; i < uint8_t(HttpContentType::kCOUNT); ++i)
+ {
+ HttpContentType Ct{i};
+
+ if (Ct != HttpContentType::kUnknownContentType)
+ {
+ CHECK_EQ(Ct, ParseContentType(MapContentTypeToString(Ct)));
+ }
+ }
+ }
+}
+
+void
+http_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenhttp/httpshared.cpp b/src/zenhttp/httpshared.cpp
new file mode 100644
index 000000000..7aade56d2
--- /dev/null
+++ b/src/zenhttp/httpshared.cpp
@@ -0,0 +1,809 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpshared.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
+
+#include <span>
+#include <vector>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+const std::string_view HandlePrefix(":?#:");
+
+std::vector<IoBuffer>
+FormatPackageMessage(const CbPackage& Data, int TargetProcessPid)
+{
+ return FormatPackageMessage(Data, FormatFlags::kDefault, TargetProcessPid);
+}
+CompositeBuffer
+FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid)
+{
+ return FormatPackageMessageBuffer(Data, FormatFlags::kDefault, TargetProcessPid);
+}
+
+CompositeBuffer
+FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid)
+{
+ std::vector<IoBuffer> Message = FormatPackageMessage(Data, Flags, TargetProcessPid);
+
+ std::vector<SharedBuffer> Buffers;
+
+ for (IoBuffer& Buf : Message)
+ {
+ Buffers.push_back(SharedBuffer(Buf));
+ }
+
+ return CompositeBuffer(std::move(Buffers));
+}
+
+std::vector<IoBuffer>
+FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid)
+{
+ void* TargetProcessHandle = nullptr;
+#if ZEN_PLATFORM_WINDOWS
+ std::vector<HANDLE> DuplicatedHandles;
+ auto _ = MakeGuard([&DuplicatedHandles, &TargetProcessHandle]() {
+ if (TargetProcessHandle == nullptr)
+ {
+ return;
+ }
+
+ for (HANDLE DuplicatedHandle : DuplicatedHandles)
+ {
+ HANDLE ClosingHandle;
+ if (::DuplicateHandle((HANDLE)TargetProcessHandle,
+ DuplicatedHandle,
+ GetCurrentProcess(),
+ &ClosingHandle,
+ 0,
+ FALSE,
+ DUPLICATE_CLOSE_SOURCE | DUPLICATE_SAME_ACCESS) == TRUE)
+ {
+ ::CloseHandle(ClosingHandle);
+ }
+ }
+ ::CloseHandle((HANDLE)TargetProcessHandle);
+ TargetProcessHandle = nullptr;
+ });
+
+ if (EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && TargetProcessPid != 0)
+ {
+ TargetProcessHandle = OpenProcess(PROCESS_DUP_HANDLE, FALSE, TargetProcessPid);
+ }
+#else
+ ZEN_UNUSED(TargetProcessPid);
+ void* DuplicatedHandles = nullptr;
+#endif // ZEN_PLATFORM_WINDOWS
+
+ const std::span<const CbAttachment>& Attachments = Data.GetAttachments();
+ std::vector<IoBuffer> ResponseBuffers;
+
+ ResponseBuffers.reserve(3 + Attachments.size()); // TODO: may want to use an additional fudge factor here to avoid growing since each
+ // attachment is likely to consist of several buffers
+
+ // Fixed size header
+
+ CbPackageHeader Hdr{.HeaderMagic = kCbPkgMagic, .AttachmentCount = gsl::narrow<uint32_t>(Attachments.size())};
+
+ ResponseBuffers.push_back(IoBufferBuilder::MakeCloneFromMemory(&Hdr, sizeof Hdr));
+
+ // Attachment metadata array
+
+ IoBuffer AttachmentMetadataBuffer = IoBuffer{sizeof(CbAttachmentEntry) * (Attachments.size() + /* root */ 1)};
+ CbAttachmentEntry* AttachmentInfo = reinterpret_cast<CbAttachmentEntry*>(AttachmentMetadataBuffer.MutableData());
+
+ ResponseBuffers.push_back(AttachmentMetadataBuffer); // Attachment metadata
+
+ // Root object
+
+ IoBuffer RootIoBuffer = Data.GetObject().GetBuffer().AsIoBuffer();
+ ResponseBuffers.push_back(RootIoBuffer); // Root object
+
+ *AttachmentInfo++ = {.PayloadSize = RootIoBuffer.Size(), .Flags = CbAttachmentEntry::kIsObject, .AttachmentHash = Data.GetObjectHash()};
+
+ // Attachment payloads
+
+ auto MarshalLocal = [&AttachmentInfo, &ResponseBuffers](const std::string& Path8,
+ CbAttachmentReferenceHeader& LocalRef,
+ const IoHash& AttachmentHash,
+ bool IsCompressed) {
+ IoBuffer RefBuffer(sizeof(CbAttachmentReferenceHeader) + Path8.size());
+
+ CbAttachmentReferenceHeader* RefHdr = RefBuffer.MutableData<CbAttachmentReferenceHeader>();
+ *RefHdr++ = LocalRef;
+ memcpy(RefHdr, Path8.data(), Path8.size());
+
+ *AttachmentInfo++ = {.PayloadSize = RefBuffer.GetSize(),
+ .Flags = (IsCompressed ? uint32_t(CbAttachmentEntry::kIsCompressed) : 0u) | CbAttachmentEntry::kIsLocalRef,
+ .AttachmentHash = AttachmentHash};
+
+ ResponseBuffers.push_back(std::move(RefBuffer));
+ };
+
+ tsl::robin_map<void*, std::string> FileNameMap;
+
+ auto IsLocalRef = [&FileNameMap, &DuplicatedHandles](const CompositeBuffer& AttachmentBinary,
+ bool DenyPartialLocalReferences,
+ void* TargetProcessHandle,
+ CbAttachmentReferenceHeader& LocalRef,
+ std::string& Path8) -> bool {
+ const SharedBuffer& Segment = AttachmentBinary.GetSegments().front();
+ IoBufferFileReference Ref;
+ const IoBuffer& SegmentBuffer = Segment.AsIoBuffer();
+
+ if (!SegmentBuffer.GetFileReference(Ref))
+ {
+ return false;
+ }
+
+ if (DenyPartialLocalReferences && !SegmentBuffer.IsWholeFile())
+ {
+ return false;
+ }
+
+ if (auto It = FileNameMap.find(Ref.FileHandle); It != FileNameMap.end())
+ {
+ Path8 = It->second;
+ }
+ else
+ {
+ bool UseFilePath = true;
+#if ZEN_PLATFORM_WINDOWS
+ if (TargetProcessHandle != nullptr)
+ {
+ HANDLE TargetHandle = INVALID_HANDLE_VALUE;
+ BOOL OK = ::DuplicateHandle(GetCurrentProcess(),
+ Ref.FileHandle,
+ (HANDLE)TargetProcessHandle,
+ &TargetHandle,
+ FILE_GENERIC_READ,
+ FALSE,
+ 0);
+ if (OK)
+ {
+ DuplicatedHandles.push_back(TargetHandle);
+ Path8 = fmt::format("{}{}", HandlePrefix, reinterpret_cast<uint64_t>(TargetHandle));
+ UseFilePath = false;
+ }
+ }
+#else // ZEN_PLATFORM_WINDOWS
+ ZEN_UNUSED(TargetProcessHandle);
+ // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes and to
+ // deal with acceess rights etc.
+#endif // ZEN_PLATFORM_WINDOWS
+ if (UseFilePath)
+ {
+ ExtendablePathBuilder<256> LocalRefFile;
+ LocalRefFile.Append(std::filesystem::absolute(PathFromHandle(Ref.FileHandle)));
+ Path8 = LocalRefFile.ToUtf8();
+ }
+ FileNameMap.insert_or_assign(Ref.FileHandle, Path8);
+ }
+
+ LocalRef.AbsolutePathLength = gsl::narrow<uint16_t>(Path8.size());
+ LocalRef.PayloadByteOffset = Ref.FileChunkOffset;
+ LocalRef.PayloadByteSize = Ref.FileChunkSize;
+
+ return true;
+ };
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ if (Attachment.IsNull())
+ {
+ ZEN_NOT_IMPLEMENTED("Null attachments are not supported");
+ }
+ else if (CompressedBuffer AttachmentBuffer = Attachment.AsCompressedBinary())
+ {
+ CompositeBuffer Compressed = AttachmentBuffer.GetCompressed();
+ IoHash AttachmentHash = Attachment.GetHash();
+
+ // If the data is either not backed by a file, or there are multiple
+ // fragments then we cannot marshal it by local reference. We might
+ // want/need to extend this in the future to allow multiple chunk
+ // segments to be marshaled at once
+
+ bool MarshalByLocalRef = EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (Compressed.GetSegments().size() == 1);
+ bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences);
+ CbAttachmentReferenceHeader LocalRef;
+ std::string Path8;
+
+ if (MarshalByLocalRef)
+ {
+ MarshalByLocalRef = IsLocalRef(Compressed, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8);
+ }
+
+ if (MarshalByLocalRef)
+ {
+ const bool IsCompressed = true;
+ bool IsHandle = false;
+#if ZEN_PLATFORM_WINDOWS
+ IsHandle = Path8.starts_with(HandlePrefix);
+#endif
+ MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed);
+ ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", Compressed.GetSize());
+ }
+ else
+ {
+ *AttachmentInfo++ = {.PayloadSize = AttachmentBuffer.GetCompressedSize(),
+ .Flags = CbAttachmentEntry::kIsCompressed,
+ .AttachmentHash = AttachmentHash};
+
+ for (const SharedBuffer& Segment : Compressed.GetSegments())
+ {
+ ResponseBuffers.push_back(Segment.AsIoBuffer());
+ }
+ }
+ }
+ else if (CbObject AttachmentObject = Attachment.AsObject())
+ {
+ IoBuffer ObjIoBuffer = AttachmentObject.GetBuffer().AsIoBuffer();
+ ResponseBuffers.push_back(ObjIoBuffer);
+
+ *AttachmentInfo++ = {.PayloadSize = ObjIoBuffer.Size(),
+ .Flags = CbAttachmentEntry::kIsObject,
+ .AttachmentHash = Attachment.GetHash()};
+ }
+ else if (CompositeBuffer AttachmentBinary = Attachment.AsCompositeBinary())
+ {
+ IoHash AttachmentHash = Attachment.GetHash();
+ bool MarshalByLocalRef =
+ EnumHasAllFlags(Flags, FormatFlags::kAllowLocalReferences) && (AttachmentBinary.GetSegments().size() == 1);
+ bool DenyPartialLocalReferences = EnumHasAllFlags(Flags, FormatFlags::kDenyPartialLocalReferences);
+
+ CbAttachmentReferenceHeader LocalRef;
+ std::string Path8;
+
+ if (MarshalByLocalRef)
+ {
+ MarshalByLocalRef = IsLocalRef(AttachmentBinary, DenyPartialLocalReferences, TargetProcessHandle, LocalRef, Path8);
+ }
+
+ if (MarshalByLocalRef)
+ {
+ const bool IsCompressed = false;
+ bool IsHandle = false;
+#if ZEN_PLATFORM_WINDOWS
+ IsHandle = Path8.starts_with(HandlePrefix);
+#endif
+ MarshalLocal(Path8, LocalRef, AttachmentHash, IsCompressed);
+ ZEN_DEBUG("Marshalled '{}' as file {} of {} bytes", Path8, IsHandle ? "handle" : "path", AttachmentBinary.GetSize());
+ }
+ else
+ {
+ *AttachmentInfo++ = {.PayloadSize = AttachmentBinary.GetSize(), .Flags = 0, .AttachmentHash = Attachment.GetHash()};
+
+ for (const SharedBuffer& Segment : AttachmentBinary.GetSegments())
+ {
+ ResponseBuffers.push_back(Segment.AsIoBuffer());
+ }
+ }
+ }
+ else
+ {
+ ZEN_NOT_IMPLEMENTED("Unknown attachment kind");
+ }
+ }
+ FileNameMap.clear();
+#if ZEN_PLATFORM_WINDOWS
+ DuplicatedHandles.clear();
+#endif // ZEN_PLATFORM_WINDOWS
+
+ return ResponseBuffers;
+}
+
+bool
+IsPackageMessage(IoBuffer Payload)
+{
+ if (!Payload)
+ {
+ return false;
+ }
+
+ BinaryReader Reader(Payload);
+
+ CbPackageHeader Hdr;
+ Reader.Read(&Hdr, sizeof Hdr);
+
+ if (Hdr.HeaderMagic != kCbPkgMagic)
+ {
+ return false;
+ }
+
+ return true;
+}
+
+CbPackage
+ParsePackageMessage(IoBuffer Payload, std::function<IoBuffer(const IoHash&, uint64_t)> CreateBuffer)
+{
+ if (!Payload)
+ {
+ return {};
+ }
+
+ BinaryReader Reader(Payload);
+
+ CbPackageHeader Hdr;
+ Reader.Read(&Hdr, sizeof Hdr);
+
+ if (Hdr.HeaderMagic != kCbPkgMagic)
+ {
+ throw std::runtime_error("invalid CbPackage header magic");
+ }
+
+ const uint32_t ChunkCount = Hdr.AttachmentCount + 1;
+
+ std::unique_ptr<CbAttachmentEntry[]> AttachmentEntries{new CbAttachmentEntry[ChunkCount]};
+
+ Reader.Read(AttachmentEntries.get(), sizeof(CbAttachmentEntry) * ChunkCount);
+
+ CbPackage Package;
+
+ std::vector<CbAttachment> Attachments;
+ Attachments.reserve(ChunkCount); // Guessing here...
+
+ tsl::robin_map<std::string, IoBuffer> PartialFileBuffers;
+
+ // TODO: Throwing before this loop completes could result in leaking handles as we might not have picked up all the handles in the
+ // message
+ for (uint32_t i = 0; i < ChunkCount; ++i)
+ {
+ const CbAttachmentEntry& Entry = AttachmentEntries[i];
+ const uint64_t AttachmentSize = Entry.PayloadSize;
+
+ const IoBuffer AttachmentBuffer(Payload, Reader.CurrentOffset(), AttachmentSize);
+ Reader.Skip(AttachmentSize);
+
+ if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
+ {
+ // Marshal local reference - a "pointer" to the chunk backing file
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
+
+ const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
+ const char* PathPointer = reinterpret_cast<const char*>(AttachRefHdr + 1);
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
+ std::string_view PathView(PathPointer, AttachRefHdr->AbsolutePathLength);
+
+ IoBuffer FullFileBuffer;
+
+ std::filesystem::path Path(Utf8ToWide(PathView));
+ if (auto It = PartialFileBuffers.find(Path.string()); It != PartialFileBuffers.end())
+ {
+ FullFileBuffer = It->second;
+ }
+ else
+ {
+ if (PathView.starts_with(HandlePrefix))
+ {
+#if ZEN_PLATFORM_WINDOWS
+ std::string_view HandleString(PathView.substr(HandlePrefix.length()));
+ std::optional<uint64_t> HandleNumber(ParseInt<uint64_t>(HandleString));
+ if (HandleNumber.has_value())
+ {
+ HANDLE FileHandle = HANDLE(HandleNumber.value());
+ ULARGE_INTEGER liFileSize;
+ liFileSize.LowPart = ::GetFileSize(FileHandle, &liFileSize.HighPart);
+ if (liFileSize.LowPart != INVALID_FILE_SIZE)
+ {
+ FullFileBuffer = IoBuffer(IoBuffer::File, (void*)FileHandle, 0, uint64_t(liFileSize.QuadPart));
+ PartialFileBuffers.insert_or_assign(Path.string(), FullFileBuffer);
+ }
+ }
+#else // ZEN_PLATFORM_WINDOWS
+ // Not supported on Linux/Mac. Could potentially use pidfd_getfd() but that requires a fairly new Linux kernel/includes
+ // and to deal with acceess rights etc.
+ ZEN_ASSERT(false);
+#endif // ZEN_PLATFORM_WINDOWS
+ }
+ else
+ {
+ FullFileBuffer = PartialFileBuffers.insert_or_assign(Path.string(), IoBufferBuilder::MakeFromFile(Path)).first->second;
+ }
+ }
+
+ if (!FullFileBuffer)
+ {
+ // Unable to open chunk reference
+ throw std::runtime_error(fmt::format("unable to resolve chunk #{} at '{}' (offset {}, size {})",
+ i,
+ Path,
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize));
+ }
+
+ IoBuffer ChunkReference = AttachRefHdr->PayloadByteOffset == 0 && AttachRefHdr->PayloadByteSize == FullFileBuffer.GetSize()
+ ? FullFileBuffer
+ : IoBuffer(FullFileBuffer, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
+
+ CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkReference)));
+ if (!CompBuf)
+ {
+ throw std::runtime_error(fmt::format("invalid format for chunk #{} at '{}' (offset {}, size {})",
+ i,
+ Path,
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize));
+ }
+ Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash));
+ }
+ else if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
+ {
+ if (Entry.Flags & CbAttachmentEntry::kIsObject)
+ {
+ if (i == 0)
+ {
+ CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer)));
+ if (!CompBuf)
+ {
+ throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for CbObject", i));
+ }
+ // First payload is always a compact binary object
+ Package.SetObject(LoadCompactBinaryObject(std::move(CompBuf)));
+ }
+ else
+ {
+ ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported");
+ }
+ }
+ else
+ {
+ CompressedBuffer CompBuf(CompressedBuffer::FromCompressedNoValidate(IoBuffer(AttachmentBuffer)));
+ if (!CompBuf)
+ {
+ throw std::runtime_error(fmt::format("invalid format for chunk #{} expected compressed buffer for attachment", i));
+ }
+ Attachments.emplace_back(CbAttachment(std::move(CompBuf), Entry.AttachmentHash));
+ }
+ }
+ else /* not compressed */
+ {
+ if (Entry.Flags & CbAttachmentEntry::kIsObject)
+ {
+ if (i == 0)
+ {
+ Package.SetObject(LoadCompactBinaryObject(AttachmentBuffer));
+ }
+ else
+ {
+ ZEN_NOT_IMPLEMENTED("Object attachments are not currently supported");
+ }
+ }
+ else
+ {
+ // Make a copy of the buffer so we attachements don't reference the entire payload
+ IoBuffer AttachmentBufferCopy = CreateBuffer(Entry.AttachmentHash, AttachmentSize);
+ ZEN_ASSERT(AttachmentBufferCopy);
+ ZEN_ASSERT(AttachmentBufferCopy.Size() == AttachmentSize);
+ AttachmentBufferCopy.GetMutableView().CopyFrom(AttachmentBuffer.GetView());
+
+ CbAttachment Attachment(SharedBuffer{AttachmentBufferCopy});
+ Attachments.emplace_back(SharedBuffer{AttachmentBufferCopy});
+ }
+ }
+ }
+ PartialFileBuffers.clear();
+
+ Package.AddAttachments(Attachments);
+
+ return Package;
+}
+
+bool
+ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage)
+{
+ if (IsPackageMessage(Response))
+ {
+ OutPackage = ParsePackageMessage(Response);
+ return true;
+ }
+ return OutPackage.TryLoad(Response);
+}
+
+CbPackageReader::CbPackageReader() : m_CreateBuffer([](const IoHash&, uint64_t Size) -> IoBuffer { return IoBuffer{Size}; })
+{
+}
+
+CbPackageReader::~CbPackageReader()
+{
+}
+
+void
+CbPackageReader::SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer)
+{
+ m_CreateBuffer = CreateBuffer;
+}
+
+uint64_t
+CbPackageReader::ProcessPackageHeaderData(const void* Data, uint64_t DataBytes)
+{
+ ZEN_ASSERT(m_CurrentState != State::kReadingBuffers);
+
+ switch (m_CurrentState)
+ {
+ case State::kInitialState:
+ ZEN_ASSERT(Data == nullptr);
+ m_CurrentState = State::kReadingHeader;
+ return sizeof m_PackageHeader;
+
+ case State::kReadingHeader:
+ ZEN_ASSERT(DataBytes == sizeof m_PackageHeader);
+ memcpy(&m_PackageHeader, Data, sizeof m_PackageHeader);
+ ZEN_ASSERT(m_PackageHeader.HeaderMagic == kCbPkgMagic);
+ m_CurrentState = State::kReadingAttachmentEntries;
+ m_AttachmentEntries.resize(m_PackageHeader.AttachmentCount + 1);
+ return (m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry);
+
+ case State::kReadingAttachmentEntries:
+ ZEN_ASSERT(DataBytes == ((m_PackageHeader.AttachmentCount + 1) * sizeof(CbAttachmentEntry)));
+ memcpy(m_AttachmentEntries.data(), Data, DataBytes);
+
+ for (CbAttachmentEntry& Entry : m_AttachmentEntries)
+ {
+ // This preallocates memory for payloads but note that for the local references
+ // the caller will need to handle the payload differently (i.e it's a
+ // CbAttachmentReferenceHeader not the actual payload)
+
+ m_PayloadBuffers.push_back(IoBuffer{Entry.PayloadSize});
+ }
+
+ m_CurrentState = State::kReadingBuffers;
+ return 0;
+
+ default:
+ ZEN_ASSERT(false);
+ return 0;
+ }
+}
+
+IoBuffer
+CbPackageReader::MarshalLocalChunkReference(IoBuffer AttachmentBuffer)
+{
+ // Marshal local reference - a "pointer" to the chunk backing file
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= sizeof(CbAttachmentReferenceHeader));
+
+ const CbAttachmentReferenceHeader* AttachRefHdr = AttachmentBuffer.Data<CbAttachmentReferenceHeader>();
+ const char8_t* PathPointer = reinterpret_cast<const char8_t*>(AttachRefHdr + 1);
+
+ ZEN_ASSERT(AttachmentBuffer.Size() >= (sizeof(CbAttachmentReferenceHeader) + AttachRefHdr->AbsolutePathLength));
+
+ std::u8string_view PathView{PathPointer, AttachRefHdr->AbsolutePathLength};
+
+ std::filesystem::path Path{PathView};
+
+ IoBuffer ChunkReference = IoBufferBuilder::MakeFromFile(Path, AttachRefHdr->PayloadByteOffset, AttachRefHdr->PayloadByteSize);
+
+ if (!ChunkReference)
+ {
+ // Unable to open chunk reference
+
+ throw std::runtime_error(fmt::format("unable to resolve local reference to '{}' (offset {}, size {})",
+ PathToUtf8(Path),
+ AttachRefHdr->PayloadByteOffset,
+ AttachRefHdr->PayloadByteSize));
+ }
+
+ return ChunkReference;
+};
+
+void
+CbPackageReader::Finalize()
+{
+ if (m_AttachmentEntries.empty())
+ {
+ return;
+ }
+
+ m_Attachments.reserve(m_AttachmentEntries.size() - 1);
+
+ int CurrentAttachmentIndex = 0;
+ for (CbAttachmentEntry& Entry : m_AttachmentEntries)
+ {
+ IoBuffer AttachmentBuffer = m_PayloadBuffers[CurrentAttachmentIndex];
+
+ if (CurrentAttachmentIndex == 0)
+ {
+ // Root object
+ if (Entry.Flags & CbAttachmentEntry::kIsObject)
+ {
+ if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
+ {
+ m_RootObject = LoadCompactBinaryObject(MarshalLocalChunkReference(AttachmentBuffer));
+ }
+ else if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(AttachmentBuffer), RawHash, RawSize);
+ if (RawHash == Entry.AttachmentHash)
+ {
+ m_RootObject = LoadCompactBinaryObject(Compressed);
+ }
+ }
+ else
+ {
+ m_RootObject = LoadCompactBinaryObject(std::move(AttachmentBuffer));
+ }
+ }
+ else
+ {
+ throw std::runtime_error("missing or invalid root object");
+ }
+ }
+ else if (Entry.Flags & CbAttachmentEntry::kIsLocalRef)
+ {
+ IoBuffer ChunkReference = MarshalLocalChunkReference(AttachmentBuffer);
+
+ if (Entry.Flags & CbAttachmentEntry::kIsCompressed)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(ChunkReference), RawHash, RawSize);
+ if (RawHash == Entry.AttachmentHash)
+ {
+ m_Attachments.push_back(CbAttachment(Compressed, Entry.AttachmentHash));
+ }
+ }
+ else
+ {
+ CompressedBuffer Compressed =
+ CompressedBuffer::Compress(SharedBuffer(ChunkReference), OodleCompressor::NotSet, OodleCompressionLevel::None);
+ m_Attachments.push_back(CbAttachment(std::move(Compressed), Compressed.DecodeRawHash()));
+ }
+ }
+
+ ++CurrentAttachmentIndex;
+ }
+}
+
+/**
+ ______________________ _____________________________
+ \__ ___/\_ _____// _____/\__ ___/ _____/
+ | | | __)_ \_____ \ | | \_____ \
+ | | | \/ \ | | / \
+ |____| /_______ /_______ / |____| /_______ /
+ \/ \/ \/
+ */
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("CbPackage.Serialization")
+{
+ // Make a test package
+
+ CbAttachment Attach1{SharedBuffer::MakeView(MakeMemoryView("abcd"))};
+ CbAttachment Attach2{SharedBuffer::MakeView(MakeMemoryView("efgh"))};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("abcd", Attach1);
+ Cbo.AddAttachment("efgh", Attach2);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach1);
+ Pkg.AddAttachment(Attach2);
+ Pkg.SetObject(Cbo.Save());
+
+ SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg).Flatten();
+ const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
+ uint64_t RemainingBytes = Buffer.GetSize();
+
+ auto ConsumeBytes = [&](uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ void* ReturnPtr = (void*)CursorPtr;
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ return ReturnPtr;
+ };
+
+ auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ memcpy(TargetBuffer, CursorPtr, ByteCount);
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ };
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
+ NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
+ auto Buffers = Reader.GetPayloadBuffers();
+
+ for (auto& PayloadBuffer : Buffers)
+ {
+ CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
+ }
+
+ Reader.Finalize();
+}
+
+TEST_CASE("CbPackage.LocalRef")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ auto Path1 = TempDir.Path() / "abcd";
+ auto Path2 = TempDir.Path() / "efgh";
+
+ {
+ IoBuffer Buffer1 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("abcd"));
+ IoBuffer Buffer2 = IoBufferBuilder::MakeCloneFromMemory(MakeMemoryView("efgh"));
+
+ WriteFile(Path1, Buffer1);
+ WriteFile(Path2, Buffer2);
+ }
+
+ // Make a test package
+
+ IoBuffer FileBuffer1 = IoBufferBuilder::MakeFromFile(Path1);
+ IoBuffer FileBuffer2 = IoBufferBuilder::MakeFromFile(Path2);
+
+ CbAttachment Attach1{SharedBuffer(FileBuffer1)};
+ CbAttachment Attach2{SharedBuffer(FileBuffer2)};
+
+ CbObjectWriter Cbo;
+ Cbo.AddAttachment("abcd", Attach1);
+ Cbo.AddAttachment("efgh", Attach2);
+
+ CbPackage Pkg;
+ Pkg.AddAttachment(Attach1);
+ Pkg.AddAttachment(Attach2);
+ Pkg.SetObject(Cbo.Save());
+
+ SharedBuffer Buffer = FormatPackageMessageBuffer(Pkg, FormatFlags::kAllowLocalReferences).Flatten();
+ const uint8_t* CursorPtr = reinterpret_cast<const uint8_t*>(Buffer.GetData());
+ uint64_t RemainingBytes = Buffer.GetSize();
+
+ auto ConsumeBytes = [&](uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ void* ReturnPtr = (void*)CursorPtr;
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ return ReturnPtr;
+ };
+
+ auto CopyBytes = [&](void* TargetBuffer, uint64_t ByteCount) {
+ ZEN_ASSERT(ByteCount <= RemainingBytes);
+ memcpy(TargetBuffer, CursorPtr, ByteCount);
+ CursorPtr += ByteCount;
+ RemainingBytes -= ByteCount;
+ };
+
+ CbPackageReader Reader;
+ uint64_t InitialRead = Reader.ProcessPackageHeaderData(nullptr, 0);
+ uint64_t NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(InitialRead), InitialRead);
+ NextBytes = Reader.ProcessPackageHeaderData(ConsumeBytes(NextBytes), NextBytes);
+ auto Buffers = Reader.GetPayloadBuffers();
+
+ for (auto& PayloadBuffer : Buffers)
+ {
+ CopyBytes(PayloadBuffer.MutableData(), PayloadBuffer.GetSize());
+ }
+
+ Reader.Finalize();
+}
+
+void
+forcelink_httpshared()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenhttp/httpsys.cpp b/src/zenhttp/httpsys.cpp
new file mode 100644
index 000000000..c733d618d
--- /dev/null
+++ b/src/zenhttp/httpsys.cpp
@@ -0,0 +1,1674 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpsys.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/except.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/string.h>
+#include <zencore/timer.h>
+#include <zenhttp/httpshared.h>
+
+#if ZEN_WITH_HTTPSYS
+
+# include <conio.h>
+# include <mstcpip.h>
+# pragma comment(lib, "httpapi.lib")
+
+std::wstring
+UTF8_to_UTF16(const char* InPtr)
+{
+ std::wstring OutString;
+ unsigned int Codepoint;
+
+ while (*InPtr != 0)
+ {
+ unsigned char InChar = static_cast<unsigned char>(*InPtr);
+
+ if (InChar <= 0x7f)
+ Codepoint = InChar;
+ else if (InChar <= 0xbf)
+ Codepoint = (Codepoint << 6) | (InChar & 0x3f);
+ else if (InChar <= 0xdf)
+ Codepoint = InChar & 0x1f;
+ else if (InChar <= 0xef)
+ Codepoint = InChar & 0x0f;
+ else
+ Codepoint = InChar & 0x07;
+
+ ++InPtr;
+
+ if (((*InPtr & 0xc0) != 0x80) && (Codepoint <= 0x10ffff))
+ {
+ if (Codepoint > 0xffff)
+ {
+ OutString.append(1, static_cast<wchar_t>(0xd800 + (Codepoint >> 10)));
+ OutString.append(1, static_cast<wchar_t>(0xdc00 + (Codepoint & 0x03ff)));
+ }
+ else if (Codepoint < 0xd800 || Codepoint >= 0xe000)
+ {
+ OutString.append(1, static_cast<wchar_t>(Codepoint));
+ }
+ }
+ }
+
+ return OutString;
+}
+
+namespace zen {
+
+using namespace std::literals;
+
+class HttpSysServer;
+class HttpSysTransaction;
+class HttpMessageResponseRequest;
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpVerb
+TranslateHttpVerb(HTTP_VERB ReqVerb)
+{
+ switch (ReqVerb)
+ {
+ case HttpVerbOPTIONS:
+ return HttpVerb::kOptions;
+
+ case HttpVerbGET:
+ return HttpVerb::kGet;
+
+ case HttpVerbHEAD:
+ return HttpVerb::kHead;
+
+ case HttpVerbPOST:
+ return HttpVerb::kPost;
+
+ case HttpVerbPUT:
+ return HttpVerb::kPut;
+
+ case HttpVerbDELETE:
+ return HttpVerb::kDelete;
+
+ case HttpVerbCOPY:
+ return HttpVerb::kCopy;
+
+ default:
+ // TODO: invalid request?
+ return (HttpVerb)0;
+ }
+}
+
+uint64_t
+GetContentLength(const HTTP_REQUEST* HttpRequest)
+{
+ const HTTP_KNOWN_HEADER& clh = HttpRequest->Headers.KnownHeaders[HttpHeaderContentLength];
+ std::string_view cl(clh.pRawValue, clh.RawValueLength);
+ uint64_t ContentLength = 0;
+ std::from_chars(cl.data(), cl.data() + cl.size(), ContentLength);
+ return ContentLength;
+};
+
+HttpContentType
+GetContentType(const HTTP_REQUEST* HttpRequest)
+{
+ const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderContentType];
+ return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength});
+};
+
+HttpContentType
+GetAcceptType(const HTTP_REQUEST* HttpRequest)
+{
+ const HTTP_KNOWN_HEADER& CtHdr = HttpRequest->Headers.KnownHeaders[HttpHeaderAccept];
+ return ParseContentType({CtHdr.pRawValue, CtHdr.RawValueLength});
+};
+
+/**
+ * @brief Base class for any pending or active HTTP transactions
+ */
+class HttpSysRequestHandler
+{
+public:
+ explicit HttpSysRequestHandler(HttpSysTransaction& Transaction) : m_Transaction(Transaction) {}
+ virtual ~HttpSysRequestHandler() = default;
+
+ virtual void IssueRequest(std::error_code& ErrorCode) = 0;
+ virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) = 0;
+ HttpSysTransaction& Transaction() { return m_Transaction; }
+
+ HttpSysRequestHandler(const HttpSysRequestHandler&) = delete;
+ HttpSysRequestHandler& operator=(const HttpSysRequestHandler&) = delete;
+
+private:
+ HttpSysTransaction& m_Transaction;
+};
+
+/**
+ * This is the handler for the initial HTTP I/O request which will receive the headers
+ * and however much of the remaining payload might fit in the embedded request buffer.
+ *
+ * It is also used to receive any entity body data relating to the request
+ *
+ */
+struct InitialRequestHandler : public HttpSysRequestHandler
+{
+ inline HTTP_REQUEST* HttpRequest() { return (HTTP_REQUEST*)m_RequestBuffer; }
+ inline uint32_t RequestBufferSize() const { return sizeof m_RequestBuffer; }
+ inline bool IsInitialRequest() const { return m_IsInitialRequest; }
+
+ InitialRequestHandler(HttpSysTransaction& InRequest);
+ ~InitialRequestHandler();
+
+ virtual void IssueRequest(std::error_code& ErrorCode) override final;
+ virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
+
+ bool m_IsInitialRequest = true;
+ uint64_t m_CurrentPayloadOffset = 0;
+ uint64_t m_ContentLength = ~uint64_t(0);
+ IoBuffer m_PayloadBuffer;
+ UCHAR m_RequestBuffer[4096 + sizeof(HTTP_REQUEST)];
+};
+
+/**
+ * This is the class which request handlers use to interact with the server instance
+ */
+
+class HttpSysServerRequest : public HttpServerRequest
+{
+public:
+ HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer);
+ ~HttpSysServerRequest() = default;
+
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
+
+ virtual IoBuffer ReadPayload() override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override;
+ virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override;
+ virtual bool TryGetRanges(HttpRanges& Ranges) override;
+
+ using HttpServerRequest::WriteResponse;
+
+ HttpSysServerRequest(const HttpSysServerRequest&) = delete;
+ HttpSysServerRequest& operator=(const HttpSysServerRequest&) = delete;
+
+ HttpSysTransaction& m_HttpTx;
+ HttpSysRequestHandler* m_NextCompletionHandler = nullptr;
+ IoBuffer m_PayloadBuffer;
+ ExtendableStringBuilder<128> m_UriUtf8;
+ ExtendableStringBuilder<128> m_QueryStringUtf8;
+};
+
+/** HTTP transaction
+
+ There will be an instance of this per pending and in-flight HTTP transaction
+
+ */
+class HttpSysTransaction final
+{
+public:
+ HttpSysTransaction(HttpSysServer& Server);
+ virtual ~HttpSysTransaction();
+
+ enum class Status
+ {
+ kDone,
+ kRequestPending
+ };
+
+ Status HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred);
+
+ static void __stdcall IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
+ PVOID pContext /* HttpSysServer */,
+ PVOID pOverlapped,
+ ULONG IoResult,
+ ULONG_PTR NumberOfBytesTransferred,
+ PTP_IO Io);
+
+ void IssueInitialRequest(std::error_code& ErrorCode);
+ bool IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler);
+
+ PTP_IO Iocp();
+ HANDLE RequestQueueHandle();
+ inline OVERLAPPED* Overlapped() { return &m_HttpOverlapped; }
+ inline HttpSysServer& Server() { return m_HttpServer; }
+ inline HTTP_REQUEST* HttpRequest() { return m_InitialHttpHandler.HttpRequest(); }
+
+ HttpSysServerRequest& InvokeRequestHandler(HttpService& Service, IoBuffer Payload);
+
+ HttpSysServerRequest& ServerRequest() { return m_HandlerRequest.value(); }
+
+private:
+ OVERLAPPED m_HttpOverlapped{};
+ HttpSysServer& m_HttpServer;
+
+ // Tracks which handler is due to handle the next I/O completion event
+ HttpSysRequestHandler* m_CompletionHandler = nullptr;
+ RwLock m_CompletionMutex;
+ InitialRequestHandler m_InitialHttpHandler{*this};
+ std::optional<HttpSysServerRequest> m_HandlerRequest;
+ Ref<IHttpPackageHandler> m_PackageHandler;
+};
+
+/**
+ * @brief HTTP request response I/O request handler
+ *
+ * Asynchronously streams out a response to an HTTP request via compound
+ * responses from memory or directly from file
+ */
+
+class HttpMessageResponseRequest : public HttpSysRequestHandler
+{
+public:
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode);
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message);
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ const void* Payload,
+ size_t PayloadSize);
+ HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ std::span<IoBuffer> Blobs);
+ ~HttpMessageResponseRequest();
+
+ virtual void IssueRequest(std::error_code& ErrorCode) override final;
+ virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
+ void SuppressResponseBody(); // typically used for HEAD requests
+
+private:
+ std::vector<HTTP_DATA_CHUNK> m_HttpDataChunks;
+ uint64_t m_TotalDataSize = 0; // Sum of all chunk sizes
+ uint16_t m_ResponseCode = 0;
+ uint32_t m_NextDataChunkOffset = 0; // Cursor used for very large chunk lists
+ uint32_t m_RemainingChunkCount = 0; // Backlog for multi-call sends
+ bool m_IsInitialResponse = true;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ std::vector<IoBuffer> m_DataBuffers;
+
+ void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> Blobs);
+};
+
+HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode)
+: HttpSysRequestHandler(InRequest)
+{
+ std::array<IoBuffer, 0> EmptyBufferList;
+
+ InitializeForPayload(ResponseCode, EmptyBufferList);
+}
+
+HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest, uint16_t ResponseCode, std::string_view Message)
+: HttpSysRequestHandler(InRequest)
+, m_ContentType(HttpContentType::kText)
+{
+ IoBuffer MessageBuffer(IoBuffer::Wrap, Message.data(), Message.size());
+ std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
+
+ InitializeForPayload(ResponseCode, SingleBufferList);
+}
+
+HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ const void* Payload,
+ size_t PayloadSize)
+: HttpSysRequestHandler(InRequest)
+, m_ContentType(ContentType)
+{
+ IoBuffer MessageBuffer(IoBuffer::Wrap, Payload, PayloadSize);
+ std::array<IoBuffer, 1> SingleBufferList({MessageBuffer});
+
+ InitializeForPayload(ResponseCode, SingleBufferList);
+}
+
+HttpMessageResponseRequest::HttpMessageResponseRequest(HttpSysTransaction& InRequest,
+ uint16_t ResponseCode,
+ HttpContentType ContentType,
+ std::span<IoBuffer> BlobList)
+: HttpSysRequestHandler(InRequest)
+, m_ContentType(ContentType)
+{
+ InitializeForPayload(ResponseCode, BlobList);
+}
+
+HttpMessageResponseRequest::~HttpMessageResponseRequest()
+{
+}
+
+void
+HttpMessageResponseRequest::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList)
+{
+ const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size());
+
+ m_HttpDataChunks.reserve(ChunkCount);
+ m_DataBuffers.reserve(ChunkCount);
+
+ for (IoBuffer& Buffer : BlobList)
+ {
+ m_DataBuffers.emplace_back(std::move(Buffer)).MakeOwned();
+ }
+
+ // Initialize the full array up front
+
+ uint64_t LocalDataSize = 0;
+
+ for (IoBuffer& Buffer : m_DataBuffers)
+ {
+ uint64_t BufferDataSize = Buffer.Size();
+
+ ZEN_ASSERT(BufferDataSize);
+
+ LocalDataSize += BufferDataSize;
+
+ IoBufferFileReference FileRef;
+ if (Buffer.GetFileReference(/* out */ FileRef))
+ {
+ // Use direct file transfer
+
+ m_HttpDataChunks.push_back({});
+ auto& Chunk = m_HttpDataChunks.back();
+
+ Chunk.DataChunkType = HttpDataChunkFromFileHandle;
+ Chunk.FromFileHandle.FileHandle = FileRef.FileHandle;
+ Chunk.FromFileHandle.ByteRange.StartingOffset.QuadPart = FileRef.FileChunkOffset;
+ Chunk.FromFileHandle.ByteRange.Length.QuadPart = BufferDataSize;
+ }
+ else
+ {
+ // Send from memory, need to make sure we chunk the buffer up since
+ // the underlying data structure only accepts 32-bit chunk sizes for
+ // memory chunks. When this happens the vector will be reallocated,
+ // which is fine since this will be a pretty rare case and sending
+ // the data is going to take a lot longer than a memory allocation :)
+
+ const uint8_t* WriteCursor = reinterpret_cast<const uint8_t*>(Buffer.Data());
+
+ while (BufferDataSize)
+ {
+ const ULONG ThisChunkSize = gsl::narrow<ULONG>(zen::Min(1 * 1024 * 1024 * 1024, BufferDataSize));
+
+ m_HttpDataChunks.push_back({});
+ auto& Chunk = m_HttpDataChunks.back();
+
+ Chunk.DataChunkType = HttpDataChunkFromMemory;
+ Chunk.FromMemory.pBuffer = (void*)WriteCursor;
+ Chunk.FromMemory.BufferLength = ThisChunkSize;
+
+ BufferDataSize -= ThisChunkSize;
+ WriteCursor += ThisChunkSize;
+ }
+ }
+ }
+
+ m_RemainingChunkCount = gsl::narrow<uint32_t>(m_HttpDataChunks.size());
+ m_TotalDataSize = LocalDataSize;
+
+ if (m_TotalDataSize == 0 && ResponseCode == 200)
+ {
+ // Some HTTP clients really don't like empty responses unless a 204 response is sent
+ m_ResponseCode = uint16_t(HttpResponseCode::NoContent);
+ }
+ else
+ {
+ m_ResponseCode = ResponseCode;
+ }
+}
+
+void
+HttpMessageResponseRequest::SuppressResponseBody()
+{
+ m_RemainingChunkCount = 0;
+ m_HttpDataChunks.clear();
+ m_DataBuffers.clear();
+}
+
+HttpSysRequestHandler*
+HttpMessageResponseRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ ZEN_UNUSED(NumberOfBytesTransferred);
+
+ if (IoResult != NO_ERROR)
+ {
+ ZEN_WARN("response aborted due to error: '{}'", GetSystemErrorAsString(IoResult));
+
+ // if one transmit failed there's really no need to go on
+ return nullptr;
+ }
+
+ if (m_RemainingChunkCount == 0)
+ {
+ return nullptr; // All done
+ }
+
+ return this;
+}
+
+void
+HttpMessageResponseRequest::IssueRequest(std::error_code& ErrorCode)
+{
+ HttpSysTransaction& Tx = Transaction();
+ HTTP_REQUEST* const HttpReq = Tx.HttpRequest();
+ PTP_IO const Iocp = Tx.Iocp();
+
+ StartThreadpoolIo(Iocp);
+
+ // Split payload into batches to play well with the underlying API
+
+ const int MaxChunksPerCall = 9999;
+
+ const int ThisRequestChunkCount = std::min<int>(m_RemainingChunkCount, MaxChunksPerCall);
+ const int ThisRequestChunkOffset = m_NextDataChunkOffset;
+
+ m_RemainingChunkCount -= ThisRequestChunkCount;
+ m_NextDataChunkOffset += ThisRequestChunkCount;
+
+ /* Should this code also use HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA?
+
+ From the docs:
+
+ This flag enables buffering of data in the kernel on a per-response basis. It should
+ be used by an application doing synchronous I/O, or by a an application doing
+ asynchronous I/O with no more than one send outstanding at a time.
+
+ Applications using asynchronous I/O which may have more than one send outstanding at
+ a time should not use this flag.
+
+ When this flag is set, it should be used consistently in calls to the
+ HttpSendHttpResponse function as well.
+ */
+
+ ULONG SendFlags = HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA;
+
+ if (m_RemainingChunkCount)
+ {
+ // We need to make more calls to send the full amount of data
+ SendFlags |= HTTP_SEND_RESPONSE_FLAG_MORE_DATA;
+ }
+
+ ULONG SendResult = 0;
+
+ if (m_IsInitialResponse)
+ {
+ // Populate response structure
+
+ HTTP_RESPONSE HttpResponse = {};
+
+ HttpResponse.EntityChunkCount = USHORT(ThisRequestChunkCount);
+ HttpResponse.pEntityChunks = m_HttpDataChunks.data() + ThisRequestChunkOffset;
+
+ // Server header
+ //
+ // By default this will also add a suffix " Microsoft-HTTPAPI/2.0" to this header
+ //
+ // This is controlled via a registry key 'DisableServerHeader', at:
+ //
+ // Computer\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\HTTP\Parameters
+ //
+ // Set DisableServerHeader to 1 to disable suffix, or 2 to disable the header altogether
+ // (only the latter appears to do anything in my testing, on Windows 10).
+ //
+ // (reference https://docs.microsoft.com/en-us/archive/blogs/dsnotes/wswcf-remove-server-header)
+ //
+
+ PHTTP_KNOWN_HEADER ServerHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderServer];
+ ServerHeader->pRawValue = "Zen";
+ ServerHeader->RawValueLength = (USHORT)3;
+
+ // Content-length header
+
+ char ContentLengthString[32];
+ _ui64toa_s(m_TotalDataSize, ContentLengthString, sizeof ContentLengthString, 10);
+
+ PHTTP_KNOWN_HEADER ContentLengthHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentLength];
+ ContentLengthHeader->pRawValue = ContentLengthString;
+ ContentLengthHeader->RawValueLength = (USHORT)strlen(ContentLengthString);
+
+ // Content-type header
+
+ PHTTP_KNOWN_HEADER ContentTypeHeader = &HttpResponse.Headers.KnownHeaders[HttpHeaderContentType];
+
+ std::string_view ContentTypeString = MapContentTypeToString(m_ContentType);
+
+ ContentTypeHeader->pRawValue = ContentTypeString.data();
+ ContentTypeHeader->RawValueLength = (USHORT)ContentTypeString.size();
+
+ std::string_view ReasonString = ReasonStringForHttpResultCode(m_ResponseCode);
+
+ HttpResponse.StatusCode = m_ResponseCode;
+ HttpResponse.pReason = ReasonString.data();
+ HttpResponse.ReasonLength = (USHORT)ReasonString.size();
+
+ // Cache policy
+
+ HTTP_CACHE_POLICY CachePolicy;
+
+ CachePolicy.Policy = HttpCachePolicyNocache; // HttpCachePolicyUserInvalidates;
+ CachePolicy.SecondsToLive = 0;
+
+ // Initial response API call
+
+ SendResult = HttpSendHttpResponse(Tx.RequestQueueHandle(),
+ HttpReq->RequestId,
+ SendFlags,
+ &HttpResponse,
+ &CachePolicy,
+ NULL,
+ NULL,
+ 0,
+ Tx.Overlapped(),
+ NULL);
+
+ m_IsInitialResponse = false;
+ }
+ else
+ {
+ // Subsequent response API calls
+
+ SendResult = HttpSendResponseEntityBody(Tx.RequestQueueHandle(),
+ HttpReq->RequestId,
+ SendFlags,
+ (USHORT)ThisRequestChunkCount, // EntityChunkCount
+ &m_HttpDataChunks[ThisRequestChunkOffset], // EntityChunks
+ NULL, // BytesSent
+ NULL, // Reserved1
+ 0, // Reserved2
+ Tx.Overlapped(), // Overlapped
+ NULL // LogData
+ );
+ }
+
+ if (SendResult == NO_ERROR)
+ {
+ // Synchronous completion, but the completion event will still be posted to IOCP
+
+ ErrorCode.clear();
+ }
+ else if (SendResult == ERROR_IO_PENDING)
+ {
+ // Asynchronous completion, a completion notification will be posted to IOCP
+
+ ErrorCode.clear();
+ }
+ else
+ {
+ // An error occurred, no completion will be posted to IOCP
+
+ CancelThreadpoolIo(Iocp);
+
+ ZEN_WARN("failed to send HTTP response (error: '{}'), request URL: '{}', request id: {}",
+ GetSystemErrorAsString(SendResult),
+ HttpReq->pRawUrl,
+ HttpReq->RequestId);
+
+ ErrorCode = MakeErrorCode(SendResult);
+ }
+}
+
+/** HTTP completion handler for async work
+
+ This is used to allow work to be taken off the request handler threads
+ and to support posting responses asynchronously.
+ */
+
+class HttpAsyncWorkRequest : public HttpSysRequestHandler
+{
+public:
+ HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response);
+ ~HttpAsyncWorkRequest();
+
+ virtual void IssueRequest(std::error_code& ErrorCode) override final;
+ virtual HttpSysRequestHandler* HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred) override;
+
+private:
+ struct AsyncWorkItem : public IWork
+ {
+ virtual void Execute() override;
+
+ AsyncWorkItem(HttpSysTransaction& InTx, std::function<void(HttpServerRequest&)>&& InHandler)
+ : Tx(InTx)
+ , Handler(std::move(InHandler))
+ {
+ }
+
+ HttpSysTransaction& Tx;
+ std::function<void(HttpServerRequest&)> Handler;
+ };
+
+ Ref<AsyncWorkItem> m_WorkItem;
+};
+
+HttpAsyncWorkRequest::HttpAsyncWorkRequest(HttpSysTransaction& Tx, std::function<void(HttpServerRequest&)>&& Response)
+: HttpSysRequestHandler(Tx)
+{
+ m_WorkItem = new AsyncWorkItem(Tx, std::move(Response));
+}
+
+HttpAsyncWorkRequest::~HttpAsyncWorkRequest()
+{
+}
+
+void
+HttpAsyncWorkRequest::IssueRequest(std::error_code& ErrorCode)
+{
+ ErrorCode.clear();
+
+ Transaction().Server().WorkPool().ScheduleWork(m_WorkItem);
+}
+
+HttpSysRequestHandler*
+HttpAsyncWorkRequest::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ // This ought to not be called since there should be no outstanding I/O request
+ // when this completion handler is active
+
+ ZEN_UNUSED(IoResult, NumberOfBytesTransferred);
+
+ ZEN_WARN("Unexpected I/O completion during async work! IoResult: {}, NumberOfBytesTransferred: {}", IoResult, NumberOfBytesTransferred);
+
+ return this;
+}
+
+void
+HttpAsyncWorkRequest::AsyncWorkItem::Execute()
+{
+ try
+ {
+ HttpSysServerRequest& ThisRequest = Tx.ServerRequest();
+
+ ThisRequest.m_NextCompletionHandler = nullptr;
+
+ Handler(ThisRequest);
+
+ // TODO: should Handler be destroyed at this point to ensure there
+ // are no outstanding references into state which could be
+ // deleted asynchronously as a result of issuing the response?
+
+ if (HttpSysRequestHandler* NextHandler = ThisRequest.m_NextCompletionHandler)
+ {
+ return (void)Tx.IssueNextRequest(NextHandler);
+ }
+ else if (!ThisRequest.IsHandled())
+ {
+ return (void)Tx.IssueNextRequest(new HttpMessageResponseRequest(Tx, 404, "Not found"sv));
+ }
+ else
+ {
+ // "Handled" but no request handler? Shouldn't ever happen
+ return (void)Tx.IssueNextRequest(
+ new HttpMessageResponseRequest(Tx, 500, "Response generated but no request handler scheduled"sv));
+ }
+ }
+ catch (std::exception& Ex)
+ {
+ return (void)Tx.IssueNextRequest(
+ new HttpMessageResponseRequest(Tx, 500, fmt::format("Exception thrown in async work: '{}'", Ex.what())));
+ }
+}
+
+/**
+ _________
+ / _____/ ______________ __ ___________
+ \_____ \_/ __ \_ __ \ \/ // __ \_ __ \
+ / \ ___/| | \/\ /\ ___/| | \/
+ /_______ /\___ >__| \_/ \___ >__|
+ \/ \/ \/
+*/
+
+HttpSysServer::HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount)
+: m_Log(logging::Get("http"))
+, m_RequestLog(logging::Get("http_requests"))
+, m_ThreadPool(ThreadCount)
+, m_AsyncWorkPool(AsyncWorkThreadCount)
+{
+ ULONG Result = HttpInitialize(HTTPAPI_VERSION_2, HTTP_INITIALIZE_SERVER, nullptr);
+
+ if (Result != NO_ERROR)
+ {
+ return;
+ }
+
+ m_IsHttpInitialized = true;
+ m_IsOk = true;
+
+ ZEN_INFO("http.sys server started, using {} I/O threads and {} async worker threads", ThreadCount, AsyncWorkThreadCount);
+}
+
+HttpSysServer::~HttpSysServer()
+{
+ if (m_IsHttpInitialized)
+ {
+ Cleanup();
+
+ HttpTerminate(HTTP_INITIALIZE_SERVER, nullptr);
+ }
+}
+
+int
+HttpSysServer::InitializeServer(int BasePort)
+{
+ using namespace std::literals;
+
+ WideStringBuilder<64> WildcardUrlPath;
+ WildcardUrlPath << u8"http://*:"sv << int64_t(BasePort) << u8"/"sv;
+
+ m_IsOk = false;
+
+ ULONG Result = HttpCreateServerSession(HTTPAPI_VERSION_2, &m_HttpSessionId, 0);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to create server session for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+
+ return BasePort;
+ }
+
+ Result = HttpCreateUrlGroup(m_HttpSessionId, &m_HttpUrlGroupId, 0);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to create URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+
+ return BasePort;
+ }
+
+ int EffectivePort = BasePort;
+
+ Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+
+ // Sharing violation implies the port is being used by another process
+ for (int PortOffset = 1; (Result == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
+ {
+ EffectivePort = BasePort + (PortOffset * 100);
+ WildcardUrlPath.Reset();
+ WildcardUrlPath << u8"http://*:"sv << int64_t(EffectivePort) << u8"/"sv;
+
+ Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, WildcardUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+ }
+
+ m_BaseUris.clear();
+ if (Result == NO_ERROR)
+ {
+ m_BaseUris.push_back(WildcardUrlPath.c_str());
+ }
+ else if (Result == ERROR_ACCESS_DENIED)
+ {
+ // If we can't register the wildcard path, we fall back to local paths
+ // This local paths allow requests originating locally to function, but will not allow
+ // remote origin requests to function. This can be remedied by using netsh
+ // during an install process to grant permissions to route public access to the appropriate
+ // port for the current user. eg:
+ // netsh http add urlacl url=http://*:1337/ user=<some_user>
+
+ ZEN_WARN("Unable to register handler using '{}' - falling back to local-only", WideToUtf8(WildcardUrlPath));
+
+ const std::u8string_view Hosts[] = {u8"[::1]"sv, u8"localhost"sv, u8"127.0.0.1"sv};
+
+ ULONG InternalResult = ERROR_SHARING_VIOLATION;
+ for (int PortOffset = 0; (InternalResult == ERROR_SHARING_VIOLATION) && (PortOffset < 10); ++PortOffset)
+ {
+ EffectivePort = BasePort + (PortOffset * 100);
+
+ for (const std::u8string_view Host : Hosts)
+ {
+ WideStringBuilder<64> LocalUrlPath;
+ LocalUrlPath << u8"http://"sv << Host << u8":"sv << int64_t(EffectivePort) << u8"/"sv;
+
+ InternalResult = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, LocalUrlPath.c_str(), HTTP_URL_CONTEXT(0), 0);
+
+ if (InternalResult == NO_ERROR)
+ {
+ ZEN_INFO("Registered local handler '{}'", WideToUtf8(LocalUrlPath));
+
+ m_BaseUris.push_back(LocalUrlPath.c_str());
+ }
+ else
+ {
+ break;
+ }
+ }
+ }
+ }
+
+ if (m_BaseUris.empty())
+ {
+ ZEN_ERROR("Failed to add base URL to URL group for '{}': {:#x}", WideToUtf8(WildcardUrlPath), Result);
+
+ return BasePort;
+ }
+
+ HTTP_BINDING_INFO HttpBindingInfo = {{0}, 0};
+
+ Result = HttpCreateRequestQueue(HTTPAPI_VERSION_2,
+ /* Name */ nullptr,
+ /* SecurityAttributes */ nullptr,
+ /* Flags */ 0,
+ &m_RequestQueueHandle);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to create request queue for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+
+ return EffectivePort;
+ }
+
+ HttpBindingInfo.Flags.Present = 1;
+ HttpBindingInfo.RequestQueueHandle = m_RequestQueueHandle;
+
+ Result = HttpSetUrlGroupProperty(m_HttpUrlGroupId, HttpServerBindingProperty, &HttpBindingInfo, sizeof(HttpBindingInfo));
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("Failed to set server binding property for '{}': {:#x}", WideToUtf8(m_BaseUris.front()), Result);
+
+ return EffectivePort;
+ }
+
+ // Create I/O completion port
+
+ std::error_code ErrorCode;
+ m_ThreadPool.CreateIocp(m_RequestQueueHandle, HttpSysTransaction::IoCompletionCallback, /* Context */ this, /* out */ ErrorCode);
+
+ if (ErrorCode)
+ {
+ ZEN_ERROR("Failed to create IOCP for '{}': {}", WideToUtf8(m_BaseUris.front()), ErrorCode.message());
+ }
+ else
+ {
+ m_IsOk = true;
+
+ ZEN_INFO("Started http.sys server at '{}'", WideToUtf8(m_BaseUris.front()));
+ }
+
+ return EffectivePort;
+}
+
+void
+HttpSysServer::Cleanup()
+{
+ ++m_IsShuttingDown;
+
+ if (m_RequestQueueHandle)
+ {
+ HttpCloseRequestQueue(m_RequestQueueHandle);
+ m_RequestQueueHandle = nullptr;
+ }
+
+ if (m_HttpUrlGroupId)
+ {
+ HttpCloseUrlGroup(m_HttpUrlGroupId);
+ m_HttpUrlGroupId = 0;
+ }
+
+ if (m_HttpSessionId)
+ {
+ HttpCloseServerSession(m_HttpSessionId);
+ m_HttpSessionId = 0;
+ }
+}
+
+void
+HttpSysServer::StartServer()
+{
+ const int InitialRequestCount = 32;
+
+ for (int i = 0; i < InitialRequestCount; ++i)
+ {
+ IssueNewRequestMaybe();
+ }
+}
+
+void
+HttpSysServer::Run(bool IsInteractive)
+{
+ if (IsInteractive)
+ {
+ zen::logging::ConsoleLog().info("Zen Server running. Press ESC or Q to quit");
+ }
+
+ do
+ {
+ // int WaitTimeout = -1;
+ int WaitTimeout = 100;
+
+ if (IsInteractive)
+ {
+ WaitTimeout = 1000;
+
+ if (_kbhit() != 0)
+ {
+ char c = (char)_getch();
+
+ if (c == 27 || c == 'Q' || c == 'q')
+ {
+ RequestApplicationExit(0);
+ }
+ }
+ }
+
+ m_ShutdownEvent.Wait(WaitTimeout);
+ UpdateLofreqTimerValue();
+ } while (!IsApplicationExitRequested());
+}
+
+void
+HttpSysServer::OnHandlingRequest()
+{
+ if (--m_PendingRequests > m_MinPendingRequests)
+ {
+ // We have more than the minimum number of requests pending, just let someone else
+ // enqueue new requests
+ return;
+ }
+
+ IssueNewRequestMaybe();
+}
+
+void
+HttpSysServer::IssueNewRequestMaybe()
+{
+ if (m_IsShuttingDown.load(std::memory_order::acquire))
+ {
+ return;
+ }
+
+ if (m_PendingRequests.load(std::memory_order::relaxed) >= m_MaxPendingRequests)
+ {
+ return;
+ }
+
+ std::unique_ptr<HttpSysTransaction> Request = std::make_unique<HttpSysTransaction>(*this);
+
+ std::error_code ErrorCode;
+ Request->IssueInitialRequest(ErrorCode);
+
+ if (ErrorCode)
+ {
+ // No request was actually issued. What is the appropriate response?
+
+ return;
+ }
+
+ // This may end up exceeding the MaxPendingRequests limit, but it's not
+ // really a problem. I'm doing it this way mostly to avoid dealing with
+ // exceptions here
+ ++m_PendingRequests;
+
+ Request.release();
+}
+
+void
+HttpSysServer::RegisterService(const char* UrlPath, HttpService& Service)
+{
+ if (UrlPath[0] == '/')
+ {
+ ++UrlPath;
+ }
+
+ const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
+ Service.SetUriPrefixLength(PathUtf16.size() + 1 /* leading slash */);
+
+ // Convert to wide string
+
+ for (const std::wstring& BaseUri : m_BaseUris)
+ {
+ std::wstring Url16 = BaseUri + PathUtf16;
+
+ ULONG Result = HttpAddUrlToUrlGroup(m_HttpUrlGroupId, Url16.c_str(), HTTP_URL_CONTEXT(&Service), 0 /* Reserved */);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpAddUrlToUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+
+ return;
+ }
+ }
+}
+
+void
+HttpSysServer::UnregisterService(const char* UrlPath, HttpService& Service)
+{
+ ZEN_UNUSED(Service);
+
+ if (UrlPath[0] == '/')
+ {
+ ++UrlPath;
+ }
+
+ const std::wstring PathUtf16 = UTF8_to_UTF16(UrlPath);
+
+ // Convert to wide string
+
+ for (const std::wstring& BaseUri : m_BaseUris)
+ {
+ std::wstring Url16 = BaseUri + PathUtf16;
+
+ ULONG Result = HttpRemoveUrlFromUrlGroup(m_HttpUrlGroupId, Url16.c_str(), 0);
+
+ if (Result != NO_ERROR)
+ {
+ ZEN_ERROR("HttpRemoveUrlFromUrlGroup failed with result: '{}'", GetSystemErrorAsString(Result));
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpSysTransaction::HttpSysTransaction(HttpSysServer& Server) : m_HttpServer(Server), m_CompletionHandler(&m_InitialHttpHandler)
+{
+}
+
+HttpSysTransaction::~HttpSysTransaction()
+{
+}
+
+PTP_IO
+HttpSysTransaction::Iocp()
+{
+ return m_HttpServer.m_ThreadPool.Iocp();
+}
+
+HANDLE
+HttpSysTransaction::RequestQueueHandle()
+{
+ return m_HttpServer.m_RequestQueueHandle;
+}
+
+void
+HttpSysTransaction::IssueInitialRequest(std::error_code& ErrorCode)
+{
+ m_InitialHttpHandler.IssueRequest(ErrorCode);
+}
+
+void
+HttpSysTransaction::IoCompletionCallback(PTP_CALLBACK_INSTANCE Instance,
+ PVOID pContext /* HttpSysServer */,
+ PVOID pOverlapped,
+ ULONG IoResult,
+ ULONG_PTR NumberOfBytesTransferred,
+ PTP_IO Io)
+{
+ UNREFERENCED_PARAMETER(Io);
+ UNREFERENCED_PARAMETER(Instance);
+ UNREFERENCED_PARAMETER(pContext);
+
+ // Note that for a given transaction we may be in this completion function on more
+ // than one thread at any given moment. This means we need to be careful about what
+ // happens in here
+
+ HttpSysTransaction* Transaction = CONTAINING_RECORD(pOverlapped, HttpSysTransaction, m_HttpOverlapped);
+
+ if (Transaction->HandleCompletion(IoResult, NumberOfBytesTransferred) == HttpSysTransaction::Status::kDone)
+ {
+ delete Transaction;
+ }
+}
+
+bool
+HttpSysTransaction::IssueNextRequest(HttpSysRequestHandler* NewCompletionHandler)
+{
+ HttpSysRequestHandler* CurrentHandler = m_CompletionHandler;
+ m_CompletionHandler = NewCompletionHandler;
+
+ auto _ = MakeGuard([this, CurrentHandler] {
+ if ((CurrentHandler != &m_InitialHttpHandler) && (CurrentHandler != m_CompletionHandler))
+ {
+ delete CurrentHandler;
+ }
+ });
+
+ if (NewCompletionHandler == nullptr)
+ {
+ return false;
+ }
+
+ try
+ {
+ std::error_code ErrorCode;
+ m_CompletionHandler->IssueRequest(ErrorCode);
+
+ if (!ErrorCode)
+ {
+ return true;
+ }
+
+ ZEN_WARN("IssueRequest() failed: '{}'", ErrorCode.message());
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_ERROR("exception caught in IssueNextRequest(): '{}'", Ex.what());
+ }
+
+ // something went wrong, no request is pending
+ m_CompletionHandler = nullptr;
+
+ return false;
+}
+
+HttpSysTransaction::Status
+HttpSysTransaction::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ // We use this to ensure sequential execution of completion handlers
+ // for any given transaction. It also ensures all member variables are
+ // in a consistent state for the current thread
+
+ RwLock::ExclusiveLockScope _(m_CompletionMutex);
+
+ bool IsRequestPending = false;
+
+ if (HttpSysRequestHandler* CurrentHandler = m_CompletionHandler)
+ {
+ if ((CurrentHandler == &m_InitialHttpHandler) && m_InitialHttpHandler.IsInitialRequest())
+ {
+ // Ensure we have a sufficient number of pending requests outstanding
+ m_HttpServer.OnHandlingRequest();
+ }
+
+ auto NewCompletionHandler = CurrentHandler->HandleCompletion(IoResult, NumberOfBytesTransferred);
+
+ IsRequestPending = IssueNextRequest(NewCompletionHandler);
+ }
+
+ // Ensure new requests are enqueued as necessary
+ m_HttpServer.IssueNewRequestMaybe();
+
+ if (IsRequestPending)
+ {
+ // There is another request pending on this transaction, so it needs to remain valid
+ return Status::kRequestPending;
+ }
+
+ if (m_HttpServer.m_IsRequestLoggingEnabled)
+ {
+ if (m_HandlerRequest.has_value())
+ {
+ m_HttpServer.m_RequestLog.info("{} {}", ToString(m_HandlerRequest->RequestVerb()), m_HandlerRequest->RelativeUri());
+ }
+ }
+
+ // Transaction done, caller should clean up (delete) this instance
+ return Status::kDone;
+}
+
+HttpSysServerRequest&
+HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
+{
+ HttpSysServerRequest& ThisRequest = m_HandlerRequest.emplace(*this, Service, Payload);
+
+ // Default request handling
+
+ if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ {
+ Service.HandleRequest(ThisRequest);
+ }
+
+ return ThisRequest;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpSysServerRequest::HttpSysServerRequest(HttpSysTransaction& Tx, HttpService& Service, IoBuffer PayloadBuffer)
+: m_HttpTx(Tx)
+, m_PayloadBuffer(std::move(PayloadBuffer))
+{
+ const HTTP_REQUEST* HttpRequestPtr = Tx.HttpRequest();
+
+ const int PrefixLength = Service.UriPrefixLength();
+ const int AbsPathLength = HttpRequestPtr->CookedUrl.AbsPathLength / sizeof(wchar_t);
+
+ HttpContentType AcceptContentType = HttpContentType::kUnknownContentType;
+
+ if (AbsPathLength >= PrefixLength)
+ {
+ // We convert the URI immediately because most of the code involved prefers to deal
+ // with utf8. This is overhead which I'd prefer to avoid but for now we just have
+ // to live with it
+
+ WideToUtf8({(wchar_t*)HttpRequestPtr->CookedUrl.pAbsPath + PrefixLength, gsl::narrow<size_t>(AbsPathLength - PrefixLength)},
+ m_UriUtf8);
+
+ std::string_view UriSuffix8{m_UriUtf8};
+
+ m_UriWithExtension = UriSuffix8; // Retain URI with extension for user access
+ m_Uri = UriSuffix8;
+
+ const size_t LastComponentIndex = UriSuffix8.find_last_of('/');
+
+ if (LastComponentIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastComponentIndex);
+ }
+
+ const size_t LastDotIndex = UriSuffix8.find_last_of('.');
+
+ if (LastDotIndex != std::string_view::npos)
+ {
+ UriSuffix8.remove_prefix(LastDotIndex + 1);
+
+ AcceptContentType = ParseContentType(UriSuffix8);
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_Uri.remove_suffix(UriSuffix8.size() + 1);
+ }
+ }
+ }
+ else
+ {
+ m_UriUtf8.Reset();
+ m_Uri = {};
+ m_UriWithExtension = {};
+ }
+
+ if (uint16_t QueryStringLength = HttpRequestPtr->CookedUrl.QueryStringLength)
+ {
+ --QueryStringLength; // We skip the leading question mark
+
+ WideToUtf8({(wchar_t*)(HttpRequestPtr->CookedUrl.pQueryString) + 1, QueryStringLength / sizeof(wchar_t)}, m_QueryStringUtf8);
+ }
+ else
+ {
+ m_QueryStringUtf8.Reset();
+ }
+
+ m_QueryString = std::string_view(m_QueryStringUtf8);
+ m_Verb = TranslateHttpVerb(HttpRequestPtr->Verb);
+ m_ContentLength = GetContentLength(HttpRequestPtr);
+ m_ContentType = GetContentType(HttpRequestPtr);
+
+ // It an explicit content type extension was specified then we'll use that over any
+ // Accept: header value that may be present
+
+ if (AcceptContentType != HttpContentType::kUnknownContentType)
+ {
+ m_AcceptType = AcceptContentType;
+ }
+ else
+ {
+ m_AcceptType = GetAcceptType(HttpRequestPtr);
+ }
+
+ if (m_Verb == HttpVerb::kHead)
+ {
+ SetSuppressResponseBody();
+ }
+}
+
+Oid
+HttpSysServerRequest::ParseSessionId() const
+{
+ const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
+
+ for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i)
+ {
+ HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i];
+ std::string_view HeaderName{Header.pName, Header.NameLength};
+
+ if (HeaderName == "UE-Session"sv)
+ {
+ if (Header.RawValueLength == Oid::StringLength)
+ {
+ return Oid::FromHexString({Header.pRawValue, Header.RawValueLength});
+ }
+ }
+ }
+
+ return {};
+}
+
+uint32_t
+HttpSysServerRequest::ParseRequestId() const
+{
+ const HTTP_REQUEST* HttpRequestPtr = m_HttpTx.HttpRequest();
+
+ for (int i = 0; i < HttpRequestPtr->Headers.UnknownHeaderCount; ++i)
+ {
+ HTTP_UNKNOWN_HEADER& Header = HttpRequestPtr->Headers.pUnknownHeaders[i];
+ std::string_view HeaderName{Header.pName, Header.NameLength};
+
+ if (HeaderName == "UE-Request"sv)
+ {
+ std::string_view RequestValue{Header.pRawValue, Header.RawValueLength};
+ uint32_t RequestId = 0;
+ std::from_chars(RequestValue.data(), RequestValue.data() + RequestValue.size(), RequestId);
+ return RequestId;
+ }
+ }
+
+ return 0;
+}
+
+IoBuffer
+HttpSysServerRequest::ReadPayload()
+{
+ return m_PayloadBuffer;
+}
+
+void
+HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode)
+{
+ ZEN_ASSERT(IsHandled() == false);
+
+ auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode);
+
+ if (SuppressBody())
+ {
+ Response->SuppressResponseBody();
+ }
+
+ m_NextCompletionHandler = Response;
+
+ SetIsHandled();
+}
+
+void
+HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs)
+{
+ ZEN_ASSERT(IsHandled() == false);
+
+ auto Response = new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, Blobs);
+
+ if (SuppressBody())
+ {
+ Response->SuppressResponseBody();
+ }
+
+ m_NextCompletionHandler = Response;
+
+ SetIsHandled();
+}
+
+void
+HttpSysServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString)
+{
+ ZEN_ASSERT(IsHandled() == false);
+
+ auto Response =
+ new HttpMessageResponseRequest(m_HttpTx, (uint16_t)ResponseCode, ContentType, ResponseString.data(), ResponseString.size());
+
+ if (SuppressBody())
+ {
+ Response->SuppressResponseBody();
+ }
+
+ m_NextCompletionHandler = Response;
+
+ SetIsHandled();
+}
+
+void
+HttpSysServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler)
+{
+ if (m_HttpTx.Server().IsAsyncResponseEnabled())
+ {
+ m_NextCompletionHandler = new HttpAsyncWorkRequest(m_HttpTx, std::move(ContinuationHandler));
+ }
+ else
+ {
+ ContinuationHandler(m_HttpTx.ServerRequest());
+ }
+}
+
+bool
+HttpSysServerRequest::TryGetRanges(HttpRanges& Ranges)
+{
+ HTTP_REQUEST* Req = m_HttpTx.HttpRequest();
+ const HTTP_KNOWN_HEADER& RangeHeader = Req->Headers.KnownHeaders[HttpHeaderRange];
+
+ return TryParseHttpRangeHeader({RangeHeader.pRawValue, RangeHeader.RawValueLength}, Ranges);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+InitialRequestHandler::InitialRequestHandler(HttpSysTransaction& InRequest) : HttpSysRequestHandler(InRequest)
+{
+}
+
+InitialRequestHandler::~InitialRequestHandler()
+{
+}
+
+void
+InitialRequestHandler::IssueRequest(std::error_code& ErrorCode)
+{
+ HttpSysTransaction& Tx = Transaction();
+ PTP_IO Iocp = Tx.Iocp();
+ HTTP_REQUEST* HttpReq = Tx.HttpRequest();
+
+ StartThreadpoolIo(Iocp);
+
+ ULONG HttpApiResult;
+
+ if (IsInitialRequest())
+ {
+ HttpApiResult = HttpReceiveHttpRequest(Tx.RequestQueueHandle(),
+ HTTP_NULL_ID,
+ HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY,
+ HttpReq,
+ RequestBufferSize(),
+ NULL,
+ Tx.Overlapped());
+ }
+ else
+ {
+ // The http.sys team recommends limiting the size to 128KB
+ static const uint64_t kMaxBytesPerApiCall = 128 * 1024;
+
+ uint64_t BytesToRead = m_ContentLength - m_CurrentPayloadOffset;
+ const uint64_t BytesToReadThisCall = zen::Min(BytesToRead, kMaxBytesPerApiCall);
+ void* BufferWriteCursor = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData()) + m_CurrentPayloadOffset;
+
+ HttpApiResult = HttpReceiveRequestEntityBody(Tx.RequestQueueHandle(),
+ HttpReq->RequestId,
+ 0, /* Flags */
+ BufferWriteCursor,
+ gsl::narrow<ULONG>(BytesToReadThisCall),
+ nullptr, // BytesReturned
+ Tx.Overlapped());
+ }
+
+ if (HttpApiResult != ERROR_IO_PENDING && HttpApiResult != NO_ERROR)
+ {
+ CancelThreadpoolIo(Iocp);
+
+ ErrorCode = MakeErrorCode(HttpApiResult);
+
+ ZEN_WARN("HttpReceiveHttpRequest failed, error: '{}'", ErrorCode.message());
+
+ return;
+ }
+
+ ErrorCode.clear();
+}
+
+HttpSysRequestHandler*
+InitialRequestHandler::HandleCompletion(ULONG IoResult, ULONG_PTR NumberOfBytesTransferred)
+{
+ auto _ = MakeGuard([&] { m_IsInitialRequest = false; });
+
+ switch (IoResult)
+ {
+ default:
+ case ERROR_OPERATION_ABORTED:
+ return nullptr;
+
+ case ERROR_MORE_DATA: // Insufficient buffer space
+ case NO_ERROR:
+ break;
+ }
+
+ // Route request
+
+ try
+ {
+ HTTP_REQUEST* HttpReq = HttpRequest();
+
+# if 0
+ for (int i = 0; i < HttpReq->RequestInfoCount; ++i)
+ {
+ auto& ReqInfo = HttpReq->pRequestInfo[i];
+
+ switch (ReqInfo.InfoType)
+ {
+ case HttpRequestInfoTypeRequestTiming:
+ {
+ const HTTP_REQUEST_TIMING_INFO* TimingInfo = reinterpret_cast<HTTP_REQUEST_TIMING_INFO*>(ReqInfo.pInfo);
+
+ ZEN_INFO("");
+ }
+ break;
+ case HttpRequestInfoTypeAuth:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeChannelBind:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeSslProtocol:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeSslTokenBindingDraft:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeSslTokenBinding:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeTcpInfoV0:
+ {
+ const TCP_INFO_v0* TcpInfo = reinterpret_cast<const TCP_INFO_v0*>(ReqInfo.pInfo);
+
+ ZEN_INFO("");
+ }
+ break;
+ case HttpRequestInfoTypeRequestSizing:
+ {
+ const HTTP_REQUEST_SIZING_INFO* SizingInfo = reinterpret_cast<const HTTP_REQUEST_SIZING_INFO*>(ReqInfo.pInfo);
+ ZEN_INFO("");
+ }
+ break;
+ case HttpRequestInfoTypeQuicStats:
+ ZEN_INFO("");
+ break;
+ case HttpRequestInfoTypeTcpInfoV1:
+ {
+ const TCP_INFO_v1* TcpInfo = reinterpret_cast<const TCP_INFO_v1*>(ReqInfo.pInfo);
+
+ ZEN_INFO("");
+ }
+ break;
+ }
+ }
+# endif
+
+ if (HttpService* Service = reinterpret_cast<HttpService*>(HttpReq->UrlContext))
+ {
+ if (m_IsInitialRequest)
+ {
+ m_ContentLength = GetContentLength(HttpReq);
+ const HttpContentType ContentType = GetContentType(HttpReq);
+
+ if (m_ContentLength)
+ {
+ // Handle initial chunk read by copying any payload which has already been copied
+ // into our embedded request buffer
+
+ m_PayloadBuffer = IoBuffer(m_ContentLength);
+ m_PayloadBuffer.SetContentType(ContentType);
+
+ uint64_t BytesToRead = m_ContentLength;
+ uint8_t* const BufferBase = reinterpret_cast<uint8_t*>(m_PayloadBuffer.MutableData());
+ uint8_t* BufferWriteCursor = BufferBase;
+
+ const int EntityChunkCount = HttpReq->EntityChunkCount;
+
+ for (int i = 0; i < EntityChunkCount; ++i)
+ {
+ HTTP_DATA_CHUNK& EntityChunk = HttpReq->pEntityChunks[i];
+
+ ZEN_ASSERT(EntityChunk.DataChunkType == HttpDataChunkFromMemory);
+
+ const uint64_t BufferLength = EntityChunk.FromMemory.BufferLength;
+
+ ZEN_ASSERT(BufferLength <= BytesToRead);
+
+ memcpy(BufferWriteCursor, EntityChunk.FromMemory.pBuffer, BufferLength);
+
+ BufferWriteCursor += BufferLength;
+ BytesToRead -= BufferLength;
+ }
+
+ m_CurrentPayloadOffset = BufferWriteCursor - BufferBase;
+ }
+ }
+ else
+ {
+ m_CurrentPayloadOffset += NumberOfBytesTransferred;
+ }
+
+ if (m_CurrentPayloadOffset != m_ContentLength)
+ {
+ // Body not complete, issue another read request to receive more body data
+ return this;
+ }
+
+ // Request body received completely
+
+ m_PayloadBuffer.MakeImmutable();
+
+ HttpSysServerRequest& ThisRequest = Transaction().InvokeRequestHandler(*Service, m_PayloadBuffer);
+
+ if (HttpSysRequestHandler* Response = ThisRequest.m_NextCompletionHandler)
+ {
+ return Response;
+ }
+
+ if (!ThisRequest.IsHandled())
+ {
+ return new HttpMessageResponseRequest(Transaction(), 404, "Not found"sv);
+ }
+ }
+
+ // Unable to route
+ return new HttpMessageResponseRequest(Transaction(), 404, "No suitable route found"sv);
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_ERROR("Caught exception while handling request: '{}'", ex.what());
+
+ return new HttpMessageResponseRequest(Transaction(), 500, ex.what());
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// HttpServer interface implementation
+//
+
+int
+HttpSysServer::Initialize(int BasePort)
+{
+ int EffectivePort = InitializeServer(BasePort);
+ StartServer();
+ return EffectivePort;
+}
+
+void
+HttpSysServer::RequestExit()
+{
+ m_ShutdownEvent.Set();
+}
+void
+HttpSysServer::RegisterService(HttpService& Service)
+{
+ RegisterService(Service.BaseUri(), Service);
+}
+
+Ref<HttpServer>
+CreateHttpSysServer(int Concurrency, int BackgroundWorkerThreads)
+{
+ return Ref<HttpServer>(new HttpSysServer(Concurrency, BackgroundWorkerThreads));
+}
+
+} // namespace zen
+#endif
diff --git a/src/zenhttp/httpsys.h b/src/zenhttp/httpsys.h
new file mode 100644
index 000000000..d6bd34890
--- /dev/null
+++ b/src/zenhttp/httpsys.h
@@ -0,0 +1,90 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+
+#ifndef ZEN_WITH_HTTPSYS
+# if ZEN_PLATFORM_WINDOWS
+# define ZEN_WITH_HTTPSYS 1
+# else
+# define ZEN_WITH_HTTPSYS 0
+# endif
+#endif
+
+#if ZEN_WITH_HTTPSYS
+# define _WINSOCKAPI_
+# include <zencore/windows.h>
+# include <zencore/workthreadpool.h>
+# include "iothreadpool.h"
+
+# include <http.h>
+
+namespace spdlog {
+class logger;
+}
+
+namespace zen {
+
+/**
+ * @brief Windows implementation of HTTP server based on http.sys
+ *
+ * This requires elevation to function
+ */
+class HttpSysServer : public HttpServer
+{
+ friend class HttpSysTransaction;
+
+public:
+ explicit HttpSysServer(unsigned int ThreadCount, unsigned int AsyncWorkThreadCount);
+ ~HttpSysServer();
+
+ // HttpServer interface implementation
+
+ virtual int Initialize(int BasePort) override;
+ virtual void Run(bool TestMode) override;
+ virtual void RequestExit() override;
+ virtual void RegisterService(HttpService& Service) override;
+
+ WorkerThreadPool& WorkPool() { return m_AsyncWorkPool; }
+
+ inline bool IsOk() const { return m_IsOk; }
+ inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; }
+
+private:
+ int InitializeServer(int BasePort);
+ void Cleanup();
+
+ void StartServer();
+ void OnHandlingRequest();
+ void IssueNewRequestMaybe();
+
+ void RegisterService(const char* Endpoint, HttpService& Service);
+ void UnregisterService(const char* Endpoint, HttpService& Service);
+
+private:
+ spdlog::logger& m_Log;
+ spdlog::logger& m_RequestLog;
+ spdlog::logger& Log() { return m_Log; }
+
+ bool m_IsOk = false;
+ bool m_IsHttpInitialized = false;
+ bool m_IsRequestLoggingEnabled = false;
+ bool m_IsAsyncResponseEnabled = true;
+
+ WinIoThreadPool m_ThreadPool;
+ WorkerThreadPool m_AsyncWorkPool;
+
+ std::vector<std::wstring> m_BaseUris; // eg: http://*:nnnn/
+ HTTP_SERVER_SESSION_ID m_HttpSessionId = 0;
+ HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0;
+ HANDLE m_RequestQueueHandle = 0;
+ std::atomic_int32_t m_PendingRequests{0};
+ std::atomic_int32_t m_IsShuttingDown{0};
+ int32_t m_MinPendingRequests = 16;
+ int32_t m_MaxPendingRequests = 128;
+ Event m_ShutdownEvent;
+};
+
+} // namespace zen
+#endif
diff --git a/src/zenhttp/include/zenhttp/httpclient.h b/src/zenhttp/include/zenhttp/httpclient.h
new file mode 100644
index 000000000..8316a9b9f
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpclient.h
@@ -0,0 +1,47 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenhttp.h"
+
+#include <zencore/iobuffer.h>
+#include <zencore/uid.h>
+#include <zenhttp/httpcommon.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class CbPackage;
+
+/** HTTP client implementation for Zen use cases
+
+ Currently simple and synchronous, should become lean and asynchronous
+ */
+class HttpClient
+{
+public:
+ HttpClient(std::string_view BaseUri);
+ ~HttpClient();
+
+ struct Response
+ {
+ int StatusCode = 0;
+ IoBuffer ResponsePayload; // Note: this also includes the content type
+ };
+
+ [[nodiscard]] Response Put(std::string_view Url, IoBuffer Payload);
+ [[nodiscard]] Response Get(std::string_view Url);
+ [[nodiscard]] Response TransactPackage(std::string_view Url, CbPackage Package);
+ [[nodiscard]] Response Delete(std::string_view Url);
+
+private:
+ std::string m_BaseUri;
+ std::string m_SessionId;
+};
+
+} // namespace zen
+
+void httpclient_forcelink(); // internal
diff --git a/src/zenhttp/include/zenhttp/httpcommon.h b/src/zenhttp/include/zenhttp/httpcommon.h
new file mode 100644
index 000000000..19fda8db4
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpcommon.h
@@ -0,0 +1,181 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+
+#include <string_view>
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+using HttpContentType = ZenContentType;
+
+class IoBuffer;
+class CbObject;
+class CbPackage;
+class StringBuilderBase;
+
+struct HttpRange
+{
+ uint32_t Start = ~uint32_t(0);
+ uint32_t End = ~uint32_t(0);
+};
+
+using HttpRanges = std::vector<HttpRange>;
+
+std::string_view MapContentTypeToString(HttpContentType ContentType);
+extern HttpContentType (*ParseContentType)(const std::string_view& ContentTypeString);
+std::string_view ReasonStringForHttpResultCode(int HttpCode);
+bool TryParseHttpRangeHeader(std::string_view RangeHeader, HttpRanges& Ranges);
+
+[[nodiscard]] inline bool
+IsHttpSuccessCode(int HttpCode)
+{
+ return (HttpCode >= 200) && (HttpCode < 300);
+}
+
+enum class HttpVerb : uint8_t
+{
+ kGet = 1 << 0,
+ kPut = 1 << 1,
+ kPost = 1 << 2,
+ kDelete = 1 << 3,
+ kHead = 1 << 4,
+ kCopy = 1 << 5,
+ kOptions = 1 << 6
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(HttpVerb);
+
+const std::string_view ToString(HttpVerb Verb);
+
+enum class HttpResponseCode
+{
+ // 1xx - Informational
+
+ Continue = 100, //!< Indicates that the initial part of a request has been received and has not yet been rejected by the server.
+ SwitchingProtocols = 101, //!< Indicates that the server understands and is willing to comply with the client's request, via the
+ //!< Upgrade header field, for a change in the application protocol being used on this connection.
+ Processing = 102, //!< Is an interim response used to inform the client that the server has accepted the complete request, but has not
+ //!< yet completed it.
+ EarlyHints = 103, //!< Indicates to the client that the server is likely to send a final response with the header fields included in
+ //!< the informational response.
+
+ // 2xx - Successful
+
+ OK = 200, //!< Indicates that the request has succeeded.
+ Created = 201, //!< Indicates that the request has been fulfilled and has resulted in one or more new resources being created.
+ Accepted = 202, //!< Indicates that the request has been accepted for processing, but the processing has not been completed.
+ NonAuthoritativeInformation = 203, //!< Indicates that the request was successful but the enclosed payload has been modified from that
+ //!< of the origin server's 200 (OK) response by a transforming proxy.
+ NoContent = 204, //!< Indicates that the server has successfully fulfilled the request and that there is no additional content to send
+ //!< in the response payload body.
+ ResetContent = 205, //!< Indicates that the server has fulfilled the request and desires that the user agent reset the \"document
+ //!< view\", which caused the request to be sent, to its original state as received from the origin server.
+ PartialContent = 206, //!< Indicates that the server is successfully fulfilling a range request for the target resource by transferring
+ //!< one or more parts of the selected representation that correspond to the satisfiable ranges found in the
+ //!< requests's Range header field.
+ MultiStatus = 207, //!< Provides status for multiple independent operations.
+ AlreadyReported = 208, //!< Used inside a DAV:propstat response element to avoid enumerating the internal members of multiple bindings
+ //!< to the same collection repeatedly. [RFC 5842]
+ IMUsed = 226, //!< The server has fulfilled a GET request for the resource, and the response is a representation of the result of one
+ //!< or more instance-manipulations applied to the current instance.
+
+ // 3xx - Redirection
+
+ MultipleChoices = 300, //!< Indicates that the target resource has more than one representation, each with its own more specific
+ //!< identifier, and information about the alternatives is being provided so that the user (or user agent) can
+ //!< select a preferred representation by redirecting its request to one or more of those identifiers.
+ MovedPermanently = 301, //!< Indicates that the target resource has been assigned a new permanent URI and any future references to this
+ //!< resource ought to use one of the enclosed URIs.
+ Found = 302, //!< Indicates that the target resource resides temporarily under a different URI.
+ SeeOther = 303, //!< Indicates that the server is redirecting the user agent to a different resource, as indicated by a URI in the
+ //!< Location header field, that is intended to provide an indirect response to the original request.
+ NotModified = 304, //!< Indicates that a conditional GET request has been received and would have resulted in a 200 (OK) response if it
+ //!< were not for the fact that the condition has evaluated to false.
+ UseProxy = 305, //!< \deprecated \parblock Due to security concerns regarding in-band configuration of a proxy. \endparblock
+ //!< The requested resource MUST be accessed through the proxy given by the Location field.
+ TemporaryRedirect = 307, //!< Indicates that the target resource resides temporarily under a different URI and the user agent MUST NOT
+ //!< change the request method if it performs an automatic redirection to that URI.
+ PermanentRedirect = 308, //!< The target resource has been assigned a new permanent URI and any future references to this resource
+ //!< ought to use one of the enclosed URIs. [...] This status code is similar to 301 Moved Permanently
+ //!< (Section 7.3.2 of rfc7231), except that it does not allow rewriting the request method from POST to GET.
+
+ // 4xx - Client Error
+ BadRequest = 400, //!< Indicates that the server cannot or will not process the request because the received syntax is invalid,
+ //!< nonsensical, or exceeds some limitation on what the server is willing to process.
+ Unauthorized = 401, //!< Indicates that the request has not been applied because it lacks valid authentication credentials for the
+ //!< target resource.
+ PaymentRequired = 402, //!< *Reserved*
+ Forbidden = 403, //!< Indicates that the server understood the request but refuses to authorize it.
+ NotFound = 404, //!< Indicates that the origin server did not find a current representation for the target resource or is not willing
+ //!< to disclose that one exists.
+ MethodNotAllowed = 405, //!< Indicates that the method specified in the request-line is known by the origin server but not supported by
+ //!< the target resource.
+ NotAcceptable = 406, //!< Indicates that the target resource does not have a current representation that would be acceptable to the
+ //!< user agent, according to the proactive negotiation header fields received in the request, and the server is
+ //!< unwilling to supply a default representation.
+ ProxyAuthenticationRequired =
+ 407, //!< Is similar to 401 (Unauthorized), but indicates that the client needs to authenticate itself in order to use a proxy.
+ RequestTimeout =
+ 408, //!< Indicates that the server did not receive a complete request message within the time that it was prepared to wait.
+ Conflict = 409, //!< Indicates that the request could not be completed due to a conflict with the current state of the resource.
+ Gone = 410, //!< Indicates that access to the target resource is no longer available at the origin server and that this condition is
+ //!< likely to be permanent.
+ LengthRequired = 411, //!< Indicates that the server refuses to accept the request without a defined Content-Length.
+ PreconditionFailed =
+ 412, //!< Indicates that one or more preconditions given in the request header fields evaluated to false when tested on the server.
+ PayloadTooLarge = 413, //!< Indicates that the server is refusing to process a request because the request payload is larger than the
+ //!< server is willing or able to process.
+ URITooLong = 414, //!< Indicates that the server is refusing to service the request because the request-target is longer than the
+ //!< server is willing to interpret.
+ UnsupportedMediaType = 415, //!< Indicates that the origin server is refusing to service the request because the payload is in a format
+ //!< not supported by the target resource for this method.
+ RangeNotSatisfiable = 416, //!< Indicates that none of the ranges in the request's Range header field overlap the current extent of the
+ //!< selected resource or that the set of ranges requested has been rejected due to invalid ranges or an
+ //!< excessive request of small or overlapping ranges.
+ ExpectationFailed = 417, //!< Indicates that the expectation given in the request's Expect header field could not be met by at least
+ //!< one of the inbound servers.
+ ImATeapot = 418, //!< Any attempt to brew coffee with a teapot should result in the error code 418 I'm a teapot.
+ UnprocessableEntity = 422, //!< Means the server understands the content type of the request entity (hence a 415(Unsupported Media
+ //!< Type) status code is inappropriate), and the syntax of the request entity is correct (thus a 400 (Bad
+ //!< Request) status code is inappropriate) but was unable to process the contained instructions.
+ Locked = 423, //!< Means the source or destination resource of a method is locked.
+ FailedDependency = 424, //!< Means that the method could not be performed on the resource because the requested action depended on
+ //!< another action and that action failed.
+ UpgradeRequired = 426, //!< Indicates that the server refuses to perform the request using the current protocol but might be willing to
+ //!< do so after the client upgrades to a different protocol.
+ PreconditionRequired = 428, //!< Indicates that the origin server requires the request to be conditional.
+ TooManyRequests = 429, //!< Indicates that the user has sent too many requests in a given amount of time (\"rate limiting\").
+ RequestHeaderFieldsTooLarge =
+ 431, //!< Indicates that the server is unwilling to process the request because its header fields are too large.
+ UnavailableForLegalReasons =
+ 451, //!< This status code indicates that the server is denying access to the resource in response to a legal demand.
+
+ // 5xx - Server Error
+
+ InternalServerError =
+ 500, //!< Indicates that the server encountered an unexpected condition that prevented it from fulfilling the request.
+ NotImplemented = 501, //!< Indicates that the server does not support the functionality required to fulfill the request.
+ BadGateway = 502, //!< Indicates that the server, while acting as a gateway or proxy, received an invalid response from an inbound
+ //!< server it accessed while attempting to fulfill the request.
+ ServiceUnavailable = 503, //!< Indicates that the server is currently unable to handle the request due to a temporary overload or
+ //!< scheduled maintenance, which will likely be alleviated after some delay.
+ GatewayTimeout = 504, //!< Indicates that the server, while acting as a gateway or proxy, did not receive a timely response from an
+ //!< upstream server it needed to access in order to complete the request.
+ HTTPVersionNotSupported = 505, //!< Indicates that the server does not support, or refuses to support, the protocol version that was
+ //!< used in the request message.
+ VariantAlsoNegotiates =
+ 506, //!< Indicates that the server has an internal configuration error: the chosen variant resource is configured to engage in
+ //!< transparent content negotiation itself, and is therefore not a proper end point in the negotiation process.
+ InsufficientStorage = 507, //!< Means the method could not be performed on the resource because the server is unable to store the
+ //!< representation needed to successfully complete the request.
+ LoopDetected = 508, //!< Indicates that the server terminated an operation because it encountered an infinite loop while processing a
+ //!< request with "Depth: infinity". [RFC 5842]
+ NotExtended = 510, //!< The policy for accessing the resource has not been met in the request. [RFC 2774]
+ NetworkAuthenticationRequired = 511, //!< Indicates that the client needs to authenticate to gain network access.
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
new file mode 100644
index 000000000..3b9fa50b4
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -0,0 +1,315 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenhttp.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/enumflags.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/refcount.h>
+#include <zencore/string.h>
+#include <zencore/uid.h>
+#include <zenhttp/httpcommon.h>
+
+#include <functional>
+#include <gsl/gsl-lite.hpp>
+#include <list>
+#include <map>
+#include <regex>
+#include <span>
+#include <unordered_map>
+
+namespace zen {
+
+/** HTTP Server Request
+ */
+class HttpServerRequest
+{
+public:
+ HttpServerRequest();
+ ~HttpServerRequest();
+
+ // Synchronous operations
+
+ [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix
+ [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; }
+ [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; }
+
+ struct QueryParams
+ {
+ std::vector<std::pair<std::string_view, std::string_view>> KvPairs;
+
+ std::string_view GetValue(std::string_view ParamName) const
+ {
+ for (const auto& Kv : KvPairs)
+ {
+ const std::string_view& Key = Kv.first;
+
+ if (Key.size() == ParamName.size())
+ {
+ if (0 == StrCaseCompare(Key.data(), ParamName.data(), Key.size()))
+ {
+ return Kv.second;
+ }
+ }
+ }
+
+ return std::string_view();
+ }
+ };
+
+ virtual bool TryGetRanges(HttpRanges&) { return false; }
+
+ QueryParams GetQueryParams();
+
+ inline HttpVerb RequestVerb() const { return m_Verb; }
+ inline HttpContentType RequestContentType() { return m_ContentType; }
+ inline HttpContentType AcceptContentType() { return m_AcceptType; }
+
+ inline uint64_t ContentLength() const { return m_ContentLength; }
+ Oid SessionId() const;
+ uint32_t RequestId() const;
+
+ inline bool IsHandled() const { return !!(m_Flags & kIsHandled); }
+ inline bool SuppressBody() const { return !!(m_Flags & kSuppressBody); }
+ inline void SetSuppressResponseBody() { m_Flags |= kSuppressBody; }
+
+ /** Read POST/PUT payload for request body, which is always available without delay
+ */
+ virtual IoBuffer ReadPayload() = 0;
+
+ ZENCORE_API CbObject ReadPayloadObject();
+ ZENCORE_API CbPackage ReadPayloadPackage();
+
+ /** Respond with payload
+
+ No data will have been sent when any of these functions return. Instead, the response will be transmitted
+ asynchronously, after returning from a request handler function.
+
+ Note that this is destructive in the sense that the IoBuffer instances referred to by Blobs will be
+ moved into our response handler array where they are kept alive, in order to reduce ref-counting storms
+ */
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) = 0;
+ virtual void WriteResponse(HttpResponseCode ResponseCode) = 0;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) = 0;
+ virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, CompositeBuffer& Payload);
+
+ void WriteResponse(HttpResponseCode ResponseCode, CbObject Data);
+ void WriteResponse(HttpResponseCode ResponseCode, CbArray Array);
+ void WriteResponse(HttpResponseCode ResponseCode, CbPackage Package);
+ void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::string_view ResponseString);
+ void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, IoBuffer Blob);
+
+ virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) = 0;
+
+protected:
+ enum
+ {
+ kIsHandled = 1 << 0,
+ kSuppressBody = 1 << 1,
+ kHaveRequestId = 1 << 2,
+ kHaveSessionId = 1 << 3,
+ };
+
+ mutable uint32_t m_Flags = 0;
+ HttpVerb m_Verb = HttpVerb::kGet;
+ HttpContentType m_ContentType = HttpContentType::kBinary;
+ HttpContentType m_AcceptType = HttpContentType::kUnknownContentType;
+ uint64_t m_ContentLength = ~0ull;
+ std::string_view m_Uri;
+ std::string_view m_UriWithExtension;
+ std::string_view m_QueryString;
+ mutable uint32_t m_RequestId = ~uint32_t(0);
+ mutable Oid m_SessionId = Oid::Zero;
+
+ inline void SetIsHandled() { m_Flags |= kIsHandled; }
+
+ virtual Oid ParseSessionId() const = 0;
+ virtual uint32_t ParseRequestId() const = 0;
+};
+
+class IHttpPackageHandler : public RefCounted
+{
+public:
+ virtual void FilterOffer(std::vector<IoHash>& OfferCids) = 0;
+ virtual void OnRequestBegin() = 0;
+ virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) = 0;
+ virtual void OnRequestComplete() = 0;
+};
+
+/**
+ * Base class for implementing an HTTP "service"
+ *
+ * A service exposes one or more endpoints with a certain URI prefix
+ *
+ */
+
+class HttpService
+{
+public:
+ HttpService() = default;
+ virtual ~HttpService() = default;
+
+ virtual const char* BaseUri() const = 0;
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) = 0;
+ virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest);
+
+ // Internals
+
+ inline void SetUriPrefixLength(size_t PrefixLength) { m_UriPrefixLength = (int)PrefixLength; }
+ inline int UriPrefixLength() const { return m_UriPrefixLength; }
+
+private:
+ int m_UriPrefixLength = 0;
+};
+
+/** HTTP server
+ *
+ * Implements the main event loop to service HTTP requests, and handles routing
+ * requests to the appropriate handler as registered via RegisterService
+ */
+class HttpServer : public RefCounted
+{
+public:
+ virtual void RegisterService(HttpService& Service) = 0;
+ virtual int Initialize(int BasePort) = 0;
+ virtual void Run(bool IsInteractiveSession) = 0;
+ virtual void RequestExit() = 0;
+};
+
+Ref<HttpServer> CreateHttpServer(std::string_view ServerClass);
+
+//////////////////////////////////////////////////////////////////////////
+
+class HttpRouterRequest
+{
+public:
+ HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {}
+
+ ZENCORE_API std::string GetCapture(uint32_t Index) const;
+ inline HttpServerRequest& ServerRequest() { return m_HttpRequest; }
+
+private:
+ using MatchResults_t = std::match_results<std::string_view::const_iterator>;
+
+ HttpServerRequest& m_HttpRequest;
+ MatchResults_t m_Match;
+
+ friend class HttpRequestRouter;
+};
+
+inline std::string
+HttpRouterRequest::GetCapture(uint32_t Index) const
+{
+ ZEN_ASSERT(Index < m_Match.size());
+
+ return m_Match[Index];
+}
+
+/** HTTP request router helper
+ *
+ * This helper class allows a service implementer to register one or more
+ * endpoints using pattern matching (currently using regex matching)
+ *
+ * This is intended to be initialized once only, there is no thread
+ * safety so you can absolutely not add or remove endpoints once the handler
+ * goes live
+ */
+
+class HttpRequestRouter
+{
+public:
+ typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t;
+
+ /**
+ * @brief Add pattern which can be referenced by name, commonly used for URL components
+ * @param Id String used to identify patterns for replacement
+ * @param Regex String which will replace the Id string in any registered URL paths
+ */
+ void AddPattern(const char* Id, const char* Regex);
+
+ /**
+ * @brief Register a an endpoint handler for the given route
+ * @param Regex Regular expression used to match the handler to a request. This may
+ * contain pattern aliases registered via AddPattern
+ * @param HandlerFunc Handler function to call for any matching request
+ * @param SupportedVerbs Supported HTTP verbs for this handler
+ */
+ void RegisterRoute(const char* Regex, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs);
+
+ void ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex);
+
+ /**
+ * @brief HTTP request handling function - this should be called to route the
+ * request to a registered handler
+ * @param Request Request to route to a handler
+ * @return Function returns true if the request was routed successfully
+ */
+ bool HandleRequest(zen::HttpServerRequest& Request);
+
+private:
+ struct HandlerEntry
+ {
+ HandlerEntry(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern)
+ : RegEx(Regex, std::regex::icase | std::regex::ECMAScript)
+ , Verbs(SupportedVerbs)
+ , Handler(std::move(Handler))
+ , Pattern(Pattern)
+ {
+ }
+
+ ~HandlerEntry() = default;
+
+ std::regex RegEx;
+ HttpVerb Verbs;
+ HandlerFunc_t Handler;
+ const char* Pattern;
+
+ private:
+ HandlerEntry& operator=(const HandlerEntry&) = delete;
+ HandlerEntry(const HandlerEntry&) = delete;
+ };
+
+ std::list<HandlerEntry> m_Handlers;
+ std::unordered_map<std::string, std::string> m_PatternMap;
+};
+
+/** HTTP RPC request helper
+ */
+
+class RpcResult
+{
+ RpcResult(CbObject Result) : m_Result(std::move(Result)) {}
+
+private:
+ CbObject m_Result;
+};
+
+class HttpRpcHandler
+{
+public:
+ HttpRpcHandler();
+ ~HttpRpcHandler();
+
+ HttpRpcHandler(const HttpRpcHandler&) = delete;
+ HttpRpcHandler operator=(const HttpRpcHandler&) = delete;
+
+ void AddRpc(std::string_view RpcId, std::function<void(CbObject& RpcArgs)> HandlerFunction);
+
+private:
+ struct RpcFunction
+ {
+ std::function<void(CbObject& RpcArgs)> Function;
+ std::string Identifier;
+ };
+
+ std::map<std::string, RpcFunction> m_Functions;
+};
+
+bool HandlePackageOffers(HttpService& Service, HttpServerRequest& Request, Ref<IHttpPackageHandler>& PackageHandlerRef);
+
+void http_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/httpshared.h b/src/zenhttp/include/zenhttp/httpshared.h
new file mode 100644
index 000000000..d335572c5
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/httpshared.h
@@ -0,0 +1,163 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinarypackage.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+
+#include <functional>
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+class IoBuffer;
+class CbPackage;
+class CompositeBuffer;
+
+/** _____ _ _____ _
+ / ____| | | __ \ | |
+ | | | |__ | |__) |_ _ ___| | ____ _ __ _ ___
+ | | | '_ \| ___/ _` |/ __| |/ / _` |/ _` |/ _ \
+ | |____| |_) | | | (_| | (__| < (_| | (_| | __/
+ \_____|_.__/|_| \__,_|\___|_|\_\__,_|\__, |\___|
+ __/ |
+ |___/
+
+ Structures and code related to handling CbPackage transactions
+
+ CbPackage instances are marshaled across the wire using a distinct message
+ format. We don't use the CbPackage serialization format provided by the
+ CbPackage implementation itself since that does not provide much flexibility
+ in how the attachment payloads are transmitted. The scheme below separates
+ metadata cleanly from payloads and this enables us to more efficiently
+ transmit them either via sendfile/TransmitFile like mechanisms, or by
+ reference/memory mapping in the local case.
+ */
+
+struct CbPackageHeader
+{
+ uint32_t HeaderMagic;
+ uint32_t AttachmentCount; // TODO: should add ability to opt out of implicit root document?
+ uint32_t Reserved1;
+ uint32_t Reserved2;
+};
+
+static_assert(sizeof(CbPackageHeader) == 16);
+
+enum : uint32_t
+{
+ kCbPkgMagic = 0xaa77aacc
+};
+
+struct CbAttachmentEntry
+{
+ uint64_t PayloadSize; // Size of the associated payload data in the message
+ uint32_t Flags; // See flags below
+ IoHash AttachmentHash; // Content Id for the attachment
+
+ enum
+ {
+ kIsCompressed = (1u << 0), // Is marshaled using compressed buffer storage format
+ kIsObject = (1u << 1), // Is compact binary object
+ kIsError = (1u << 2), // Is error (compact binary formatted) object
+ kIsLocalRef = (1u << 3), // Is "local reference"
+ };
+};
+
+struct CbAttachmentReferenceHeader
+{
+ uint64_t PayloadByteOffset = 0;
+ uint64_t PayloadByteSize = ~0u;
+ uint16_t AbsolutePathLength = 0;
+
+ // This header will be followed by UTF8 encoded absolute path to backing file
+};
+
+static_assert(sizeof(CbAttachmentEntry) == 32);
+
+enum class FormatFlags
+{
+ kDefault = 0,
+ kAllowLocalReferences = (1u << 0),
+ kDenyPartialLocalReferences = (1u << 1)
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(FormatFlags);
+
+enum class RpcAcceptOptions : uint16_t
+{
+ kNone = 0,
+ kAllowLocalReferences = (1u << 0),
+ kAllowPartialLocalReferences = (1u << 1)
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(RpcAcceptOptions);
+
+std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0);
+CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, FormatFlags Flags, int TargetProcessPid = 0);
+CbPackage ParsePackageMessage(
+ IoBuffer Payload,
+ std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer = [](const IoHash&, uint64_t Size) -> IoBuffer {
+ return IoBuffer{Size};
+ });
+bool IsPackageMessage(IoBuffer Payload);
+
+bool ParsePackageMessageWithLegacyFallback(const IoBuffer& Response, CbPackage& OutPackage);
+
+std::vector<IoBuffer> FormatPackageMessage(const CbPackage& Data, int TargetProcessPid = 0);
+CompositeBuffer FormatPackageMessageBuffer(const CbPackage& Data, int TargetProcessPid = 0);
+
+/** Streaming reader for compact binary packages
+
+ The goal is to ultimately support zero-copy I/O, but for now there'll be some
+ copying involved on some platforms at least.
+
+ This approach to deserializing CbPackage data is more efficient than
+ `ParsePackageMessage` since it does not require the entire message to
+ be resident in a memory buffer
+
+ */
+class CbPackageReader
+{
+public:
+ CbPackageReader();
+ ~CbPackageReader();
+
+ void SetPayloadBufferCreator(std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> CreateBuffer);
+
+ /** Process compact binary package data stream
+
+ The data stream must be in the serialization format produced by FormatPackageMessage
+
+ \return How many bytes must be fed to this function in the next call
+ */
+ uint64_t ProcessPackageHeaderData(const void* Data, uint64_t DataBytes);
+
+ void Finalize();
+ const std::vector<CbAttachment>& GetAttachments() { return m_Attachments; }
+ CbObject GetRootObject() { return m_RootObject; }
+ std::span<IoBuffer> GetPayloadBuffers() { return m_PayloadBuffers; }
+
+private:
+ enum class State
+ {
+ kInitialState,
+ kReadingHeader,
+ kReadingAttachmentEntries,
+ kReadingBuffers
+ } m_CurrentState = State::kInitialState;
+
+ std::function<IoBuffer(const IoHash& Cid, uint64_t Size)> m_CreateBuffer;
+ std::vector<IoBuffer> m_PayloadBuffers;
+ std::vector<CbAttachmentEntry> m_AttachmentEntries;
+ std::vector<CbAttachment> m_Attachments;
+ CbObject m_RootObject;
+ CbPackageHeader m_PackageHeader;
+
+ IoBuffer MarshalLocalChunkReference(IoBuffer AttachmentBuffer);
+};
+
+void forcelink_httpshared();
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/websocket.h b/src/zenhttp/include/zenhttp/websocket.h
new file mode 100644
index 000000000..adca7e988
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/websocket.h
@@ -0,0 +1,256 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/compactbinarypackage.h>
+#include <zencore/memory.h>
+
+#include <compare>
+#include <functional>
+#include <future>
+#include <memory>
+#include <optional>
+
+#pragma once
+
+namespace asio {
+class io_context;
+}
+
+namespace zen {
+
+class BinaryWriter;
+
+/**
+ * A unique socket ID.
+ */
+class WebSocketId
+{
+ static std::atomic_uint32_t NextId;
+
+public:
+ WebSocketId() = default;
+
+ uint32_t Value() const { return m_Value; }
+
+ auto operator<=>(const WebSocketId&) const = default;
+
+ static WebSocketId New() { return WebSocketId(NextId.fetch_add(1)); }
+
+private:
+ WebSocketId(uint32_t Value) : m_Value(Value) {}
+
+ uint32_t m_Value{};
+};
+
+/**
+ * Type of web socket message.
+ */
+enum class WebSocketMessageType : uint8_t
+{
+ kInvalid,
+ kNotification,
+ kRequest,
+ kStreamRequest,
+ kResponse,
+ kStreamResponse,
+ kStreamCompleteResponse,
+ kCount
+};
+
+inline std::string_view
+ToString(WebSocketMessageType Type)
+{
+ switch (Type)
+ {
+ case WebSocketMessageType::kInvalid:
+ return std::string_view("Invalid");
+ case WebSocketMessageType::kNotification:
+ return std::string_view("Notification");
+ case WebSocketMessageType::kRequest:
+ return std::string_view("Request");
+ case WebSocketMessageType::kStreamRequest:
+ return std::string_view("StreamRequest");
+ case WebSocketMessageType::kResponse:
+ return std::string_view("Response");
+ case WebSocketMessageType::kStreamResponse:
+ return std::string_view("StreamResponse");
+ case WebSocketMessageType::kStreamCompleteResponse:
+ return std::string_view("StreamCompleteResponse");
+ default:
+ return std::string_view("Unknown");
+ };
+}
+
+/**
+ * Web socket message.
+ */
+class WebSocketMessage
+{
+ struct Header
+ {
+ static constexpr uint32_t ExpectedMagic = 0x7a776d68; // zwmh
+
+ uint64_t MessageSize{};
+ uint32_t Magic{ExpectedMagic};
+ uint32_t CorrelationId{};
+ uint32_t StatusCode{200u};
+ WebSocketMessageType MessageType{};
+ uint8_t Reserved[3] = {0};
+
+ bool IsValid() const;
+ };
+
+ static_assert(sizeof(Header) == 24);
+
+ static std::atomic_uint32_t NextCorrelationId;
+
+public:
+ static constexpr size_t HeaderSize = sizeof(Header);
+
+ WebSocketMessage() = default;
+
+ WebSocketId SocketId() const { return m_SocketId; }
+ void SetSocketId(WebSocketId Id) { m_SocketId = Id; }
+ uint64_t MessageSize() const { return m_Header.MessageSize; }
+ void SetMessageType(WebSocketMessageType MessageType);
+ void SetCorrelationId(uint32_t Id) { m_Header.CorrelationId = Id; }
+ uint32_t CorrelationId() const { return m_Header.CorrelationId; }
+ uint32_t StatusCode() const { return m_Header.StatusCode; }
+ void SetStatusCode(uint32_t StatusCode) { m_Header.StatusCode = StatusCode; }
+ WebSocketMessageType MessageType() const { return m_Header.MessageType; }
+
+ const CbPackage& Body() const { return m_Body.value(); }
+ void SetBody(CbPackage&& Body);
+ void SetBody(CbObject&& Body);
+ bool HasBody() const { return m_Body.has_value(); }
+
+ void Save(BinaryWriter& Writer);
+ bool TryLoadHeader(MemoryView Memory);
+
+ bool IsValid() const { return m_Header.MessageType != WebSocketMessageType::kInvalid; }
+
+private:
+ Header m_Header{};
+ WebSocketId m_SocketId{};
+ std::optional<CbPackage> m_Body;
+};
+
+class WebSocketServer;
+
+/**
+ * Base class for handling web socket requests and notifications from connected client(s).
+ */
+class WebSocketService
+{
+public:
+ virtual ~WebSocketService() = default;
+
+ void Configure(WebSocketServer& Server);
+
+ virtual bool HandleRequest(const WebSocketMessage&) { ZEN_ASSERT(false); }
+ virtual void HandleNotification(const WebSocketMessage&) { ZEN_ASSERT(false); }
+
+protected:
+ WebSocketService() = default;
+
+ virtual void RegisterHandlers(WebSocketServer& Server) = 0;
+ void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete);
+ void SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete);
+
+ WebSocketServer& SocketServer()
+ {
+ ZEN_ASSERT(m_SocketServer);
+ return *m_SocketServer;
+ }
+
+private:
+ WebSocketServer* m_SocketServer{};
+};
+
+/**
+ * Server options.
+ */
+struct WebSocketServerOptions
+{
+ uint16_t Port = 2337;
+ uint32_t ThreadCount = 1;
+};
+
+/**
+ * The web socket server manages client connections and routing of requests and notifications.
+ */
+class WebSocketServer
+{
+public:
+ virtual ~WebSocketServer() = default;
+
+ virtual bool Run() = 0;
+ virtual void Shutdown() = 0;
+
+ virtual void RegisterService(WebSocketService& Service) = 0;
+ virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) = 0;
+ virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) = 0;
+
+ virtual void SendNotification(WebSocketMessage&& Notification) = 0;
+ virtual void SendResponse(WebSocketMessage&& Response) = 0;
+
+ static std::unique_ptr<WebSocketServer> Create(const WebSocketServerOptions& Options);
+};
+
+/**
+ * The state of the web socket.
+ */
+enum class WebSocketState : uint32_t
+{
+ kNone,
+ kHandshaking,
+ kConnected,
+ kDisconnected,
+ kError
+};
+
+/**
+ * Type of web socket client event.
+ */
+enum class WebSocketEvent : uint32_t
+{
+ kConnected,
+ kDisconnected,
+ kError
+};
+
+/**
+ * Web socket client connection info.
+ */
+struct WebSocketConnectInfo
+{
+ std::string Host;
+ int16_t Port{8848};
+ std::string Endpoint;
+ std::vector<std::string> Protocols;
+ uint16_t Version{13};
+};
+
+/**
+ * A connection to a web socket server for sending requests and listening for notifications.
+ */
+class WebSocketClient
+{
+public:
+ using EventCallback = std::function<void()>;
+ using NotificationCallback = std::function<void(WebSocketMessage&&)>;
+
+ virtual ~WebSocketClient() = default;
+
+ virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) = 0;
+ virtual void Disconnect() = 0;
+ virtual bool IsConnected() const = 0;
+ virtual WebSocketState State() const = 0;
+
+ virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) = 0;
+ virtual void OnNotification(NotificationCallback&& Cb) = 0;
+ virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) = 0;
+
+ static std::shared_ptr<WebSocketClient> Create(asio::io_context& IoCtx);
+};
+
+} // namespace zen
diff --git a/src/zenhttp/include/zenhttp/zenhttp.h b/src/zenhttp/include/zenhttp/zenhttp.h
new file mode 100644
index 000000000..59c64b31f
--- /dev/null
+++ b/src/zenhttp/include/zenhttp/zenhttp.h
@@ -0,0 +1,21 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#ifndef ZEN_WITH_HTTPSYS
+# if ZEN_PLATFORM_WINDOWS
+# define ZEN_WITH_HTTPSYS 1
+# else
+# define ZEN_WITH_HTTPSYS 0
+# endif
+#endif
+
+#define ZENHTTP_API // Placeholder to allow DLL configs in the future
+
+namespace zen {
+
+ZENHTTP_API void zenhttp_forcelinktests();
+
+}
diff --git a/src/zenhttp/iothreadpool.cpp b/src/zenhttp/iothreadpool.cpp
new file mode 100644
index 000000000..6087e69ec
--- /dev/null
+++ b/src/zenhttp/iothreadpool.cpp
@@ -0,0 +1,49 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "iothreadpool.h"
+
+#include <zencore/except.h>
+
+#if ZEN_PLATFORM_WINDOWS
+
+namespace zen {
+
+WinIoThreadPool::WinIoThreadPool(int InThreadCount)
+{
+ // Thread pool setup
+
+ m_ThreadPool = CreateThreadpool(NULL);
+
+ SetThreadpoolThreadMinimum(m_ThreadPool, InThreadCount);
+ SetThreadpoolThreadMaximum(m_ThreadPool, InThreadCount * 2);
+
+ InitializeThreadpoolEnvironment(&m_CallbackEnvironment);
+
+ m_CleanupGroup = CreateThreadpoolCleanupGroup();
+
+ SetThreadpoolCallbackPool(&m_CallbackEnvironment, m_ThreadPool);
+
+ SetThreadpoolCallbackCleanupGroup(&m_CallbackEnvironment, m_CleanupGroup, NULL);
+}
+
+WinIoThreadPool::~WinIoThreadPool()
+{
+ CloseThreadpool(m_ThreadPool);
+}
+
+void
+WinIoThreadPool::CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode)
+{
+ ZEN_ASSERT(!m_ThreadPoolIo);
+
+ m_ThreadPoolIo = CreateThreadpoolIo(IoHandle, Callback, Context, &m_CallbackEnvironment);
+
+ if (!m_ThreadPoolIo)
+ {
+ ErrorCode = MakeErrorCodeFromLastError();
+ }
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zenhttp/iothreadpool.h b/src/zenhttp/iothreadpool.h
new file mode 100644
index 000000000..8333964c3
--- /dev/null
+++ b/src/zenhttp/iothreadpool.h
@@ -0,0 +1,37 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+
+# include <system_error>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Thread pool. Implemented in terms of Windows thread pool right now, will
+// need a cross-platform implementation eventually
+//
+
+class WinIoThreadPool
+{
+public:
+ WinIoThreadPool(int InThreadCount);
+ ~WinIoThreadPool();
+
+ void CreateIocp(HANDLE IoHandle, PTP_WIN32_IO_CALLBACK Callback, void* Context, std::error_code& ErrorCode);
+ inline PTP_IO Iocp() const { return m_ThreadPoolIo; }
+
+private:
+ PTP_POOL m_ThreadPool = nullptr;
+ PTP_CLEANUP_GROUP m_CleanupGroup = nullptr;
+ PTP_IO m_ThreadPoolIo = nullptr;
+ TP_CALLBACK_ENVIRON m_CallbackEnvironment;
+};
+
+} // namespace zen
+#endif
diff --git a/src/zenhttp/websocketasio.cpp b/src/zenhttp/websocketasio.cpp
new file mode 100644
index 000000000..bbe7e1ad8
--- /dev/null
+++ b/src/zenhttp/websocketasio.cpp
@@ -0,0 +1,1613 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/websocket.h>
+
+#include <zencore/base64.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/intmath.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/sha1.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/trace.h>
+
+#include <chrono>
+#include <optional>
+#include <shared_mutex>
+#include <span>
+#include <system_error>
+#include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <http_parser.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# include <mstcpip.h>
+#endif
+
+namespace zen::websocket {
+
+using namespace std::literals;
+
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWebSocket, "websocket"sv);
+
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogWsClient, "ws-client"sv);
+
+using Clock = std::chrono::steady_clock;
+using TimePoint = Clock::time_point;
+
+///////////////////////////////////////////////////////////////////////////////
+namespace http_header {
+ static constexpr std::string_view SecWebSocketKey = "Sec-WebSocket-Key"sv;
+ static constexpr std::string_view SecWebSocketOrigin = "Sec-WebSocket-Origin"sv;
+ static constexpr std::string_view SecWebSocketProtocol = "Sec-WebSocket-Protocol"sv;
+ static constexpr std::string_view SecWebSocketVersion = "Sec-WebSocket-Version"sv;
+ static constexpr std::string_view SecWebSocketAccept = "Sec-WebSocket-Accept"sv;
+ static constexpr std::string_view Upgrade = "Upgrade"sv;
+} // namespace http_header
+
+///////////////////////////////////////////////////////////////////////////////
+enum class ParseMessageStatus : uint32_t
+{
+ kError,
+ kContinue,
+ kDone,
+};
+
+struct ParseMessageResult
+{
+ ParseMessageStatus Status{};
+ size_t ByteCount{};
+ std::optional<std::string> Reason;
+};
+
+class MessageParser
+{
+public:
+ virtual ~MessageParser() = default;
+
+ ParseMessageResult ParseMessage(MemoryView Msg);
+ void Reset();
+
+protected:
+ MessageParser() = default;
+
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) = 0;
+ virtual void OnReset() = 0;
+
+ BinaryWriter m_Stream;
+};
+
+ParseMessageResult
+MessageParser::ParseMessage(MemoryView Msg)
+{
+ return OnParseMessage(Msg);
+}
+
+void
+MessageParser::Reset()
+{
+ OnReset();
+
+ m_Stream.Reset();
+}
+
+///////////////////////////////////////////////////////////////////////////////
+enum class HttpMessageParserType
+{
+ kRequest,
+ kResponse,
+ kBoth
+};
+
+class HttpMessageParser final : public MessageParser
+{
+public:
+ using HttpHeaders = std::unordered_map<std::string_view, std::string_view>;
+
+ HttpMessageParser(HttpMessageParserType Type) : MessageParser(), m_Type(Type) { Initialize(); }
+
+ virtual ~HttpMessageParser() = default;
+
+ int32_t StatusCode() const { return m_Parser.status_code; }
+ bool IsUpgrade() const { return m_Parser.upgrade != 0; }
+ HttpHeaders& Headers() { return m_Headers; }
+ MemoryView Body() const { return MemoryView(m_Stream.Data() + m_BodyEntry.Offset, m_BodyEntry.Size); }
+
+ std::string_view StatusText() const
+ {
+ return std::string_view(reinterpret_cast<const char*>(m_Stream.Data() + m_StatusEntry.Offset), m_StatusEntry.Size);
+ }
+
+ bool ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason);
+
+private:
+ void Initialize();
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
+ virtual void OnReset() override;
+ int OnMessageBegin();
+ int OnUrl(MemoryView Url);
+ int OnStatus(MemoryView Status);
+ int OnHeaderField(MemoryView HeaderField);
+ int OnHeaderValue(MemoryView HeaderValue);
+ int OnHeadersComplete();
+ int OnBody(MemoryView Body);
+ int OnMessageComplete();
+
+ struct StreamEntry
+ {
+ uint64_t Offset{};
+ uint64_t Size{};
+ };
+
+ struct HeaderStreamEntry
+ {
+ StreamEntry Field{};
+ StreamEntry Value{};
+ };
+
+ HttpMessageParserType m_Type;
+ http_parser m_Parser;
+ StreamEntry m_UrlEntry;
+ StreamEntry m_StatusEntry;
+ StreamEntry m_BodyEntry;
+ HeaderStreamEntry m_CurrentHeader;
+ std::vector<HeaderStreamEntry> m_HeaderEntries;
+ HttpHeaders m_Headers;
+ bool m_IsMsgComplete{false};
+
+ static http_parser_settings ParserSettings;
+};
+
+http_parser_settings HttpMessageParser::ParserSettings = {
+ .on_message_begin = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageBegin(); },
+
+ .on_url = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnUrl(MemoryView(Data, Size)); },
+
+ .on_status = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnStatus(MemoryView(Data, Size)); },
+
+ .on_header_field = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderField(MemoryView(Data, Size)); },
+
+ .on_header_value = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeaderValue(MemoryView(Data, Size)); },
+
+ .on_headers_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnHeadersComplete(); },
+
+ .on_body = [](http_parser* P,
+ const char* Data,
+ size_t Size) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnBody(MemoryView(Data, Size)); },
+
+ .on_message_complete = [](http_parser* P) { return reinterpret_cast<HttpMessageParser*>(P->data)->OnMessageComplete(); }};
+
+void
+HttpMessageParser::Initialize()
+{
+ http_parser_init(&m_Parser,
+ m_Type == HttpMessageParserType::kRequest ? HTTP_REQUEST
+ : m_Type == HttpMessageParserType::kResponse ? HTTP_RESPONSE
+ : HTTP_BOTH);
+ m_Parser.data = this;
+
+ m_UrlEntry = {};
+ m_StatusEntry = {};
+ m_CurrentHeader = {};
+ m_BodyEntry = {};
+
+ m_IsMsgComplete = false;
+
+ m_HeaderEntries.clear();
+}
+
+ParseMessageResult
+HttpMessageParser::OnParseMessage(MemoryView Msg)
+{
+ const size_t ByteCount = http_parser_execute(&m_Parser, &ParserSettings, reinterpret_cast<const char*>(Msg.GetData()), Msg.GetSize());
+
+ auto Status = m_IsMsgComplete ? ParseMessageStatus::kDone : ParseMessageStatus::kContinue;
+
+ if (m_Parser.http_errno != 0)
+ {
+ Status = ParseMessageStatus::kError;
+ }
+
+ return {.Status = Status, .ByteCount = uint64_t(ByteCount)};
+}
+
+void
+HttpMessageParser::OnReset()
+{
+ Initialize();
+}
+
+int
+HttpMessageParser::OnMessageBegin()
+{
+ ZEN_ASSERT(m_IsMsgComplete == false);
+ ZEN_ASSERT(m_HeaderEntries.empty());
+ ZEN_ASSERT(m_Headers.empty());
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnStatus(MemoryView Status)
+{
+ m_StatusEntry = {m_Stream.CurrentOffset(), Status.GetSize()};
+
+ m_Stream.Write(Status);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnUrl(MemoryView Url)
+{
+ m_UrlEntry = {m_Stream.CurrentOffset(), Url.GetSize()};
+
+ m_Stream.Write(Url);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeaderField(MemoryView HeaderField)
+{
+ if (m_CurrentHeader.Value.Size > 0)
+ {
+ m_HeaderEntries.push_back(m_CurrentHeader);
+ m_CurrentHeader = {};
+ }
+
+ if (m_CurrentHeader.Field.Size == 0)
+ {
+ m_CurrentHeader.Field.Offset = m_Stream.CurrentOffset();
+ }
+
+ m_CurrentHeader.Field.Size += HeaderField.GetSize();
+
+ m_Stream.Write(HeaderField);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeaderValue(MemoryView HeaderValue)
+{
+ if (m_CurrentHeader.Value.Size == 0)
+ {
+ m_CurrentHeader.Value.Offset = m_Stream.CurrentOffset();
+ }
+
+ m_CurrentHeader.Value.Size += HeaderValue.GetSize();
+
+ m_Stream.Write(HeaderValue);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnHeadersComplete()
+{
+ if (m_CurrentHeader.Value.Size > 0)
+ {
+ m_HeaderEntries.push_back(m_CurrentHeader);
+ m_CurrentHeader = {};
+ }
+
+ m_Headers.clear();
+ m_Headers.reserve(m_HeaderEntries.size());
+
+ const char* StreamData = reinterpret_cast<const char*>(m_Stream.Data());
+
+ for (const auto& Entry : m_HeaderEntries)
+ {
+ auto Field = std::string_view(StreamData + Entry.Field.Offset, Entry.Field.Size);
+ auto Value = std::string_view(StreamData + Entry.Value.Offset, Entry.Value.Size);
+
+ m_Headers.try_emplace(std::move(Field), std::move(Value));
+ }
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnBody(MemoryView Body)
+{
+ m_BodyEntry = {m_Stream.CurrentOffset(), Body.GetSize()};
+
+ m_Stream.Write(Body);
+
+ return 0;
+}
+
+int
+HttpMessageParser::OnMessageComplete()
+{
+ m_IsMsgComplete = true;
+
+ return 0;
+}
+
+bool
+HttpMessageParser::ValidateWebSocketHandshake(std::string& OutAcceptHash, std::string& OutReason)
+{
+ static constexpr std::string_view WebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"sv;
+
+ OutAcceptHash = std::string();
+
+ if (m_Headers.contains(http_header::SecWebSocketKey) == false)
+ {
+ OutReason = "Missing header Sec-WebSocket-Key";
+ return false;
+ }
+
+ if (m_Headers.contains(http_header::Upgrade) == false)
+ {
+ OutReason = "Missing header Upgrade";
+ return false;
+ }
+
+ ExtendableStringBuilder<128> Sb;
+ Sb << m_Headers[http_header::SecWebSocketKey] << WebSocketGuid;
+
+ SHA1Stream HashStream;
+ HashStream.Append(Sb.Data(), Sb.Size());
+
+ SHA1 Hash = HashStream.GetHash();
+
+ OutAcceptHash.resize(Base64::GetEncodedDataSize(sizeof(SHA1::Hash)));
+ Base64::Encode(Hash.Hash, sizeof(SHA1::Hash), OutAcceptHash.data());
+
+ return true;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WebSocketMessageParser final : public MessageParser
+{
+public:
+ WebSocketMessageParser() : MessageParser() {}
+
+ WebSocketMessage ConsumeMessage();
+
+private:
+ virtual ParseMessageResult OnParseMessage(MemoryView Msg) override;
+ virtual void OnReset() override;
+
+ WebSocketMessage m_Message;
+};
+
+ParseMessageResult
+WebSocketMessageParser::OnParseMessage(MemoryView Msg)
+{
+ ZEN_TRACE_CPU("WebSocketMessageParser::OnParseMessage");
+
+ const uint64_t PrevOffset = m_Stream.CurrentOffset();
+
+ if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
+ {
+ const uint64_t RemaingHeaderSize = WebSocketMessage::HeaderSize - m_Stream.CurrentOffset();
+
+ m_Stream.Write(Msg.Left(RemaingHeaderSize));
+ Msg += RemaingHeaderSize;
+
+ if (m_Stream.CurrentOffset() < WebSocketMessage::HeaderSize)
+ {
+ return {.Status = ParseMessageStatus::kContinue, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ }
+
+ const bool IsValidHeader = m_Message.TryLoadHeader(m_Stream.GetView());
+
+ if (IsValidHeader == false)
+ {
+ OnReset();
+
+ return {.Status = ParseMessageStatus::kError,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
+ .Reason = std::string("Invalid websocket message header")};
+ }
+
+ if (m_Message.MessageSize() == 0)
+ {
+ return {.Status = ParseMessageStatus::kDone, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+ }
+ }
+
+ ZEN_ASSERT(m_Stream.CurrentOffset() >= WebSocketMessage::HeaderSize);
+
+ if (Msg.IsEmpty() == false)
+ {
+ const uint64_t RemaingMessageSize = (WebSocketMessage::HeaderSize + m_Message.MessageSize()) - m_Stream.CurrentOffset();
+ m_Stream.Write(Msg.Left(RemaingMessageSize));
+ }
+
+ auto Status = ParseMessageStatus::kContinue;
+
+ if (m_Stream.CurrentOffset() == WebSocketMessage::HeaderSize + m_Message.MessageSize())
+ {
+ Status = ParseMessageStatus::kDone;
+
+ BinaryReader Reader(m_Stream.GetView().RightChop(WebSocketMessage::HeaderSize));
+
+ CbPackage Pkg;
+ if (Pkg.TryLoad(Reader) == false)
+ {
+ return {.Status = ParseMessageStatus::kError,
+ .ByteCount = m_Stream.CurrentOffset() - PrevOffset,
+ .Reason = std::string("Invalid websocket message")};
+ }
+
+ m_Message.SetBody(std::move(Pkg));
+ }
+
+ return {.Status = Status, .ByteCount = m_Stream.CurrentOffset() - PrevOffset};
+}
+
+void
+WebSocketMessageParser::OnReset()
+{
+ m_Message = WebSocketMessage();
+}
+
+WebSocketMessage
+WebSocketMessageParser::ConsumeMessage()
+{
+ WebSocketMessage Msg = std::move(m_Message);
+ m_Message = WebSocketMessage();
+
+ return Msg;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsConnection : public std::enable_shared_from_this<WsConnection>
+{
+public:
+ WsConnection(WebSocketId Id, std::unique_ptr<asio::ip::tcp::socket> Socket)
+ : m_Id(Id)
+ , m_Socket(std::move(Socket))
+ , m_StartTime(Clock::now())
+ , m_State()
+ {
+ }
+
+ ~WsConnection() = default;
+
+ std::shared_ptr<WsConnection> AsShared() { return shared_from_this(); }
+
+ WebSocketId Id() const { return m_Id; }
+ asio::ip::tcp::socket& Socket() { return *m_Socket; }
+ TimePoint StartTime() const { return m_StartTime; }
+ WebSocketState State() const { return static_cast<WebSocketState>(m_State.load(std::memory_order_relaxed)); }
+ std::string RemoteAddr() const { return m_Socket->remote_endpoint().address().to_string(); }
+ asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
+ WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
+ WebSocketState Close();
+ MessageParser* Parser() { return m_MsgParser.get(); }
+ void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
+ std::mutex& WriteMutex() { return m_WriteMutex; }
+
+private:
+ WebSocketId m_Id;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ TimePoint m_StartTime;
+ std::atomic_uint32_t m_State;
+ std::unique_ptr<MessageParser> m_MsgParser;
+ asio::streambuf m_ReadBuffer;
+ std::mutex m_WriteMutex;
+};
+
+WebSocketState
+WsConnection::Close()
+{
+ const auto PrevState = SetState(WebSocketState::kDisconnected);
+
+ if (PrevState != WebSocketState::kDisconnected && m_Socket->is_open())
+ {
+ m_Socket->close();
+ }
+
+ return PrevState;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsThreadPool
+{
+public:
+ WsThreadPool(asio::io_service& IoSvc) : m_IoSvc(IoSvc) {}
+ void Start(uint32_t ThreadCount);
+ void Stop();
+
+private:
+ asio::io_service& m_IoSvc;
+ std::vector<std::thread> m_Threads;
+ std::atomic_bool m_Running{false};
+};
+
+void
+WsThreadPool::Start(uint32_t ThreadCount)
+{
+ ZEN_ASSERT(m_Threads.empty());
+
+ ZEN_LOG_DEBUG(LogWebSocket, "starting '{}' websocket I/O thread(s)", ThreadCount);
+
+ m_Running = true;
+
+ for (uint32_t Idx = 0; Idx < ThreadCount; Idx++)
+ {
+ m_Threads.emplace_back([this, ThreadId = Idx + 1] {
+ for (;;)
+ {
+ if (m_Running == false)
+ {
+ break;
+ }
+
+ try
+ {
+ m_IoSvc.run();
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_LOG_ERROR(LogWebSocket, "process websocket I/O FAILED, reason '{}'", Err.what());
+ }
+ }
+
+ ZEN_LOG_TRACE(LogWebSocket, "websocket I/O thread '{}' exiting", ThreadId);
+ });
+ }
+}
+
+void
+WsThreadPool::Stop()
+{
+ if (m_Running)
+ {
+ m_Running = false;
+
+ for (std::thread& Thread : m_Threads)
+ {
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
+ }
+
+ m_Threads.clear();
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsServer final : public WebSocketServer
+{
+public:
+ WsServer(const WebSocketServerOptions& Options) : m_Options(Options) {}
+ virtual ~WsServer() { Shutdown(); }
+
+ virtual bool Run() override;
+ virtual void Shutdown() override;
+
+ virtual void RegisterService(WebSocketService& Service) override;
+ virtual void RegisterNotificationHandler(std::string_view Key, WebSocketService& Service) override;
+ virtual void RegisterRequestHandler(std::string_view Key, WebSocketService& Service) override;
+
+ virtual void SendNotification(WebSocketMessage&& Notification) override;
+ virtual void SendResponse(WebSocketMessage&& Response) override;
+
+private:
+ friend class WsConnection;
+
+ void AcceptConnection();
+ void CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec);
+
+ void ReadMessage(std::shared_ptr<WsConnection> Connection);
+ void RouteMessage(WebSocketMessage&& Msg);
+ void SendMessage(WebSocketMessage&& Msg);
+
+ struct IdHasher
+ {
+ size_t operator()(WebSocketId Id) const { return size_t(Id.Value()); }
+ };
+
+ using ConnectionMap = std::unordered_map<WebSocketId, std::shared_ptr<WsConnection>, IdHasher>;
+ using RequestHandlerMap = std::unordered_map<std::string_view, WebSocketService*>;
+ using NotificationHandlerMap = std::unordered_map<std::string_view, std::vector<WebSocketService*>>;
+
+ WebSocketServerOptions m_Options;
+ asio::io_service m_IoSvc;
+ std::unique_ptr<asio::ip::tcp::acceptor> m_Acceptor;
+ std::unique_ptr<WsThreadPool> m_ThreadPool;
+ ConnectionMap m_Connections;
+ std::shared_mutex m_ConnMutex;
+ std::vector<WebSocketService*> m_Services;
+ RequestHandlerMap m_RequestHandlers;
+ NotificationHandlerMap m_NotificationHandlers;
+ std::atomic_bool m_Running{};
+};
+
+void
+WsServer::RegisterService(WebSocketService& Service)
+{
+ m_Services.push_back(&Service);
+
+ Service.Configure(*this);
+}
+
+bool
+WsServer::Run()
+{
+ static constexpr size_t ReceiveBufferSize = 256 << 10;
+ static constexpr size_t SendBufferSize = 256 << 10;
+
+ m_Acceptor = std::make_unique<asio::ip::tcp::acceptor>(m_IoSvc, asio::ip::tcp::v6());
+
+ m_Acceptor->set_option(asio::ip::v6_only(false));
+ m_Acceptor->set_option(asio::socket_base::reuse_address(true));
+ m_Acceptor->set_option(asio::ip::tcp::no_delay(true));
+ m_Acceptor->set_option(asio::socket_base::receive_buffer_size(ReceiveBufferSize));
+ m_Acceptor->set_option(asio::socket_base::send_buffer_size(SendBufferSize));
+
+#if ZEN_PLATFORM_WINDOWS
+ // On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
+ // This must be used by both the client and server side, and is only effective in the absence of
+ // Windows Filtering Platform (WFP) callouts which can be installed by security software.
+ // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
+ SOCKET NativeSocket = m_Acceptor->native_handle();
+ int LoopbackOptionValue = 1;
+ DWORD OptionNumberOfBytesReturned = 0;
+ WSAIoctl(NativeSocket,
+ SIO_LOOPBACK_FAST_PATH,
+ &LoopbackOptionValue,
+ sizeof(LoopbackOptionValue),
+ NULL,
+ 0,
+ &OptionNumberOfBytesReturned,
+ 0,
+ 0);
+#endif
+
+ asio::error_code Ec;
+ m_Acceptor->bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), m_Options.Port), Ec);
+
+ if (Ec)
+ {
+ ZEN_LOG_ERROR(LogWebSocket, "failed to bind websocket endpoint, error code '{}'", Ec.value());
+
+ return false;
+ }
+
+ m_Acceptor->listen();
+ m_Running = true;
+
+ ZEN_LOG_INFO(LogWebSocket, "web socket server running on port '{}'", m_Options.Port);
+
+ AcceptConnection();
+
+ m_ThreadPool = std::make_unique<WsThreadPool>(m_IoSvc);
+ m_ThreadPool->Start(m_Options.ThreadCount);
+
+ return true;
+}
+
+void
+WsServer::Shutdown()
+{
+ if (m_Running)
+ {
+ ZEN_LOG_INFO(LogWebSocket, "websocket server shutting down");
+
+ m_Running = false;
+
+ m_Acceptor->close();
+ m_Acceptor.reset();
+ m_IoSvc.stop();
+
+ m_ThreadPool->Stop();
+ }
+}
+
+void
+WsServer::RegisterNotificationHandler(std::string_view Key, WebSocketService& Service)
+{
+ auto Result = m_NotificationHandlers.try_emplace(Key, std::vector<WebSocketService*>());
+ Result.first->second.push_back(&Service);
+}
+
+void
+WsServer::RegisterRequestHandler(std::string_view Key, WebSocketService& Service)
+{
+ m_RequestHandlers[Key] = &Service;
+}
+
+void
+WsServer::SendNotification(WebSocketMessage&& Notification)
+{
+ ZEN_ASSERT(Notification.MessageType() == WebSocketMessageType::kNotification);
+
+ SendMessage(std::move(Notification));
+}
+void
+WsServer::SendResponse(WebSocketMessage&& Response)
+{
+ ZEN_ASSERT(Response.MessageType() == WebSocketMessageType::kResponse ||
+ Response.MessageType() == WebSocketMessageType::kStreamResponse ||
+ Response.MessageType() == WebSocketMessageType::kStreamCompleteResponse);
+
+ ZEN_ASSERT(Response.CorrelationId() != 0);
+
+ SendMessage(std::move(Response));
+}
+
+void
+WsServer::AcceptConnection()
+{
+ auto Socket = std::make_unique<asio::ip::tcp::socket>(m_IoSvc);
+ asio::ip::tcp::socket& SocketRef = *Socket.get();
+
+ m_Acceptor->async_accept(SocketRef, [this, ConnectedSocket = std::move(Socket)](const asio::error_code& Ec) mutable {
+ if (m_Running)
+ {
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "accept connection FAILED, reason '{}'", Ec.message());
+ }
+ else
+ {
+ auto Connection = std::make_shared<WsConnection>(WebSocketId::New(), std::move(ConnectedSocket));
+
+ ZEN_LOG_DEBUG(LogWebSocket, "accept connection '#{} {}' OK", Connection->Id().Value(), Connection->RemoteAddr());
+
+ {
+ std::unique_lock _(m_ConnMutex);
+ m_Connections[Connection->Id()] = Connection;
+ }
+
+ Connection->SetParser(std::make_unique<HttpMessageParser>(HttpMessageParserType::kRequest));
+ Connection->SetState(WebSocketState::kHandshaking);
+
+ ReadMessage(Connection);
+ }
+
+ AcceptConnection();
+ }
+ });
+}
+
+void
+WsServer::CloseConnection(std::shared_ptr<WsConnection> Connection, const std::error_code& Ec)
+{
+ if (const auto State = Connection->Close(); State != WebSocketState::kDisconnected)
+ {
+ if (Ec)
+ {
+ ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed, reason '{} ({})'", Connection->Id().Value(), Ec.message(), Ec.value());
+ }
+ else
+ {
+ ZEN_LOG_INFO(LogWebSocket, "connection '{}' closed", Connection->Id().Value());
+ }
+ }
+
+ const WebSocketId Id = Connection->Id();
+
+ {
+ std::unique_lock _(m_ConnMutex);
+ if (m_Connections.contains(Id))
+ {
+ m_Connections.erase(Id);
+ }
+ }
+}
+
+void
+WsServer::ReadMessage(std::shared_ptr<WsConnection> Connection)
+{
+ Connection->ReadBuffer().prepare(64 << 10);
+
+ asio::async_read(
+ Connection->Socket(),
+ Connection->ReadBuffer(),
+ asio::transfer_at_least(1),
+ [this, Connection](const asio::error_code& ReadEc, std::size_t) mutable {
+ if (ReadEc)
+ {
+ return CloseConnection(Connection, ReadEc);
+ }
+
+ switch (Connection->State())
+ {
+ case WebSocketState::kHandshaking:
+ {
+ HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Connection->Parser());
+ asio::const_buffer Buffer = Connection->ReadBuffer().data();
+
+ ParseMessageResult Result = Parser.ParseMessage(MemoryView(Buffer.data(), Buffer.size()));
+
+ Connection->ReadBuffer().consume(Result.ByteCount);
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ return ReadMessage(Connection);
+ }
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWebSocket,
+ "handshake with connection '#{} {}' FAILED, reason 'HTTP parse error'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ return CloseConnection(Connection, std::error_code());
+ }
+
+ if (Parser.IsUpgrade() == false)
+ {
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '#{} {}' FAILED, reason 'invalid HTTP upgrade request'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ constexpr auto UpgradeRequiredResponse = "HTTP/1.1 426 Upgrade Required\n\r\n\r"sv;
+
+ return async_write(Connection->Socket(),
+ asio::buffer(UpgradeRequiredResponse),
+ [this, Connection](const asio::error_code& WriteEc, std::size_t) {
+ if (WriteEc)
+ {
+ return CloseConnection(Connection, WriteEc);
+ }
+
+ Connection->Parser()->Reset();
+ Connection->SetState(WebSocketState::kHandshaking);
+
+ ReadMessage(Connection);
+ });
+ }
+
+ ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
+
+ std::string AcceptHash;
+ std::string Reason;
+ const bool ValidHandshake = Parser.ValidateWebSocketHandshake(AcceptHash, Reason);
+
+ if (ValidHandshake == false)
+ {
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '{}' FAILED, reason '{}'",
+ Connection->Id().Value(),
+ Reason);
+
+ constexpr auto UpgradeRequiredResponse = "HTTP/1.1 400 Bad Request\n\r\n\r"sv;
+
+ return async_write(Connection->Socket(),
+ asio::buffer(UpgradeRequiredResponse),
+ [this, &Connection](const asio::error_code& WriteEc, std::size_t) {
+ if (WriteEc)
+ {
+ return CloseConnection(Connection, WriteEc);
+ }
+
+ Connection->Parser()->Reset();
+ Connection->SetState(WebSocketState::kHandshaking);
+
+ ReadMessage(Connection);
+ });
+ }
+
+ ExtendableStringBuilder<128> Sb;
+
+ Sb << "HTTP/1.1 101 Switching Protocols\r\n"sv;
+ Sb << "Upgrade: websocket\r\n"sv;
+ Sb << "Connection: Upgrade\r\n"sv;
+
+ // TODO: Verify protocol
+ if (Parser.Headers().contains(http_header::SecWebSocketProtocol))
+ {
+ Sb << http_header::SecWebSocketProtocol << ": " << Parser.Headers()[http_header::SecWebSocketProtocol]
+ << "\r\n";
+ }
+
+ Sb << http_header::SecWebSocketAccept << ": " << AcceptHash << "\r\n";
+ Sb << "\r\n"sv;
+
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "accepting handshake from connection '#{} {}'",
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ std::string Response = Sb.ToString();
+ Buffer = asio::buffer(Response);
+
+ async_write(Connection->Socket(),
+ Buffer,
+ [this, Connection, _ = std::move(Response)](const asio::error_code& WriteEc, std::size_t ByteCount) {
+ if (WriteEc)
+ {
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake with connection '{}' FAILED, reason '{}'",
+ Connection->Id().Value(),
+ WriteEc.message());
+
+ return CloseConnection(Connection, WriteEc);
+ }
+
+ ZEN_LOG_DEBUG(LogWebSocket,
+ "handshake ({}B) with connection '#{} {}' OK",
+ ByteCount,
+ Connection->Id().Value(),
+ Connection->RemoteAddr());
+
+ Connection->SetParser(std::make_unique<WebSocketMessageParser>());
+ Connection->SetState(WebSocketState::kConnected);
+
+ ReadMessage(Connection);
+ });
+ }
+ break;
+
+ case WebSocketState::kConnected:
+ {
+ WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Connection->Parser());
+
+ uint64_t RemainingBytes = Connection->ReadBuffer().size();
+
+ while (RemainingBytes > 0)
+ {
+ MemoryView MessageData = MemoryView(Connection->ReadBuffer().data().data(), RemainingBytes);
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ Connection->ReadBuffer().consume(Result.ByteCount);
+ RemainingBytes = Connection->ReadBuffer().size();
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
+
+ return CloseConnection(Connection, std::error_code());
+ }
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ ZEN_ASSERT(RemainingBytes == 0);
+ continue;
+ }
+
+ WebSocketMessage Message = Parser.ConsumeMessage();
+ Parser.Reset();
+
+ Message.SetSocketId(Connection->Id());
+
+ RouteMessage(std::move(Message));
+ }
+
+ ReadMessage(Connection);
+ }
+ break;
+
+ default:
+ break;
+ };
+ });
+}
+
+void
+WsServer::RouteMessage(WebSocketMessage&& RoutedMessage)
+{
+ switch (RoutedMessage.MessageType())
+ {
+ case WebSocketMessageType::kRequest:
+ case WebSocketMessageType::kStreamRequest:
+ {
+ CbObjectView Request = RoutedMessage.Body().GetObject();
+ std::string_view Method = Request["Method"].AsString();
+ bool Handled = false;
+ bool Error = false;
+ std::exception Exception;
+
+ if (auto It = m_RequestHandlers.find(Method); It != m_RequestHandlers.end())
+ {
+ WebSocketService* Service = It->second;
+ ZEN_ASSERT(Service);
+
+ try
+ {
+ Handled = Service->HandleRequest(std::move(RoutedMessage));
+ }
+ catch (std::exception& Err)
+ {
+ Exception = std::move(Err);
+ Error = true;
+ }
+ }
+
+ if (Error || Handled == false)
+ {
+ std::string ErrorText = Error ? Exception.what() : fmt::format("'{}' Not Found", Method);
+
+ ZEN_LOG_WARN(LogWebSocket, "route request message FAILED, reason '{}'", ErrorText);
+
+ CbObjectWriter Response;
+ Response << "Error"sv << ErrorText;
+
+ WebSocketMessage ResponseMsg;
+ ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
+ ResponseMsg.SetCorrelationId(RoutedMessage.CorrelationId());
+ ResponseMsg.SetSocketId(RoutedMessage.SocketId());
+ ResponseMsg.SetBody(Response.Save());
+
+ SendResponse(std::move(ResponseMsg));
+ }
+ }
+ break;
+
+ case WebSocketMessageType::kNotification:
+ {
+ CbObjectView Notification = RoutedMessage.Body().GetObject();
+ std::string_view Message = Notification["Message"].AsString();
+
+ if (auto It = m_NotificationHandlers.find(Message); It != m_NotificationHandlers.end())
+ {
+ std::vector<WebSocketService*>& Handlers = It->second;
+
+ for (WebSocketService* Handler : Handlers)
+ {
+ Handler->HandleNotification(RoutedMessage);
+ }
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWebSocket, "route notification message FAILED, unknown notification '{}'", Message);
+ }
+ }
+ break;
+
+ default:
+ break;
+ };
+}
+
+void
+WsServer::SendMessage(WebSocketMessage&& Msg)
+{
+ std::shared_ptr<WsConnection> Connection;
+
+ {
+ std::unique_lock _(m_ConnMutex);
+
+ if (auto It = m_Connections.find(Msg.SocketId()); It != m_Connections.end())
+ {
+ Connection = It->second;
+ }
+ }
+
+ if (Connection.get() == nullptr)
+ {
+ ZEN_LOG_WARN(LogWebSocket, "send message FAILED, reason 'unknown socket ID ({})'", Msg.SocketId().Value());
+ return;
+ }
+
+ if (Connection.get() != nullptr)
+ {
+ BinaryWriter Writer;
+ Msg.Save(Writer);
+
+ ZEN_LOG_TRACE(LogWebSocket,
+ "sending '{}' message, receiver '{}', size '{}', ID '{}', total size {}",
+ ToString(Msg.MessageType()),
+ Connection->Id().Value(),
+ Msg.MessageSize(),
+ Msg.CorrelationId(),
+ NiceBytes(Writer.Size()));
+
+ {
+ ZEN_TRACE_CPU("WS::SendMessage");
+ std::unique_lock _(Connection->WriteMutex());
+ ZEN_TRACE_CPU("WS::WriteSocketData");
+ asio::write(Connection->Socket(), asio::buffer(Writer.Data(), Writer.Size()), asio::transfer_exactly(Writer.Size()));
+ }
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////////
+class WsClient final : public WebSocketClient, public std::enable_shared_from_this<WsClient>
+{
+public:
+ WsClient(asio::io_context& IoCtx) : m_IoCtx(IoCtx), m_Id(WebSocketId::New()) {}
+
+ virtual ~WsClient() { Disconnect(); }
+
+ std::shared_ptr<WsClient> AsShared() { return shared_from_this(); }
+
+ virtual std::future<bool> Connect(const WebSocketConnectInfo& Info) override;
+ virtual void Disconnect() override;
+ virtual bool IsConnected() const override { return false; }
+ virtual WebSocketState State() const override { return static_cast<WebSocketState>(m_State.load()); }
+
+ virtual std::future<WebSocketMessage> SendRequest(WebSocketMessage&& Request) override;
+ virtual void OnNotification(NotificationCallback&& Cb) override;
+ virtual void OnEvent(WebSocketEvent Evt, EventCallback&& Cb) override;
+
+private:
+ WebSocketState SetState(WebSocketState NewState) { return static_cast<WebSocketState>(m_State.exchange(uint32_t(NewState))); }
+ MessageParser* Parser() { return m_MsgParser.get(); }
+ void SetParser(std::unique_ptr<MessageParser>&& Parser) { m_MsgParser = std::move(Parser); }
+ asio::streambuf& ReadBuffer() { return m_ReadBuffer; }
+ void TriggerEvent(WebSocketEvent Evt);
+ void ReadMessage();
+ void RouteMessage(WebSocketMessage&& RoutedMessage);
+
+ using PendingRequestMap = std::unordered_map<uint32_t, std::promise<WebSocketMessage>>;
+
+ asio::io_context& m_IoCtx;
+ WebSocketId m_Id;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<MessageParser> m_MsgParser;
+ asio::streambuf m_ReadBuffer;
+ EventCallback m_EventCallbacks[3];
+ NotificationCallback m_NotificationCallback;
+ PendingRequestMap m_PendingRequests;
+ std::mutex m_RequestMutex;
+ std::promise<bool> m_ConnectPromise;
+ std::atomic_uint32_t m_State;
+ std::string m_Host;
+ int16_t m_Port{};
+};
+
+std::future<bool>
+WsClient::Connect(const WebSocketConnectInfo& Info)
+{
+ if (State() == WebSocketState::kHandshaking || State() == WebSocketState::kConnected)
+ {
+ return m_ConnectPromise.get_future();
+ }
+
+ SetState(WebSocketState::kHandshaking);
+
+ try
+ {
+ asio::ip::tcp::endpoint Endpoint(asio::ip::address::from_string(Info.Host), Info.Port);
+ m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoCtx, Endpoint.protocol());
+
+ m_Socket->connect(Endpoint);
+
+ m_Host = m_Socket->remote_endpoint().address().to_string();
+ m_Port = Info.Port;
+
+ ZEN_LOG_INFO(LogWsClient, "connected to websocket server '{}:{}'", m_Host, m_Port);
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_LOG_WARN(LogWsClient, "connect to websocket server '{}:{}' FAILED, reason '{}'", Info.Host, Info.Port, Err.what());
+
+ SetState(WebSocketState::kError);
+ m_Socket.reset();
+
+ TriggerEvent(WebSocketEvent::kDisconnected);
+
+ m_ConnectPromise.set_value(false);
+
+ return m_ConnectPromise.get_future();
+ }
+
+ ExtendableStringBuilder<128> Sb;
+ Sb << "GET " << Info.Endpoint << " HTTP/1.1\r\n"sv;
+ Sb << "Host: " << Info.Host << "\r\n"sv;
+ Sb << "Upgrade: websocket\r\n"sv;
+ Sb << "Connection: upgrade\r\n"sv;
+ Sb << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"sv;
+
+ if (Info.Protocols.empty() == false)
+ {
+ Sb << "Sec-WebSocket-Protocol: "sv;
+ for (size_t Idx = 0; const auto& Protocol : Info.Protocols)
+ {
+ if (Idx++)
+ {
+ Sb << ", ";
+ }
+ Sb << Protocol;
+ }
+ }
+
+ Sb << "Sec-WebSocket-Version: "sv << Info.Version << "\r\n"sv;
+ Sb << "\r\n";
+
+ std::string HandshakeRequest = Sb.ToString();
+ asio::const_buffer Buffer = asio::buffer(HandshakeRequest);
+
+ ZEN_LOG_DEBUG(LogWsClient, "handshaking with '{}:{}'", m_Host, m_Port);
+
+ m_MsgParser = std::make_unique<HttpMessageParser>(HttpMessageParserType::kResponse);
+ m_MsgParser->Reset();
+
+ async_write(*m_Socket, Buffer, [Self = AsShared(), _ = std::move(HandshakeRequest)](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ ZEN_LOG_ERROR(LogWsClient, "write data FAILED, reason '{}'", Ec.message());
+
+ Self->Disconnect();
+ }
+ else
+ {
+ Self->ReadMessage();
+ }
+ });
+
+ return m_ConnectPromise.get_future();
+}
+
+void
+WsClient::Disconnect()
+{
+ if (auto PrevState = SetState(WebSocketState::kDisconnected); PrevState != WebSocketState::kDisconnected)
+ {
+ ZEN_LOG_INFO(LogWsClient, "closing connection to '{}:{}'", m_Host, m_Port);
+
+ if (m_Socket && m_Socket->is_open())
+ {
+ m_Socket->close();
+ m_Socket.reset();
+ }
+
+ TriggerEvent(WebSocketEvent::kDisconnected);
+
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ for (auto& Kv : m_PendingRequests)
+ {
+ Kv.second.set_value(WebSocketMessage());
+ }
+
+ m_PendingRequests.clear();
+ }
+ }
+}
+
+std::future<WebSocketMessage>
+WsClient::SendRequest(WebSocketMessage&& Request)
+{
+ ZEN_ASSERT(Request.MessageType() == WebSocketMessageType::kRequest);
+
+ BinaryWriter Writer;
+ Request.Save(Writer);
+
+ std::future<WebSocketMessage> FutureResponse;
+
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ auto Result = m_PendingRequests.try_emplace(Request.CorrelationId(), std::promise<WebSocketMessage>());
+ ZEN_ASSERT(Result.second);
+
+ auto It = Result.first;
+ FutureResponse = It->second.get_future();
+ }
+
+ IoBuffer Buffer = IoBufferBuilder::MakeCloneFromMemory(Writer.Data(), Writer.Size());
+
+ async_write(*m_Socket, asio::buffer(Buffer.Data(), Buffer.Size()), [Self = AsShared()](const std::error_code& Ec, size_t) {
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWsClient, "send request message FAILED, reason '{}'", Ec.message());
+
+ Self->Disconnect();
+ }
+ });
+
+ return FutureResponse;
+}
+
+void
+WsClient::OnNotification(NotificationCallback&& Cb)
+{
+ m_NotificationCallback = std::move(Cb);
+}
+
+void
+WsClient::OnEvent(WebSocketEvent Evt, WebSocketClient::EventCallback&& Cb)
+{
+ m_EventCallbacks[static_cast<uint32_t>(Evt)] = std::move(Cb);
+}
+
+void
+WsClient::TriggerEvent(WebSocketEvent Evt)
+{
+ const uint32_t Index = static_cast<uint32_t>(Evt);
+
+ if (m_EventCallbacks[Index])
+ {
+ m_EventCallbacks[Index]();
+ }
+}
+
+void
+WsClient::ReadMessage()
+{
+ m_ReadBuffer.prepare(64 << 10);
+
+ async_read(*m_Socket,
+ m_ReadBuffer,
+ asio::transfer_at_least(1),
+ [Self = AsShared()](const asio::error_code& Ec, std::size_t ByteCount) mutable {
+ const WebSocketState State = Self->State();
+
+ if (State == WebSocketState::kDisconnected)
+ {
+ return;
+ }
+
+ if (Ec)
+ {
+ ZEN_LOG_WARN(LogWsClient, "read message FAILED, reason '{}'", Ec.message());
+
+ return Self->Disconnect();
+ }
+
+ switch (State)
+ {
+ case WebSocketState::kHandshaking:
+ {
+ HttpMessageParser& Parser = *reinterpret_cast<HttpMessageParser*>(Self->Parser());
+
+ MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), ByteCount);
+
+ ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ Self->ReadBuffer().consume(size_t(Result.ByteCount));
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWsClient, "handshake FAILED, status code '{}'", Parser.StatusCode());
+
+ Self->m_ConnectPromise.set_value(false);
+
+ return Self->Disconnect();
+ }
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ return Self->ReadMessage();
+ }
+
+ ZEN_ASSERT(Result.Status == ParseMessageStatus::kDone);
+
+ if (Parser.StatusCode() != 101)
+ {
+ ZEN_LOG_WARN(LogWsClient,
+ "handshake FAILED, status '{}', status code '{}'",
+ Parser.StatusText(),
+ Parser.StatusCode());
+
+ Self->m_ConnectPromise.set_value(false);
+
+ return Self->Disconnect();
+ }
+
+ ZEN_LOG_INFO(LogWsClient, "handshake OK, status '{}'", Parser.StatusText());
+
+ Self->SetParser(std::make_unique<WebSocketMessageParser>());
+ Self->SetState(WebSocketState::kConnected);
+ Self->ReadMessage();
+ Self->TriggerEvent(WebSocketEvent::kConnected);
+
+ Self->m_ConnectPromise.set_value(true);
+ }
+ break;
+
+ case WebSocketState::kConnected:
+ {
+ WebSocketMessageParser& Parser = *reinterpret_cast<WebSocketMessageParser*>(Self->Parser());
+
+ uint64_t RemainingBytes = Self->ReadBuffer().size();
+
+ while (RemainingBytes > 0)
+ {
+ MemoryView MessageData = MemoryView(Self->ReadBuffer().data().data(), RemainingBytes);
+ const ParseMessageResult Result = Parser.ParseMessage(MessageData);
+
+ Self->ReadBuffer().consume(Result.ByteCount);
+ RemainingBytes = Self->ReadBuffer().size();
+
+ if (Result.Status == ParseMessageStatus::kError)
+ {
+ ZEN_LOG_WARN(LogWsClient, "parse websocket message FAILED, reason '{}'", Result.Reason.value());
+
+ Parser.Reset();
+ continue;
+ }
+
+ if (Result.Status == ParseMessageStatus::kContinue)
+ {
+ ZEN_ASSERT(RemainingBytes == 0);
+ continue;
+ }
+
+ WebSocketMessage Message = Parser.ConsumeMessage();
+ Parser.Reset();
+
+ Self->RouteMessage(std::move(Message));
+ }
+
+ Self->ReadMessage();
+ }
+ break;
+
+ default:
+ break;
+ }
+ });
+}
+
+void
+WsClient::RouteMessage(WebSocketMessage&& RoutedMessage)
+{
+ switch (RoutedMessage.MessageType())
+ {
+ case WebSocketMessageType::kResponse:
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ if (auto It = m_PendingRequests.find(RoutedMessage.CorrelationId()); It != m_PendingRequests.end())
+ {
+ It->second.set_value(std::move(RoutedMessage));
+ m_PendingRequests.erase(It);
+ }
+ else
+ {
+ ZEN_LOG_WARN(LogWsClient,
+ "route request message FAILED, reason 'unknown correlation ID ({})'",
+ RoutedMessage.CorrelationId());
+ }
+ }
+ break;
+
+ case WebSocketMessageType::kNotification:
+ {
+ std::unique_lock _(m_RequestMutex);
+
+ if (m_NotificationCallback)
+ {
+ m_NotificationCallback(std::move(RoutedMessage));
+ }
+ }
+ break;
+
+ default:
+ ZEN_LOG_WARN(LogWsClient, "route message FAILED, reason 'invalid message type ({})'", uint8_t(RoutedMessage.MessageType()));
+ break;
+ };
+}
+
+} // namespace zen::websocket
+
+namespace zen {
+
+std::atomic_uint32_t WebSocketId::NextId{1};
+
+bool
+WebSocketMessage::Header::IsValid() const
+{
+ return Magic == ExpectedMagic && StatusCode > 0 && uint8_t(MessageType) > uint8_t(WebSocketMessageType::kInvalid) &&
+ uint8_t(MessageType) < uint8_t(WebSocketMessageType::kCount);
+}
+
+std::atomic_uint32_t WebSocketMessage::NextCorrelationId{1};
+
+void
+WebSocketMessage::SetMessageType(WebSocketMessageType MessageType)
+{
+ m_Header.MessageType = MessageType;
+}
+
+void
+WebSocketMessage::SetBody(CbPackage&& Body)
+{
+ m_Body = std::move(Body);
+}
+void
+WebSocketMessage::SetBody(CbObject&& Body)
+{
+ CbPackage Pkg;
+ Pkg.SetObject(Body);
+
+ SetBody(std::move(Pkg));
+}
+
+void
+WebSocketMessage::Save(BinaryWriter& Writer)
+{
+ Writer.Write(&m_Header, HeaderSize);
+
+ if (m_Body.has_value())
+ {
+ const CbObject& Obj = m_Body.value().GetObject();
+ MemoryView View = Obj.GetBuffer().GetView();
+
+ const CbValidateError ValidationResult = ValidateCompactBinary(View, CbValidateMode::All);
+ ZEN_ASSERT(ValidationResult == CbValidateError::None);
+
+ m_Body.value().Save(Writer);
+ }
+
+ if (m_Header.CorrelationId == 0 && MessageType() == WebSocketMessageType::kRequest)
+ {
+ m_Header.CorrelationId = NextCorrelationId.fetch_add(1);
+ }
+
+ m_Header.MessageSize = Writer.Size() - HeaderSize;
+
+ Writer.GetMutableView().CopyFrom(MakeMemoryView(&m_Header, HeaderSize));
+}
+
+bool
+WebSocketMessage::TryLoadHeader(MemoryView Memory)
+{
+ if (Memory.GetSize() < HeaderSize)
+ {
+ return false;
+ }
+
+ MutableMemoryView HeaderView(&m_Header, HeaderSize);
+
+ HeaderView.CopyFrom(Memory);
+
+ return m_Header.IsValid();
+}
+
+void
+WebSocketService::Configure(WebSocketServer& Server)
+{
+ ZEN_ASSERT(m_SocketServer == nullptr);
+
+ m_SocketServer = &Server;
+
+ RegisterHandlers(Server);
+}
+
+void
+WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbPackage&& StreamResponse, bool IsStreamComplete)
+{
+ WebSocketMessage Message;
+
+ Message.SetMessageType(IsStreamComplete ? WebSocketMessageType::kStreamCompleteResponse : WebSocketMessageType::kStreamResponse);
+ Message.SetCorrelationId(CorrelationId);
+ Message.SetSocketId(SocketId);
+ Message.SetBody(std::move(StreamResponse));
+
+ SocketServer().SendResponse(std::move(Message));
+}
+
+void
+WebSocketService::SendStreamResponse(WebSocketId SocketId, uint32_t CorrelationId, CbObject&& StreamResponse, bool IsStreamComplete)
+{
+ CbPackage Response;
+ Response.SetObject(std::move(StreamResponse));
+
+ SendStreamResponse(SocketId, CorrelationId, std::move(Response), IsStreamComplete);
+}
+
+std::unique_ptr<WebSocketServer>
+WebSocketServer::Create(const WebSocketServerOptions& Options)
+{
+ return std::make_unique<websocket::WsServer>(Options);
+}
+
+std::shared_ptr<WebSocketClient>
+WebSocketClient::Create(asio::io_context& IoCtx)
+{
+ return std::make_shared<websocket::WsClient>(IoCtx);
+}
+
+} // namespace zen
diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua
new file mode 100644
index 000000000..b0dbdbc79
--- /dev/null
+++ b/src/zenhttp/xmake.lua
@@ -0,0 +1,14 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target('zenhttp')
+ set_kind("static")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_files("httpsys.cpp", {unity_ignored=true})
+ add_includedirs("include", {public=true})
+ add_deps("zencore")
+ add_packages(
+ "vcpkg::gsl-lite",
+ "vcpkg::http-parser"
+ )
+ add_options("httpsys")
diff --git a/src/zenhttp/zenhttp.cpp b/src/zenhttp/zenhttp.cpp
new file mode 100644
index 000000000..4bd6a5697
--- /dev/null
+++ b/src/zenhttp/zenhttp.cpp
@@ -0,0 +1,22 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/zenhttp.h>
+
+#if ZEN_WITH_TESTS
+
+# include <zenhttp/httpclient.h>
+# include <zenhttp/httpserver.h>
+# include <zenhttp/httpshared.h>
+
+namespace zen {
+
+void
+zenhttp_forcelinktests()
+{
+ http_forcelink();
+ forcelink_httpshared();
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zenserver-test/cachepolicy-tests.cpp b/src/zenserver-test/cachepolicy-tests.cpp
new file mode 100644
index 000000000..79d78e522
--- /dev/null
+++ b/src/zenserver-test/cachepolicy-tests.cpp
@@ -0,0 +1,153 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/zencore.h>
+
+#if ZEN_WITH_TESTS
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/string.h>
+# include <zencore/testing.h>
+# include <zencore/uid.h>
+# include <zenutil/cache/cachepolicy.h>
+
+namespace zen::tests {
+
+using namespace std::literals;
+
+TEST_CASE("cachepolicy")
+{
+ SUBCASE("atomics serialization")
+ {
+ CachePolicy SomeAtomics[] = {CachePolicy::None,
+ CachePolicy::QueryLocal,
+ CachePolicy::StoreRemote,
+ CachePolicy::SkipData,
+ CachePolicy::KeepAlive};
+ for (CachePolicy Atomic : SomeAtomics)
+ {
+ CHECK(ParseCachePolicy(WriteToString<128>(Atomic)) == Atomic);
+ }
+ // Also verify that we ignore unrecognized bits
+ for (CachePolicy Atomic : SomeAtomics)
+ {
+ CHECK(ParseCachePolicy(WriteToString<128>(Atomic | (CachePolicy)0x10000000)) == Atomic);
+ }
+ }
+ SUBCASE("aliases serialization")
+ {
+ CachePolicy SomeAliases[] = {CachePolicy::Query, CachePolicy::Local};
+ for (CachePolicy Alias : SomeAliases)
+ {
+ CHECK(ParseCachePolicy(WriteToString<128>(Alias)) == Alias);
+ }
+ // Also verify that we ignore unrecognized bits
+ for (CachePolicy Alias : SomeAliases)
+ {
+ CHECK(ParseCachePolicy(WriteToString<128>(Alias | (CachePolicy)0x10000000)) == Alias);
+ }
+ }
+ SUBCASE("aliases take priority over atomics")
+ {
+ CHECK(WriteToString<128>(CachePolicy::Default).ToView() == "Default"sv);
+ CHECK(WriteToString<128>(CachePolicy::Query).ToView() == "Query"sv);
+ CHECK(WriteToString<128>(CachePolicy::Local).ToView() == "Local"sv);
+ }
+ SUBCASE("policies requiring multiple strings work")
+ {
+ char Delimiter = ',';
+ CachePolicy Combination = CachePolicy::SkipData | CachePolicy::QueryLocal;
+ CHECK(WriteToString<128>(Combination).ToView().find(Delimiter) != std::string_view::npos);
+ CHECK(ParseCachePolicy(WriteToString<128>(Combination)) == Combination);
+ }
+ SUBCASE("parsing invalid text")
+ {
+ CHECK(ParseCachePolicy(",,,") == CachePolicy::None);
+ CHECK(ParseCachePolicy("fee,fie,foo,fum") == CachePolicy::None);
+ CHECK(ParseCachePolicy("fee,KeepAlive,foo,fum") == CachePolicy::KeepAlive);
+ }
+}
+
+TEST_CASE("cacherecordpolicy")
+{
+ SUBCASE("policy with no values")
+ {
+ CachePolicy Policy = CachePolicy::SkipData | CachePolicy::QueryLocal | CachePolicy::PartialRecord;
+ CachePolicy ValuePolicy = Policy & CacheValuePolicy::PolicyMask;
+ CacheRecordPolicy RecordPolicy;
+ CacheRecordPolicyBuilder Builder(Policy);
+ RecordPolicy = Builder.Build();
+ SUBCASE("construct")
+ {
+ CHECK(RecordPolicy.IsUniform());
+ CHECK(RecordPolicy.GetRecordPolicy() == Policy);
+ CHECK(RecordPolicy.GetBasePolicy() == Policy);
+ CHECK(RecordPolicy.GetValuePolicy(Oid::NewOid()) == ValuePolicy);
+ CHECK(RecordPolicy.GetValuePolicies().size() == 0);
+ }
+ SUBCASE("saveload")
+ {
+ CbWriter Writer;
+ RecordPolicy.Save(Writer);
+ CbObject Saved = Writer.Save()->AsObject();
+ CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get();
+ CHECK(Loaded.IsUniform());
+ CHECK(Loaded.GetRecordPolicy() == Policy);
+ CHECK(Loaded.GetBasePolicy() == Policy);
+ CHECK(Loaded.GetValuePolicy(Oid::NewOid()) == ValuePolicy);
+ CHECK(Loaded.GetValuePolicies().size() == 0);
+ }
+ }
+
+ SUBCASE("policy with values")
+ {
+ CachePolicy DefaultPolicy = CachePolicy::StoreRemote | CachePolicy::QueryLocal | CachePolicy::PartialRecord;
+ CachePolicy DefaultValuePolicy = DefaultPolicy & CacheValuePolicy::PolicyMask;
+ CachePolicy PartialOverlap = CachePolicy::StoreRemote;
+ CachePolicy NoOverlap = CachePolicy::QueryRemote;
+ CachePolicy UnionPolicy = DefaultPolicy | PartialOverlap | NoOverlap | CachePolicy::PartialRecord;
+
+ CacheRecordPolicy RecordPolicy;
+ CacheRecordPolicyBuilder Builder(DefaultPolicy);
+ Oid PartialOid = Oid::NewOid();
+ Oid NoOverlapOid = Oid::NewOid();
+ Oid OtherOid = Oid::NewOid();
+ Builder.AddValuePolicy(PartialOid, PartialOverlap);
+ Builder.AddValuePolicy(NoOverlapOid, NoOverlap);
+ RecordPolicy = Builder.Build();
+ SUBCASE("construct")
+ {
+ CHECK(!RecordPolicy.IsUniform());
+ CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy);
+ CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy);
+ CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap);
+ CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap);
+ CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy);
+ CHECK(RecordPolicy.GetValuePolicies().size() == 2);
+ }
+ SUBCASE("saveload")
+ {
+ CbWriter Writer;
+ RecordPolicy.Save(Writer);
+ CbObject Saved = Writer.Save()->AsObject();
+ CacheRecordPolicy Loaded = CacheRecordPolicy::Load(Saved).Get();
+ CHECK(!RecordPolicy.IsUniform());
+ CHECK(RecordPolicy.GetRecordPolicy() == UnionPolicy);
+ CHECK(RecordPolicy.GetBasePolicy() == DefaultPolicy);
+ CHECK(RecordPolicy.GetValuePolicy(PartialOid) == PartialOverlap);
+ CHECK(RecordPolicy.GetValuePolicy(NoOverlapOid) == NoOverlap);
+ CHECK(RecordPolicy.GetValuePolicy(OtherOid) == DefaultValuePolicy);
+ CHECK(RecordPolicy.GetValuePolicies().size() == 2);
+ }
+ }
+
+ SUBCASE("parsing invalid text")
+ {
+ OptionalCacheRecordPolicy Loaded = CacheRecordPolicy::Load(CbObject());
+ CHECK(Loaded.IsNull());
+ }
+}
+
+} // namespace zen::tests
+
+#endif
diff --git a/src/zenserver-test/projectclient.cpp b/src/zenserver-test/projectclient.cpp
new file mode 100644
index 000000000..597838e0d
--- /dev/null
+++ b/src/zenserver-test/projectclient.cpp
@@ -0,0 +1,164 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "projectclient.h"
+
+#if 0
+
+# include <zencore/compactbinary.h>
+# include <zencore/logging.h>
+# include <zencore/sharedbuffer.h>
+# include <zencore/string.h>
+# include <zencore/zencore.h>
+
+# include <asio.hpp>
+# include <gsl/gsl-lite.hpp>
+
+# if ZEN_PLATFORM_WINDOWS
+# include <atlbase.h>
+# endif
+
+namespace zen {
+
+struct ProjectClientConnection
+{
+ ProjectClientConnection(int BasePort) { Connect(BasePort); }
+
+ void Connect(int BasePort)
+ {
+ ZEN_UNUSED(BasePort);
+
+ WideStringBuilder<64> PipeName;
+ PipeName << "\\\\.\\pipe\\zenprj"; // TODO: this should use an instance-specific identifier!
+
+ HANDLE hPipe = CreateFileW(PipeName.c_str(),
+ GENERIC_READ | GENERIC_WRITE,
+ 0, // Sharing doesn't make any sense
+ nullptr, // No security attributes
+ OPEN_EXISTING, // Open existing pipe
+ 0, // Attributes
+ nullptr // Template file
+ );
+
+ if (hPipe == INVALID_HANDLE_VALUE)
+ {
+ ZEN_WARN("failed while creating named pipe {}", WideToUtf8(PipeName));
+
+ throw std::system_error(GetLastError(), std::system_category(), fmt::format("Failed to open named pipe '{}'", WideToUtf8(PipeName)));
+ }
+
+ // Change to message mode
+ DWORD dwMode = PIPE_READMODE_MESSAGE;
+ BOOL Success = SetNamedPipeHandleState(hPipe, &dwMode, nullptr, nullptr);
+
+ if (!Success)
+ {
+ throw std::system_error(GetLastError(),
+ std::system_category(),
+ fmt::format("Failed to change named pipe '{}' to message mode", WideToUtf8(PipeName)));
+ }
+
+ m_hPipe.Attach(hPipe); // This now owns the handle and will close it
+ }
+
+ ~ProjectClientConnection() {}
+
+ CbObject MessageTransaction(CbObject Request)
+ {
+ DWORD dwWrittenBytes = 0;
+
+ MemoryView View = Request.GetView();
+
+ BOOL Success = ::WriteFile(m_hPipe, View.GetData(), gsl::narrow_cast<DWORD>(View.GetSize()), &dwWrittenBytes, nullptr);
+
+ if (!Success)
+ {
+ throw std::system_error(GetLastError(), std::system_category(), "Failed to write pipe message");
+ }
+
+ ZEN_ASSERT(dwWrittenBytes == View.GetSize());
+
+ DWORD dwReadBytes = 0;
+
+ Success = ReadFile(m_hPipe, m_Buffer, sizeof m_Buffer, &dwReadBytes, nullptr);
+
+ if (!Success)
+ {
+ DWORD ErrorCode = GetLastError();
+
+ if (ERROR_MORE_DATA == ErrorCode)
+ {
+ // Response message is larger than our buffer - handle it by allocating a larger
+ // buffer on the heap and read the remainder into that buffer
+
+ DWORD dwBytesAvail = 0, dwLeftThisMessage = 0;
+
+ Success = PeekNamedPipe(m_hPipe, nullptr, 0, nullptr, &dwBytesAvail, &dwLeftThisMessage);
+
+ if (Success)
+ {
+ UniqueBuffer MessageBuffer = UniqueBuffer::Alloc(dwReadBytes + dwLeftThisMessage);
+
+ memcpy(MessageBuffer.GetData(), m_Buffer, dwReadBytes);
+
+ Success = ReadFile(m_hPipe,
+ reinterpret_cast<uint8_t*>(MessageBuffer.GetData()) + dwReadBytes,
+ dwLeftThisMessage,
+ &dwReadBytes,
+ nullptr);
+
+ if (Success)
+ {
+ return CbObject(SharedBuffer(std::move(MessageBuffer)));
+ }
+ }
+ }
+
+ throw std::system_error(GetLastError(), std::system_category(), "Failed to read pipe message");
+ }
+
+ return CbObject(SharedBuffer::MakeView(MakeMemoryView(m_Buffer)));
+ }
+
+private:
+ static const int kEmbeddedBufferSize = 512 - 16;
+
+ CHandle m_hPipe;
+ uint8_t m_Buffer[kEmbeddedBufferSize];
+};
+
+struct LocalProjectClient::ClientImpl
+{
+ ClientImpl(int BasePort) : m_BasePort(BasePort) {}
+ ~ClientImpl() {}
+
+ void Start() {}
+ void Stop() {}
+
+ inline int BasePort() const { return m_BasePort; }
+
+private:
+ int m_BasePort = 0;
+};
+
+LocalProjectClient::LocalProjectClient(int BasePort)
+{
+ m_Impl = std::make_unique<ClientImpl>(BasePort);
+ m_Impl->Start();
+}
+
+LocalProjectClient::~LocalProjectClient()
+{
+ m_Impl->Stop();
+}
+
+CbObject
+LocalProjectClient::MessageTransaction(CbObject Request)
+{
+ ProjectClientConnection Cx(m_Impl->BasePort());
+
+ return Cx.MessageTransaction(Request);
+}
+
+} // namespace zen
+
+#endif // 0
diff --git a/src/zenserver-test/projectclient.h b/src/zenserver-test/projectclient.h
new file mode 100644
index 000000000..1865dd67d
--- /dev/null
+++ b/src/zenserver-test/projectclient.h
@@ -0,0 +1,32 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <memory>
+
+#include <zencore/compactbinary.h>
+#include <zencore/refcount.h>
+
+namespace zen {
+
+/**
+ * Client for communication with local project service
+ *
+ * This is WIP and not yet functional!
+ */
+
+class LocalProjectClient : public RefCounted
+{
+public:
+ LocalProjectClient(int BasePort = 0);
+ ~LocalProjectClient();
+
+ CbObject MessageTransaction(CbObject Request);
+
+private:
+ struct ClientImpl;
+
+ std::unique_ptr<ClientImpl> m_Impl;
+};
+
+} // namespace zen
diff --git a/src/zenserver-test/xmake.lua b/src/zenserver-test/xmake.lua
new file mode 100644
index 000000000..f0b34f6ca
--- /dev/null
+++ b/src/zenserver-test/xmake.lua
@@ -0,0 +1,16 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target("zenserver-test")
+ set_kind("binary")
+ add_headerfiles("**.h")
+ add_files("*.cpp")
+ add_files("zenserver-test.cpp", {unity_ignored = true })
+ add_deps("zencore", "zenutil", "zenhttp")
+ add_deps("zenserver", {inherit=false})
+ add_packages("vcpkg::http-parser", "vcpkg::mimalloc")
+
+ if is_plat("macosx") then
+ add_ldflags("-framework CoreFoundation")
+ add_ldflags("-framework Security")
+ add_ldflags("-framework SystemConfiguration")
+ end
diff --git a/src/zenserver-test/zenserver-test.cpp b/src/zenserver-test/zenserver-test.cpp
new file mode 100644
index 000000000..3195181d1
--- /dev/null
+++ b/src/zenserver-test/zenserver-test.cpp
@@ -0,0 +1,3323 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#define _SILENCE_CXX17_C_HEADER_DEPRECATION_WARNING
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compress.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/refcount.h>
+#include <zencore/stream.h>
+#include <zencore/string.h>
+#include <zencore/testutils.h>
+#include <zencore/thread.h>
+#include <zencore/timer.h>
+#include <zencore/xxhash.h>
+#include <zenhttp/httpclient.h>
+#include <zenhttp/httpshared.h>
+#include <zenhttp/websocket.h>
+#include <zenhttp/zenhttp.h>
+#include <zenutil/cache/cache.h>
+#include <zenutil/cache/cacherequests.h>
+#include <zenutil/zenserverprocess.h>
+
+#if ZEN_USE_MIMALLOC
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <mimalloc.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+#endif
+
+#include <http_parser.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# pragma comment(lib, "Crypt32.lib")
+# pragma comment(lib, "Wldap32.lib")
+#endif
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#undef GetObject
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <atomic>
+#include <filesystem>
+#include <map>
+#include <random>
+#include <span>
+#include <thread>
+#include <typeindex>
+#include <unordered_map>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <ppl.h>
+# include <atlbase.h>
+# include <process.h>
+#endif
+
+#include <asio.hpp>
+
+//////////////////////////////////////////////////////////////////////////
+
+#include "projectclient.h"
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+# define ZEN_TEST_WITH_RUNNER 1
+# include <zencore/testing.h>
+# include <zencore/workthreadpool.h>
+#endif
+
+using namespace std::literals;
+
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+struct Concurrency
+{
+ template<typename... T>
+ static void parallel_invoke(T&&... t)
+ {
+ constexpr size_t NumTs = sizeof...(t);
+ std::thread Threads[NumTs] = {
+ std::thread(std::forward<T>(t))...,
+ };
+
+ for (std::thread& Thread : Threads)
+ {
+ Thread.join();
+ }
+ }
+};
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Custom logging -- test code, this should be tweaked
+//
+
+namespace logging {
+using namespace spdlog;
+using namespace spdlog::details;
+using namespace std::literals;
+
+class full_test_formatter final : public spdlog::formatter
+{
+public:
+ full_test_formatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) : m_Epoch(Epoch), m_LogId(LogId)
+ {
+ }
+
+ virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<full_test_formatter>(m_LogId, m_Epoch); }
+
+ static constexpr bool UseDate = false;
+
+ virtual void format(const details::log_msg& msg, memory_buf_t& dest) override
+ {
+ using std::chrono::duration_cast;
+ using std::chrono::milliseconds;
+ using std::chrono::seconds;
+
+ if constexpr (UseDate)
+ {
+ auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch());
+ if (secs != m_LastLogSecs)
+ {
+ m_CachedTm = os::localtime(log_clock::to_time_t(msg.time));
+ m_LastLogSecs = secs;
+ }
+ }
+
+ const auto& tm_time = m_CachedTm;
+
+ // cache the date/time part for the next second.
+ auto duration = msg.time - m_Epoch;
+ auto secs = duration_cast<seconds>(duration);
+
+ if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0)
+ {
+ m_CachedDatetime.clear();
+ m_CachedDatetime.push_back('[');
+
+ if constexpr (UseDate)
+ {
+ fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime);
+ m_CachedDatetime.push_back('-');
+
+ fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime);
+ m_CachedDatetime.push_back('-');
+
+ fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime);
+ m_CachedDatetime.push_back(' ');
+
+ fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+
+ fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+
+ fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime);
+ }
+ else
+ {
+ int Count = int(secs.count());
+
+ const int LogSecs = Count % 60;
+ Count /= 60;
+
+ const int LogMins = Count % 60;
+ Count /= 60;
+
+ const int LogHours = Count;
+
+ fmt_helper::pad2(LogHours, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+ fmt_helper::pad2(LogMins, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+ fmt_helper::pad2(LogSecs, m_CachedDatetime);
+ }
+
+ m_CachedDatetime.push_back('.');
+
+ m_CacheTimestamp = secs;
+ }
+
+ dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
+
+ auto millis = fmt_helper::time_fraction<milliseconds>(msg.time);
+ fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+
+ if (!m_LogId.empty())
+ {
+ dest.push_back('[');
+ fmt_helper::append_string_view(m_LogId, dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+ }
+
+ // append logger name if exists
+ if (msg.logger_name.size() > 0)
+ {
+ dest.push_back('[');
+ fmt_helper::append_string_view(msg.logger_name, dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+ }
+
+ dest.push_back('[');
+ // wrap the level name with color
+ msg.color_range_start = dest.size();
+ fmt_helper::append_string_view(level::to_string_view(msg.level), dest);
+ msg.color_range_end = dest.size();
+ dest.push_back(']');
+ dest.push_back(' ');
+
+ // add source location if present
+ if (!msg.source.empty())
+ {
+ dest.push_back('[');
+ const char* filename = details::short_filename_formatter<details::null_scoped_padder>::basename(msg.source.filename);
+ fmt_helper::append_string_view(filename, dest);
+ dest.push_back(':');
+ fmt_helper::append_int(msg.source.line, dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+ }
+
+ fmt_helper::append_string_view(msg.payload, dest);
+ fmt_helper::append_string_view("\n"sv, dest);
+ }
+
+private:
+ std::chrono::time_point<std::chrono::system_clock> m_Epoch;
+ std::tm m_CachedTm;
+ std::chrono::seconds m_LastLogSecs;
+ std::chrono::seconds m_CacheTimestamp{0};
+ memory_buf_t m_CachedDatetime;
+ std::string m_LogId;
+};
+} // namespace logging
+
+//////////////////////////////////////////////////////////////////////////
+
+#if 0
+
+int
+main()
+{
+ mi_version();
+
+ zen::Sleep(1000);
+
+ zen::Stopwatch timer;
+
+ const int RequestCount = 100000;
+
+ cpr::Session Sessions[10];
+
+ for (auto& Session : Sessions)
+ {
+ Session.SetUrl(cpr::Url{"http://localhost:1337/test/hello"});
+ //Session.SetUrl(cpr::Url{ "http://arn-wd-l0182:1337/test/hello" });
+ }
+
+ auto Run = [](cpr::Session& Session) {
+ for (int i = 0; i < 10000; ++i)
+ {
+ cpr::Response Result = Session.Get();
+
+ if (Result.status_code != 200)
+ {
+ ZEN_WARN("request response: {}", Result.status_code);
+ }
+ }
+ };
+
+ Concurrency::parallel_invoke([&] { Run(Sessions[0]); },
+ [&] { Run(Sessions[1]); },
+ [&] { Run(Sessions[2]); },
+ [&] { Run(Sessions[3]); },
+ [&] { Run(Sessions[4]); },
+ [&] { Run(Sessions[5]); },
+ [&] { Run(Sessions[6]); },
+ [&] { Run(Sessions[7]); },
+ [&] { Run(Sessions[8]); },
+ [&] { Run(Sessions[9]); });
+
+ // cpr::Response r = cpr::Get(cpr::Url{ "http://localhost:1337/test/hello" });
+
+ ZEN_INFO("{} requests in {} ({})",
+ RequestCount,
+ zen::NiceTimeSpanMs(timer.GetElapsedTimeMs()),
+ zen::NiceRate(RequestCount, (uint32_t)timer.GetElapsedTimeMs(), "req"));
+
+ return 0;
+}
+#elif 0
+// #include <restinio/all.hpp>
+
+int
+main()
+{
+ mi_version();
+ restinio::run(restinio::on_thread_pool(32).port(8080).request_handler(
+ [](auto req) { return req->create_response().set_body("Hello, World!").done(); }));
+ return 0;
+}
+#elif ZEN_WITH_TESTS
+
+zen::ZenServerEnvironment TestEnv;
+
+int
+main(int argc, char** argv)
+{
+ using namespace std::literals;
+
+# if ZEN_USE_MIMALLOC
+ mi_version();
+# endif
+
+ zen::zencore_forcelinktests();
+ zen::zenhttp_forcelinktests();
+ zen::cacherequests_forcelink();
+
+ zen::logging::InitializeLogging();
+
+ spdlog::set_level(spdlog::level::debug);
+ spdlog::set_formatter(std::make_unique< ::logging::full_test_formatter>("test", std::chrono::system_clock::now()));
+
+ std::filesystem::path ProgramBaseDir = std::filesystem::path(argv[0]).parent_path();
+ std::filesystem::path TestBaseDir = ProgramBaseDir.parent_path().parent_path() / ".test";
+
+ // This is pretty janky because we're passing most of the options through to the test
+ // framework, so we can't just use cxxopts (I think). This should ideally be cleaned up
+ // somehow in the future
+
+ std::string ServerClass;
+
+ for (int i = 1; i < argc; ++i)
+ {
+ if (argv[i] == "--http"sv)
+ {
+ if ((i + 1) < argc)
+ {
+ ServerClass = argv[++i];
+ }
+ }
+ }
+
+ TestEnv.InitializeForTest(ProgramBaseDir, TestBaseDir, ServerClass);
+
+ ZEN_INFO("Running tests...(base dir: '{}')", TestBaseDir);
+
+ zen::testing::TestRunner Runner;
+ Runner.ApplyCommandLine(argc, argv);
+
+ return Runner.Run();
+}
+
+namespace zen::tests {
+
+TEST_CASE("default.single")
+{
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+
+ ZenServerInstance Instance(TestEnv);
+ Instance.SetTestDir(TestDir);
+ Instance.SpawnServer(13337);
+
+ ZEN_INFO("Waiting...");
+
+ Instance.WaitUntilReady();
+
+ std::atomic<uint64_t> RequestCount{0};
+ std::atomic<uint64_t> BatchCounter{0};
+
+ ZEN_INFO("Running single server test...");
+
+ auto IssueTestRequests = [&] {
+ const uint64_t BatchNo = BatchCounter.fetch_add(1);
+ const int ThreadId = zen::GetCurrentThreadId();
+
+ ZEN_INFO("query batch {} started (thread {})", BatchNo, ThreadId);
+ cpr::Session cli;
+ cli.SetUrl(cpr::Url{"http://localhost:13337/test/hello"});
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ auto res = cli.Get();
+ ++RequestCount;
+ }
+ ZEN_INFO("query batch {} ended (thread {})", BatchNo, ThreadId);
+ };
+
+ zen::Stopwatch timer;
+
+ Concurrency::parallel_invoke(IssueTestRequests,
+ IssueTestRequests,
+ IssueTestRequests,
+ IssueTestRequests,
+ IssueTestRequests,
+ IssueTestRequests,
+ IssueTestRequests,
+ IssueTestRequests,
+ IssueTestRequests,
+ IssueTestRequests);
+
+ uint64_t Elapsed = timer.GetElapsedTimeMs();
+
+ ZEN_INFO("{} requests in {} ({})", RequestCount, zen::NiceTimeSpanMs(Elapsed), zen::NiceRate(RequestCount, (uint32_t)Elapsed, "req"));
+}
+
+TEST_CASE("multi.basic")
+{
+ ZenServerInstance Instance1(TestEnv);
+ std::filesystem::path TestDir1 = TestEnv.CreateNewTestDir();
+ Instance1.SetTestDir(TestDir1);
+ Instance1.SpawnServer(13337);
+
+ ZenServerInstance Instance2(TestEnv);
+ std::filesystem::path TestDir2 = TestEnv.CreateNewTestDir();
+ Instance2.SetTestDir(TestDir2);
+ Instance2.SpawnServer(13338);
+
+ ZEN_INFO("Waiting...");
+
+ Instance1.WaitUntilReady();
+ Instance2.WaitUntilReady();
+
+ std::atomic<uint64_t> RequestCount{0};
+ std::atomic<uint64_t> BatchCounter{0};
+
+ auto IssueTestRequests = [&](int PortNumber) {
+ const uint64_t BatchNo = BatchCounter.fetch_add(1);
+ const int ThreadId = zen::GetCurrentThreadId();
+
+ ZEN_INFO("query batch {} started (thread {}) for port {}", BatchNo, ThreadId, PortNumber);
+
+ cpr::Session cli;
+ cli.SetUrl(cpr::Url{fmt::format("http://localhost:{}/test/hello", PortNumber)});
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ auto res = cli.Get();
+ ++RequestCount;
+ }
+ ZEN_INFO("query batch {} ended (thread {})", BatchNo, ThreadId);
+ };
+
+ zen::Stopwatch timer;
+
+ ZEN_INFO("Running multi-server test...");
+
+ Concurrency::parallel_invoke([&] { IssueTestRequests(13337); },
+ [&] { IssueTestRequests(13338); },
+ [&] { IssueTestRequests(13337); },
+ [&] { IssueTestRequests(13338); });
+
+ uint64_t Elapsed = timer.GetElapsedTimeMs();
+
+ ZEN_INFO("{} requests in {} ({})", RequestCount, zen::NiceTimeSpanMs(Elapsed), zen::NiceRate(RequestCount, (uint32_t)Elapsed, "req"));
+}
+
+TEST_CASE("project.basic")
+{
+ using namespace std::literals;
+
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+
+ const uint16_t PortNumber = 13337;
+
+ ZenServerInstance Instance1(TestEnv);
+ Instance1.SetTestDir(TestDir);
+ Instance1.SpawnServer(PortNumber);
+ Instance1.WaitUntilReady();
+
+ std::atomic<uint64_t> RequestCount{0};
+
+ zen::Stopwatch timer;
+
+ std::mt19937_64 mt;
+
+ zen::StringBuilder<64> BaseUri;
+ BaseUri << fmt::format("http://localhost:{}/prj/test", PortNumber);
+
+ std::filesystem::path BinPath = zen::GetRunningExecutablePath();
+ std::filesystem::path RootPath = BinPath.parent_path().parent_path();
+ BinPath = BinPath.lexically_relative(RootPath);
+
+ SUBCASE("build store init")
+ {
+ {
+ {
+ zen::CbObjectWriter Body;
+ Body << "id"
+ << "test";
+ Body << "root" << RootPath.c_str();
+ Body << "project"
+ << "/zooom";
+ Body << "engine"
+ << "/zooom";
+
+ zen::BinaryWriter MemOut;
+ Body.Save(MemOut);
+
+ auto Response = cpr::Post(cpr::Url{BaseUri.c_str()}, cpr::Body{(const char*)MemOut.Data(), MemOut.Size()});
+ CHECK(Response.status_code == 201);
+ }
+
+ {
+ auto Response = cpr::Get(cpr::Url{BaseUri.c_str()});
+ CHECK(Response.status_code == 200);
+
+ zen::CbObjectView ResponseObject = zen::CbFieldView(Response.text.data()).AsObjectView();
+
+ CHECK(ResponseObject["id"].AsString() == "test"sv);
+ CHECK(ResponseObject["root"].AsString() == PathToUtf8(RootPath.c_str()));
+ }
+ }
+
+ BaseUri << "/oplog/foobar";
+
+ {
+ {
+ zen::StringBuilder<64> PostUri;
+ PostUri << BaseUri;
+ auto Response = cpr::Post(cpr::Url{PostUri.c_str()});
+ CHECK(Response.status_code == 201);
+ }
+
+ {
+ auto Response = cpr::Get(cpr::Url{BaseUri.c_str()});
+ CHECK(Response.status_code == 200);
+
+ zen::CbObjectView ResponseObject = zen::CbFieldView(Response.text.data()).AsObjectView();
+
+ CHECK(ResponseObject["id"].AsString() == "foobar"sv);
+ CHECK(ResponseObject["project"].AsString() == "test"sv);
+ }
+ }
+
+ SUBCASE("build store persistence")
+ {
+ uint8_t AttachData[] = {1, 2, 3};
+
+ zen::CompressedBuffer Attachment = zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone(zen::MemoryView{AttachData, 3}));
+ zen::CbAttachment Attach{Attachment, Attachment.DecodeRawHash()};
+
+ zen::CbObjectWriter OpWriter;
+ OpWriter << "key"
+ << "foo"
+ << "attachment" << Attach;
+
+ const std::string_view ChunkId{
+ "00000000"
+ "00000000"
+ "00010000"};
+ auto FileOid = zen::Oid::FromHexString(ChunkId);
+
+ OpWriter.BeginArray("files");
+ OpWriter.BeginObject();
+ OpWriter << "id" << FileOid;
+ OpWriter << "clientpath"
+ << "/{engine}/client/side/path";
+ OpWriter << "serverpath" << BinPath.c_str();
+ OpWriter.EndObject();
+ OpWriter.EndArray();
+
+ zen::CbObject Op = OpWriter.Save();
+
+ zen::CbPackage OpPackage(Op);
+ OpPackage.AddAttachment(Attach);
+
+ zen::BinaryWriter MemOut;
+ legacy::SaveCbPackage(OpPackage, MemOut);
+
+ {
+ zen::StringBuilder<64> PostUri;
+ PostUri << BaseUri << "/new";
+ auto Response = cpr::Post(cpr::Url{PostUri.c_str()}, cpr::Body{(const char*)MemOut.Data(), MemOut.Size()});
+
+ REQUIRE(!Response.error);
+ CHECK(Response.status_code == 201);
+ }
+
+ // Read file data
+
+ {
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << BaseUri << "/" << ChunkId;
+ auto Response = cpr::Get(cpr::Url{ChunkGetUri.c_str()});
+
+ REQUIRE(!Response.error);
+ CHECK(Response.status_code == 200);
+ }
+
+ {
+ zen::StringBuilder<128> ChunkGetUri;
+ ChunkGetUri << BaseUri << "/" << ChunkId << "?offset=1&size=10";
+ auto Response = cpr::Get(cpr::Url{ChunkGetUri.c_str()});
+
+ REQUIRE(!Response.error);
+ CHECK(Response.status_code == 200);
+ CHECK(Response.text.size() == 10);
+ }
+
+ ZEN_INFO("+++++++");
+ }
+ SUBCASE("build store op commit") { ZEN_INFO("-------"); }
+ SUBCASE("test chunk not found error")
+ {
+ for (size_t I = 0; I < 65; I++)
+ {
+ zen::StringBuilder<128> PostUri;
+ PostUri << BaseUri << "/f77c781846caead318084604/info";
+ auto Response = cpr::Get(cpr::Url{PostUri.c_str()});
+
+ REQUIRE(!Response.error);
+ CHECK(Response.status_code == 404);
+ }
+ }
+ }
+
+ const uint64_t Elapsed = timer.GetElapsedTimeMs();
+
+ ZEN_INFO("{} requests in {} ({})", RequestCount, zen::NiceTimeSpanMs(Elapsed), zen::NiceRate(RequestCount, (uint32_t)Elapsed, "req"));
+}
+
+# if 0 // this is extremely WIP
+TEST_CASE("project.pipe")
+{
+ using namespace std::literals;
+
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+
+ const uint16_t PortNumber = 13337;
+
+ ZenServerInstance Instance1(TestEnv);
+ Instance1.SetTestDir(TestDir);
+ Instance1.SpawnServer(PortNumber);
+ Instance1.WaitUntilReady();
+
+ zen::LocalProjectClient LocalClient(PortNumber);
+
+ zen::CbObjectWriter Cbow;
+ Cbow << "hey" << 42;
+
+ zen::CbObject Response = LocalClient.MessageTransaction(Cbow.Save());
+}
+# endif
+
+namespace utils {
+
+ struct ZenConfig
+ {
+ std::filesystem::path DataDir;
+ uint16_t Port;
+ std::string BaseUri;
+ std::string Args;
+
+ static ZenConfig New(uint16_t Port = 13337, std::string Args = "")
+ {
+ return ZenConfig{.DataDir = TestEnv.CreateNewTestDir(),
+ .Port = Port,
+ .BaseUri = fmt::format("http://localhost:{}/z$", Port),
+ .Args = std::move(Args)};
+ }
+
+ static ZenConfig NewWithUpstream(uint16_t UpstreamPort)
+ {
+ return New(13337, fmt::format("--debug --upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", UpstreamPort));
+ }
+
+ static ZenConfig NewWithThreadedUpstreams(std::span<uint16_t> UpstreamPorts, bool Debug)
+ {
+ std::string Args = Debug ? "--debug" : "";
+ for (uint16_t Port : UpstreamPorts)
+ {
+ Args = fmt::format("{}{}--upstream-zen-url=http://localhost:{}", Args, Args.length() > 0 ? " " : "", Port);
+ }
+ return New(13337, Args);
+ }
+
+ void Spawn(ZenServerInstance& Inst)
+ {
+ Inst.SetTestDir(DataDir);
+ Inst.SpawnServer(Port, Args);
+ Inst.WaitUntilReady();
+ }
+ };
+
+ void SpawnServer(ZenServerInstance& Server, ZenConfig& Cfg)
+ {
+ Server.SetTestDir(Cfg.DataDir);
+ Server.SpawnServer(Cfg.Port, Cfg.Args);
+ Server.WaitUntilReady();
+ }
+
+} // namespace utils
+
+TEST_CASE("zcache.basic")
+{
+ using namespace std::literals;
+
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+
+ const uint16_t PortNumber = 13337;
+
+ const int kIterationCount = 100;
+ const auto BaseUri = fmt::format("http://localhost:{}/z$", PortNumber);
+
+ auto HashKey = [](int i) -> zen::IoHash { return zen::IoHash::HashBuffer(&i, sizeof i); };
+
+ {
+ ZenServerInstance Instance1(TestEnv);
+ Instance1.SetTestDir(TestDir);
+ Instance1.SpawnServer(PortNumber);
+ Instance1.WaitUntilReady();
+
+ // Populate with some simple data
+
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ zen::CbObjectWriter Cbo;
+ Cbo << "index" << i;
+
+ zen::BinaryWriter MemOut;
+ Cbo.Save(MemOut);
+
+ zen::IoHash Key = HashKey(i);
+
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", BaseUri, "test", Key)},
+ cpr::Body{(const char*)MemOut.Data(), MemOut.Size()},
+ cpr::Header{{"Content-Type", "application/x-ue-cb"}});
+
+ CHECK(Result.status_code == 201);
+ }
+
+ // Retrieve data
+
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ zen::IoHash Key = zen::IoHash::HashBuffer(&i, sizeof i);
+
+ cpr::Response Result =
+ cpr::Get(cpr::Url{fmt::format("{}/{}/{}", BaseUri, "test", Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+
+ CHECK(Result.status_code == 200);
+ }
+
+ // Ensure bad bucket identifiers are rejected
+
+ {
+ zen::CbObjectWriter Cbo;
+ Cbo << "index" << 42;
+
+ zen::BinaryWriter MemOut;
+ Cbo.Save(MemOut);
+
+ zen::IoHash Key = HashKey(442);
+
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", BaseUri, "te!st", Key)},
+ cpr::Body{(const char*)MemOut.Data(), MemOut.Size()},
+ cpr::Header{{"Content-Type", "application/x-ue-cb"}});
+
+ CHECK(Result.status_code == 400);
+ }
+ }
+
+ // Verify that the data persists between process runs (the previous server has exited at this point)
+
+ {
+ ZenServerInstance Instance1(TestEnv);
+ Instance1.SetTestDir(TestDir);
+ Instance1.SpawnServer(PortNumber);
+ Instance1.WaitUntilReady();
+
+ // Retrieve data again
+
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ zen::IoHash Key = HashKey(i);
+
+ cpr::Response Result =
+ cpr::Get(cpr::Url{fmt::format("{}/{}/{}", BaseUri, "test", Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+
+ CHECK(Result.status_code == 200);
+ }
+ }
+}
+
+TEST_CASE("zcache.cbpackage")
+{
+ using namespace std::literals;
+
+ auto CreateTestPackage = [](zen::IoHash& OutAttachmentKey) -> zen::CbPackage {
+ auto Data = zen::SharedBuffer::Clone(zen::MakeMemoryView<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9}));
+ auto CompressedData = zen::CompressedBuffer::Compress(Data);
+
+ OutAttachmentKey = CompressedData.DecodeRawHash();
+
+ zen::CbWriter Obj;
+ Obj.BeginObject("obj"sv);
+ Obj.AddBinaryAttachment("data", OutAttachmentKey);
+ Obj.EndObject();
+
+ zen::CbPackage Package;
+ Package.SetObject(Obj.Save().AsObject());
+ Package.AddAttachment(zen::CbAttachment(CompressedData, OutAttachmentKey));
+
+ return Package;
+ };
+
+ auto SerializeToBuffer = [](zen::CbPackage Package) -> zen::IoBuffer {
+ zen::BinaryWriter MemStream;
+
+ Package.Save(MemStream);
+
+ return zen::IoBuffer(zen::IoBuffer::Clone, MemStream.Data(), MemStream.Size());
+ };
+
+ auto IsEqual = [](zen::CbPackage Lhs, zen::CbPackage Rhs) -> bool {
+ std::span<const zen::CbAttachment> LhsAttachments = Lhs.GetAttachments();
+ std::span<const zen::CbAttachment> RhsAttachments = Rhs.GetAttachments();
+
+ if (LhsAttachments.size() != RhsAttachments.size())
+ {
+ return false;
+ }
+
+ for (const zen::CbAttachment& LhsAttachment : LhsAttachments)
+ {
+ const zen::CbAttachment* RhsAttachment = Rhs.FindAttachment(LhsAttachment.GetHash());
+ CHECK(RhsAttachment);
+
+ zen::SharedBuffer LhsBuffer = LhsAttachment.AsCompressedBinary().Decompress();
+ CHECK(!LhsBuffer.IsNull());
+
+ zen::SharedBuffer RhsBuffer = RhsAttachment->AsCompressedBinary().Decompress();
+ CHECK(!RhsBuffer.IsNull());
+
+ if (!LhsBuffer.GetView().EqualBytes(RhsBuffer.GetView()))
+ {
+ return false;
+ }
+ }
+
+ return true;
+ };
+
+ SUBCASE("PUT/GET returns correct package")
+ {
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const uint16_t PortNumber = 13337;
+ const auto BaseUri = fmt::format("http://localhost:{}/z$", PortNumber);
+
+ ZenServerInstance Instance1(TestEnv);
+ Instance1.SetTestDir(TestDir);
+ Instance1.SpawnServer(PortNumber);
+ Instance1.WaitUntilReady();
+
+ const std::string_view Bucket = "mosdef"sv;
+ zen::IoHash Key;
+ zen::CbPackage ExpectedPackage = CreateTestPackage(Key);
+
+ // PUT
+ {
+ zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage);
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", BaseUri, Bucket, Key)},
+ cpr::Body{(const char*)Body.Data(), Body.Size()},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 201);
+ }
+
+ // GET
+ {
+ cpr::Response Result =
+ cpr::Get(cpr::Url{fmt::format("{}/{}/{}", BaseUri, Bucket, Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 200);
+
+ zen::IoBuffer Response(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size());
+
+ zen::CbPackage Package;
+ const bool Ok = Package.TryLoad(Response);
+ CHECK(Ok);
+ CHECK(IsEqual(Package, ExpectedPackage));
+ }
+ }
+
+ SUBCASE("PUT propagates upstream")
+ {
+ // Setup local and remote server
+ std::filesystem::path LocalDataDir = TestEnv.CreateNewTestDir();
+ std::filesystem::path RemoteDataDir = TestEnv.CreateNewTestDir();
+ const uint16_t LocalPortNumber = 13337;
+ const uint16_t RemotePortNumber = 13338;
+
+ const auto LocalBaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber);
+ const auto RemoteBaseUri = fmt::format("http://localhost:{}/z$", RemotePortNumber);
+
+ ZenServerInstance RemoteInstance(TestEnv);
+ RemoteInstance.SetTestDir(RemoteDataDir);
+ RemoteInstance.SpawnServer(RemotePortNumber);
+ RemoteInstance.WaitUntilReady();
+
+ ZenServerInstance LocalInstance(TestEnv);
+ LocalInstance.SetTestDir(LocalDataDir);
+ LocalInstance.SpawnServer(LocalPortNumber,
+ fmt::format("--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", RemotePortNumber));
+ LocalInstance.WaitUntilReady();
+
+ const std::string_view Bucket = "mosdef"sv;
+ zen::IoHash Key;
+ zen::CbPackage ExpectedPackage = CreateTestPackage(Key);
+
+ // Store the cache record package in the local instance
+ {
+ zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage);
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", LocalBaseUri, Bucket, Key)},
+ cpr::Body{(const char*)Body.Data(), Body.Size()},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}});
+
+ CHECK(Result.status_code == 201);
+ }
+
+ // The cache record can be retrieved as a package from the local instance
+ {
+ cpr::Response Result =
+ cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalBaseUri, Bucket, Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 200);
+
+ zen::IoBuffer Body(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size());
+ zen::CbPackage Package;
+ const bool Ok = Package.TryLoad(Body);
+ CHECK(Ok);
+ CHECK(IsEqual(Package, ExpectedPackage));
+ }
+
+ // The cache record can be retrieved as a package from the remote instance
+ {
+ cpr::Response Result =
+ cpr::Get(cpr::Url{fmt::format("{}/{}/{}", RemoteBaseUri, Bucket, Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 200);
+
+ zen::IoBuffer Body(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size());
+ zen::CbPackage Package;
+ const bool Ok = Package.TryLoad(Body);
+ CHECK(Ok);
+ CHECK(IsEqual(Package, ExpectedPackage));
+ }
+ }
+
+ SUBCASE("GET finds upstream when missing in local")
+ {
+ // Setup local and remote server
+ std::filesystem::path LocalDataDir = TestEnv.CreateNewTestDir();
+ std::filesystem::path RemoteDataDir = TestEnv.CreateNewTestDir();
+ const uint16_t LocalPortNumber = 13337;
+ const uint16_t RemotePortNumber = 13338;
+
+ const auto LocalBaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber);
+ const auto RemoteBaseUri = fmt::format("http://localhost:{}/z$", RemotePortNumber);
+
+ ZenServerInstance RemoteInstance(TestEnv);
+ RemoteInstance.SetTestDir(RemoteDataDir);
+ RemoteInstance.SpawnServer(RemotePortNumber);
+ RemoteInstance.WaitUntilReady();
+
+ ZenServerInstance LocalInstance(TestEnv);
+ LocalInstance.SetTestDir(LocalDataDir);
+ LocalInstance.SpawnServer(LocalPortNumber,
+ fmt::format("--upstream-thread-count=0 --upstream-zen-url=http://localhost:{}", RemotePortNumber));
+ LocalInstance.WaitUntilReady();
+
+ const std::string_view Bucket = "mosdef"sv;
+ zen::IoHash Key;
+ zen::CbPackage ExpectedPackage = CreateTestPackage(Key);
+
+ // Store the cache record package in upstream cache
+ {
+ zen::IoBuffer Body = SerializeToBuffer(ExpectedPackage);
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", RemoteBaseUri, Bucket, Key)},
+ cpr::Body{(const char*)Body.Data(), Body.Size()},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}});
+
+ CHECK(Result.status_code == 201);
+ }
+
+ // The cache record can be retrieved as a package from the local cache
+ {
+ cpr::Response Result =
+ cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalBaseUri, Bucket, Key)}, cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 200);
+
+ zen::IoBuffer Body(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size());
+ zen::CbPackage Package;
+ const bool Ok = Package.TryLoad(Body);
+ CHECK(Ok);
+ CHECK(IsEqual(Package, ExpectedPackage));
+ }
+ }
+}
+
+TEST_CASE("zcache.policy")
+{
+ using namespace std::literals;
+ using namespace utils;
+
+ auto GenerateData = [](uint64_t Size, zen::IoHash& OutHash) -> zen::UniqueBuffer {
+ auto Buf = zen::UniqueBuffer::Alloc(Size);
+ uint8_t* Data = reinterpret_cast<uint8_t*>(Buf.GetData());
+ for (uint64_t Idx = 0; Idx < Size; Idx++)
+ {
+ Data[Idx] = Idx % 256;
+ }
+ OutHash = zen::IoHash::HashBuffer(Data, Size);
+ return Buf;
+ };
+
+ auto GeneratePackage = [](zen::IoHash& OutRecordKey, zen::IoHash& OutAttachmentKey) -> zen::CbPackage {
+ auto Data = zen::SharedBuffer::Clone(zen::MakeMemoryView<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9}));
+ auto CompressedData = zen::CompressedBuffer::Compress(Data);
+ OutAttachmentKey = CompressedData.DecodeRawHash();
+
+ zen::CbWriter Writer;
+ Writer.BeginObject("obj"sv);
+ Writer.AddBinaryAttachment("data", OutAttachmentKey);
+ Writer.EndObject();
+ CbObject CacheRecord = Writer.Save().AsObject();
+
+ OutRecordKey = IoHash::HashBuffer(CacheRecord.GetBuffer().GetView());
+
+ zen::CbPackage Package;
+ Package.SetObject(CacheRecord);
+ Package.AddAttachment(zen::CbAttachment(CompressedData, OutAttachmentKey));
+
+ return Package;
+ };
+
+ auto ToBuffer = [](zen::CbPackage Package) -> zen::IoBuffer {
+ zen::BinaryWriter MemStream;
+ Package.Save(MemStream);
+
+ return zen::IoBuffer(zen::IoBuffer::Clone, MemStream.Data(), MemStream.Size());
+ };
+
+ SUBCASE("query - 'local' does not query upstream (binary)")
+ {
+ ZenConfig UpstreamCfg = ZenConfig::New(13338);
+ ZenServerInstance UpstreamInst(TestEnv);
+ ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338);
+ ZenServerInstance LocalInst(TestEnv);
+ const auto Bucket = "legacy"sv;
+
+ UpstreamCfg.Spawn(UpstreamInst);
+ LocalCfg.Spawn(LocalInst);
+
+ zen::IoHash Key;
+ auto BinaryValue = GenerateData(1024, Key);
+
+ // Store binary cache value upstream
+ {
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)},
+ cpr::Body{(const char*)BinaryValue.GetData(), BinaryValue.GetSize()},
+ cpr::Header{{"Content-Type", "application/octet-stream"}});
+ CHECK(Result.status_code == 201);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=QueryLocal,Store", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/octet-stream"}});
+ CHECK(Result.status_code == 404);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,Store", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/octet-stream"}});
+ CHECK(Result.status_code == 200);
+ }
+ }
+
+ SUBCASE("store - 'local' does not store upstream (binary)")
+ {
+ ZenConfig UpstreamCfg = ZenConfig::New(13338);
+ ZenServerInstance UpstreamInst(TestEnv);
+ ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338);
+ ZenServerInstance LocalInst(TestEnv);
+ const auto Bucket = "legacy"sv;
+
+ UpstreamCfg.Spawn(UpstreamInst);
+ LocalCfg.Spawn(LocalInst);
+
+ zen::IoHash Key;
+ auto BinaryValue = GenerateData(1024, Key);
+
+ // Store binary cache value locally
+ {
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,StoreLocal", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Body{(const char*)BinaryValue.GetData(), BinaryValue.GetSize()},
+ cpr::Header{{"Content-Type", "application/octet-stream"}});
+ CHECK(Result.status_code == 201);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/octet-stream"}});
+ CHECK(Result.status_code == 404);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/octet-stream"}});
+ CHECK(Result.status_code == 200);
+ }
+ }
+
+ SUBCASE("store - 'local/remote' stores local and upstream (binary)")
+ {
+ ZenConfig UpstreamCfg = ZenConfig::New(13338);
+ ZenServerInstance UpstreamInst(TestEnv);
+ ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338);
+ ZenServerInstance LocalInst(TestEnv);
+ const auto Bucket = "legacy"sv;
+
+ UpstreamCfg.Spawn(UpstreamInst);
+ LocalCfg.Spawn(LocalInst);
+
+ zen::IoHash Key;
+ auto BinaryValue = GenerateData(1024, Key);
+
+ // Store binary cache value locally and upstream
+ {
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,Store", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Body{(const char*)BinaryValue.GetData(), BinaryValue.GetSize()},
+ cpr::Header{{"Content-Type", "application/octet-stream"}});
+ CHECK(Result.status_code == 201);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/octet-stream"}});
+ CHECK(Result.status_code == 200);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/octet-stream"}});
+ CHECK(Result.status_code == 200);
+ }
+ }
+
+ SUBCASE("query - 'local' does not query upstream (cppackage)")
+ {
+ ZenConfig UpstreamCfg = ZenConfig::New(13338);
+ ZenServerInstance UpstreamInst(TestEnv);
+ ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338);
+ ZenServerInstance LocalInst(TestEnv);
+ const auto Bucket = "legacy"sv;
+
+ UpstreamCfg.Spawn(UpstreamInst);
+ LocalCfg.Spawn(LocalInst);
+
+ zen::IoHash Key;
+ zen::IoHash PayloadId;
+ zen::CbPackage Package = GeneratePackage(Key, PayloadId);
+ auto Buf = ToBuffer(Package);
+
+ // Store package upstream
+ {
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)},
+ cpr::Body{(const char*)Buf.GetData(), Buf.GetSize()},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 201);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=QueryLocal,Store", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 404);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,Store", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 200);
+ }
+ }
+
+ SUBCASE("store - 'local' does not store upstream (cbpackge)")
+ {
+ ZenConfig UpstreamCfg = ZenConfig::New(13338);
+ ZenServerInstance UpstreamInst(TestEnv);
+ ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338);
+ ZenServerInstance LocalInst(TestEnv);
+ const auto Bucket = "legacy"sv;
+
+ UpstreamCfg.Spawn(UpstreamInst);
+ LocalCfg.Spawn(LocalInst);
+
+ zen::IoHash Key;
+ zen::IoHash PayloadId;
+ zen::CbPackage Package = GeneratePackage(Key, PayloadId);
+ auto Buf = ToBuffer(Package);
+
+ // Store packge locally
+ {
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,StoreLocal", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Body{(const char*)Buf.GetData(), Buf.GetSize()},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 201);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 404);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 200);
+ }
+ }
+
+ SUBCASE("store - 'local/remote' stores local and upstream (cbpackage)")
+ {
+ ZenConfig UpstreamCfg = ZenConfig::New(13338);
+ ZenServerInstance UpstreamInst(TestEnv);
+ ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338);
+ ZenServerInstance LocalInst(TestEnv);
+ const auto Bucket = "legacy"sv;
+
+ UpstreamCfg.Spawn(UpstreamInst);
+ LocalCfg.Spawn(LocalInst);
+
+ zen::IoHash Key;
+ zen::IoHash PayloadId;
+ zen::CbPackage Package = GeneratePackage(Key, PayloadId);
+ auto Buf = ToBuffer(Package);
+
+ // Store package locally and upstream
+ {
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}?Policy=Query,Store", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Body{(const char*)Buf.GetData(), Buf.GetSize()},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 201);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", UpstreamCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 200);
+ }
+
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}", LocalCfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 200);
+ }
+ }
+
+ SUBCASE("skip - 'data' returns cache record without attachments/empty payload")
+ {
+ ZenConfig Cfg = ZenConfig::New();
+ ZenServerInstance Instance(TestEnv);
+ const auto Bucket = "test"sv;
+
+ Cfg.Spawn(Instance);
+
+ zen::IoHash Key;
+ zen::IoHash PayloadId;
+ zen::CbPackage Package = GeneratePackage(Key, PayloadId);
+ auto Buf = ToBuffer(Package);
+
+ // Store package
+ {
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", Cfg.BaseUri, Bucket, Key)},
+ cpr::Body{(const char*)Buf.GetData(), Buf.GetSize()},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}});
+ CHECK(Result.status_code == 201);
+ }
+
+ // Get package
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=Default,SkipData", Cfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/x-ue-cbpkg"}});
+ CHECK(IsHttpSuccessCode(Result.status_code));
+ IoBuffer Buffer(IoBuffer::Wrap, Result.text.c_str(), Result.text.size());
+ CbPackage ResponsePackage;
+ CHECK(ResponsePackage.TryLoad(Buffer));
+ CHECK(ResponsePackage.GetAttachments().size() == 0);
+ }
+
+ // Get record
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=Default,SkipData", Cfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/x-ue-cb"}});
+ CHECK(IsHttpSuccessCode(Result.status_code));
+ IoBuffer Buffer(IoBuffer::Wrap, Result.text.c_str(), Result.text.size());
+ CbObject ResponseObject = zen::LoadCompactBinaryObject(Buffer);
+ CHECK((bool)ResponseObject);
+ }
+
+ // Get payload
+ {
+ cpr::Response Result =
+ cpr::Get(cpr::Url{fmt::format("{}/{}/{}/{}?Policy=Default,SkipData", Cfg.BaseUri, Bucket, Key, PayloadId)},
+ cpr::Header{{"Accept", "application/x-ue-comp"}});
+ CHECK(IsHttpSuccessCode(Result.status_code));
+ CHECK(Result.text.size() == 0);
+ }
+ }
+
+ SUBCASE("skip - 'data' returns empty binary value")
+ {
+ ZenConfig Cfg = ZenConfig::New();
+ ZenServerInstance Instance(TestEnv);
+ const auto Bucket = "test"sv;
+
+ Cfg.Spawn(Instance);
+
+ zen::IoHash Key;
+ auto BinaryValue = GenerateData(1024, Key);
+
+ // Store binary cache value
+ {
+ cpr::Response Result = cpr::Put(cpr::Url{fmt::format("{}/{}/{}", Cfg.BaseUri, Bucket, Key)},
+ cpr::Body{(const char*)BinaryValue.GetData(), BinaryValue.GetSize()},
+ cpr::Header{{"Content-Type", "application/octet-stream"}});
+ CHECK(Result.status_code == 201);
+ }
+
+ // Get package
+ {
+ cpr::Response Result = cpr::Get(cpr::Url{fmt::format("{}/{}/{}?Policy=Default,SkipData", Cfg.BaseUri, Bucket, Key)},
+ cpr::Header{{"Accept", "application/octet-stream"}});
+ CHECK(IsHttpSuccessCode(Result.status_code));
+ CHECK(Result.text.size() == 0);
+ }
+ }
+}
+
+TEST_CASE("zcache.rpc")
+{
+ using namespace std::literals;
+
+ auto AppendCacheRecord = [](cacherequests::PutCacheRecordsRequest& Request,
+ const zen::CacheKey& CacheKey,
+ size_t PayloadSize,
+ CachePolicy RecordPolicy) {
+ std::vector<uint8_t> Data;
+ Data.resize(PayloadSize);
+ uint32_t DataSeed = *reinterpret_cast<const uint32_t*>(&CacheKey.Hash.Hash[0]);
+ uint16_t* DataPtr = reinterpret_cast<uint16_t*>(Data.data());
+ for (size_t Idx = 0; Idx < PayloadSize / 2; ++Idx)
+ {
+ DataPtr[Idx] = static_cast<uint16_t>((Idx + DataSeed) % 0xffffu);
+ }
+ if (PayloadSize & 1)
+ {
+ Data[PayloadSize - 1] = static_cast<uint8_t>((PayloadSize - 1) & 0xff);
+ }
+ CompressedBuffer Value = zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size()));
+ Request.Requests.push_back({.Key = CacheKey, .Values = {{.Id = Oid::NewOid(), .Body = std::move(Value)}}, .Policy = RecordPolicy});
+ };
+
+ auto PutCacheRecords = [&AppendCacheRecord](std::string_view BaseUri,
+ std::string_view Namespace,
+ std::string_view Bucket,
+ size_t Num,
+ size_t PayloadSize = 1024,
+ size_t KeyOffset = 1) -> std::vector<CacheKey> {
+ std::vector<zen::CacheKey> OutKeys;
+
+ for (uint32_t Key = 1; Key <= Num; ++Key)
+ {
+ zen::IoHash KeyHash;
+ ((uint32_t*)(KeyHash.Hash))[0] = gsl::narrow<uint32_t>(KeyOffset + Key);
+ const zen::CacheKey CacheKey = zen::CacheKey::Create(Bucket, KeyHash);
+
+ cacherequests::PutCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic, .Namespace = std::string(Namespace)};
+ AppendCacheRecord(Request, CacheKey, PayloadSize, CachePolicy::Default);
+ OutKeys.push_back(CacheKey);
+
+ CbPackage Package;
+ CHECK(Request.Format(Package));
+
+ IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer();
+ cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}},
+ cpr::Body{(const char*)Body.GetData(), Body.GetSize()});
+
+ CHECK(Result.status_code == 200);
+ }
+
+ return OutKeys;
+ };
+
+ struct GetCacheRecordResult
+ {
+ zen::CbPackage Response;
+ cacherequests::GetCacheRecordsResult Result;
+ bool Success;
+ };
+
+ auto GetCacheRecords = [](std::string_view BaseUri,
+ std::string_view Namespace,
+ std::span<zen::CacheKey> Keys,
+ zen::CachePolicy Policy,
+ zen::RpcAcceptOptions AcceptOptions = zen::RpcAcceptOptions::kNone,
+ int Pid = 0) -> GetCacheRecordResult {
+ cacherequests::GetCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic,
+ .AcceptOptions = static_cast<uint16_t>(AcceptOptions),
+ .ProcessPid = Pid,
+ .DefaultPolicy = Policy,
+ .Namespace = std::string(Namespace)};
+ for (const CacheKey& Key : Keys)
+ {
+ Request.Requests.push_back({.Key = Key});
+ }
+
+ CbObjectWriter RequestWriter;
+ CHECK(Request.Format(RequestWriter));
+
+ BinaryWriter Body;
+ RequestWriter.Save(Body);
+
+ cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)},
+ cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}},
+ cpr::Body{(const char*)Body.GetData(), Body.GetSize()});
+
+ GetCacheRecordResult OutResult;
+
+ if (Result.status_code == 200)
+ {
+ CbPackage Response = ParsePackageMessage(zen::IoBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size()));
+ if (!Response.IsNull())
+ {
+ OutResult.Response = std::move(Response);
+ CHECK(OutResult.Result.Parse(OutResult.Response));
+ OutResult.Success = true;
+ }
+ }
+
+ return OutResult;
+ };
+
+ SUBCASE("get cache records")
+ {
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const uint16_t PortNumber = 13337;
+ const auto BaseUri = fmt::format("http://localhost:{}/z$", PortNumber);
+
+ ZenServerInstance Inst(TestEnv);
+ Inst.SetTestDir(TestDir);
+ Inst.SpawnServer(PortNumber);
+ Inst.WaitUntilReady();
+
+ CachePolicy Policy = CachePolicy::Default;
+ std::vector<zen::CacheKey> Keys = PutCacheRecords(BaseUri, "ue4.ddc"sv, "mastodon"sv, 128);
+ GetCacheRecordResult Result = GetCacheRecords(BaseUri, "ue4.ddc"sv, Keys, Policy);
+
+ CHECK(Result.Result.Results.size() == Keys.size());
+
+ for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results)
+ {
+ const CacheKey& ExpectedKey = Keys[Index++];
+ CHECK(Record);
+ CHECK(Record->Key == ExpectedKey);
+ CHECK(Record->Values.size() == 1);
+
+ for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values)
+ {
+ CHECK(Value.Body);
+ }
+ }
+ }
+
+ SUBCASE("get missing cache records")
+ {
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const uint16_t PortNumber = 13337;
+ const auto BaseUri = fmt::format("http://localhost:{}/z$", PortNumber);
+
+ ZenServerInstance Inst(TestEnv);
+ Inst.SetTestDir(TestDir);
+ Inst.SpawnServer(PortNumber);
+ Inst.WaitUntilReady();
+
+ CachePolicy Policy = CachePolicy::Default;
+ std::vector<zen::CacheKey> ExistingKeys = PutCacheRecords(BaseUri, "ue4.ddc"sv, "mastodon"sv, 128);
+ std::vector<zen::CacheKey> Keys;
+
+ for (const zen::CacheKey& Key : ExistingKeys)
+ {
+ Keys.push_back(Key);
+ Keys.push_back(CacheKey::Create("missing"sv, IoHash::Zero));
+ }
+
+ GetCacheRecordResult Result = GetCacheRecords(BaseUri, "ue4.ddc"sv, Keys, Policy);
+
+ CHECK(Result.Result.Results.size() == Keys.size());
+
+ size_t KeyIndex = 0;
+ for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results)
+ {
+ const bool Missing = Index++ % 2 != 0;
+
+ if (Missing)
+ {
+ CHECK(!Record);
+ }
+ else
+ {
+ const CacheKey& ExpectedKey = ExistingKeys[KeyIndex++];
+ CHECK(Record->Key == ExpectedKey);
+ for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values)
+ {
+ CHECK(Value.Body);
+ }
+ }
+ }
+ }
+
+ SUBCASE("policy - 'QueryLocal' does not query upstream")
+ {
+ using namespace utils;
+
+ ZenConfig UpstreamCfg = ZenConfig::New(13338);
+ ZenServerInstance UpstreamServer(TestEnv);
+ ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338);
+ ZenServerInstance LocalServer(TestEnv);
+
+ SpawnServer(UpstreamServer, UpstreamCfg);
+ SpawnServer(LocalServer, LocalCfg);
+
+ std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, "ue4.ddc"sv, "mastodon"sv, 4);
+
+ CachePolicy Policy = CachePolicy::QueryLocal;
+ GetCacheRecordResult Result = GetCacheRecords(LocalCfg.BaseUri, "ue4.ddc"sv, Keys, Policy);
+
+ CHECK(Result.Result.Results.size() == Keys.size());
+
+ for (const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results)
+ {
+ CHECK(!Record);
+ }
+ }
+
+ SUBCASE("policy - 'QueryRemote' does query upstream")
+ {
+ using namespace utils;
+
+ ZenConfig UpstreamCfg = ZenConfig::New(13338);
+ ZenServerInstance UpstreamServer(TestEnv);
+ ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338);
+ ZenServerInstance LocalServer(TestEnv);
+
+ SpawnServer(UpstreamServer, UpstreamCfg);
+ SpawnServer(LocalServer, LocalCfg);
+
+ std::vector<zen::CacheKey> Keys = PutCacheRecords(UpstreamCfg.BaseUri, "ue4.ddc"sv, "mastodon"sv, 4);
+
+ CachePolicy Policy = (CachePolicy::QueryLocal | CachePolicy::QueryRemote);
+ GetCacheRecordResult Result = GetCacheRecords(LocalCfg.BaseUri, "ue4.ddc"sv, Keys, Policy);
+
+ CHECK(Result.Result.Results.size() == Keys.size());
+
+ for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results)
+ {
+ CHECK(Record);
+ const CacheKey& ExpectedKey = Keys[Index++];
+ CHECK(Record->Key == ExpectedKey);
+ }
+ }
+
+ SUBCASE("RpcAcceptOptions")
+ {
+ using namespace utils;
+
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const uint16_t PortNumber = 13337;
+ const auto BaseUri = fmt::format("http://localhost:{}/z$", PortNumber);
+
+ ZenServerInstance Inst(TestEnv);
+ Inst.SetTestDir(TestDir);
+ Inst.SpawnServer(PortNumber);
+ Inst.WaitUntilReady();
+
+ std::vector<zen::CacheKey> SmallKeys = PutCacheRecords(BaseUri, "ue4.ddc"sv, "mastodon"sv, 4, 1024);
+ std::vector<zen::CacheKey> LargeKeys = PutCacheRecords(BaseUri, "ue4.ddc"sv, "mastodon"sv, 4, 1024 * 1024 * 16, SmallKeys.size());
+
+ std::vector<zen::CacheKey> Keys(SmallKeys.begin(), SmallKeys.end());
+ Keys.insert(Keys.end(), LargeKeys.begin(), LargeKeys.end());
+
+ {
+ GetCacheRecordResult Result = GetCacheRecords(BaseUri, "ue4.ddc"sv, Keys, CachePolicy::Default);
+
+ CHECK(Result.Result.Results.size() == Keys.size());
+
+ for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results)
+ {
+ CHECK(Record);
+ const CacheKey& ExpectedKey = Keys[Index++];
+ CHECK(Record->Key == ExpectedKey);
+ for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values)
+ {
+ const IoBuffer& Body = Value.Body.GetCompressed().Flatten().AsIoBuffer();
+ IoBufferFileReference Ref;
+ bool IsFileRef = Body.GetFileReference(Ref);
+ CHECK(!IsFileRef);
+ }
+ }
+ }
+
+ // File path, but only for large files
+ {
+ GetCacheRecordResult Result =
+ GetCacheRecords(BaseUri, "ue4.ddc"sv, Keys, CachePolicy::Default, RpcAcceptOptions::kAllowLocalReferences);
+
+ CHECK(Result.Result.Results.size() == Keys.size());
+
+ for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results)
+ {
+ CHECK(Record);
+ const CacheKey& ExpectedKey = Keys[Index++];
+ CHECK(Record->Key == ExpectedKey);
+ for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values)
+ {
+ const IoBuffer& Body = Value.Body.GetCompressed().Flatten().AsIoBuffer();
+ IoBufferFileReference Ref;
+ bool IsFileRef = Body.GetFileReference(Ref);
+ CHECK(IsFileRef == (Body.Size() > 1024));
+ }
+ }
+ }
+
+ // File path, for all files
+ {
+ GetCacheRecordResult Result =
+ GetCacheRecords(BaseUri,
+ "ue4.ddc"sv,
+ Keys,
+ CachePolicy::Default,
+ RpcAcceptOptions::kAllowLocalReferences | RpcAcceptOptions::kAllowPartialLocalReferences);
+
+ CHECK(Result.Result.Results.size() == Keys.size());
+
+ for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results)
+ {
+ CHECK(Record);
+ const CacheKey& ExpectedKey = Keys[Index++];
+ CHECK(Record->Key == ExpectedKey);
+ for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values)
+ {
+ const IoBuffer& Body = Value.Body.GetCompressed().Flatten().AsIoBuffer();
+ IoBufferFileReference Ref;
+ bool IsFileRef = Body.GetFileReference(Ref);
+ CHECK(IsFileRef);
+ }
+ }
+ }
+
+ // File handle, but only for large files
+ {
+ GetCacheRecordResult Result = GetCacheRecords(BaseUri,
+ "ue4.ddc"sv,
+ Keys,
+ CachePolicy::Default,
+ RpcAcceptOptions::kAllowLocalReferences,
+ GetCurrentProcessId());
+
+ CHECK(Result.Result.Results.size() == Keys.size());
+
+ for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results)
+ {
+ CHECK(Record);
+ const CacheKey& ExpectedKey = Keys[Index++];
+ CHECK(Record->Key == ExpectedKey);
+ for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values)
+ {
+ const IoBuffer& Body = Value.Body.GetCompressed().Flatten().AsIoBuffer();
+ IoBufferFileReference Ref;
+ bool IsFileRef = Body.GetFileReference(Ref);
+ CHECK(IsFileRef == (Body.Size() > 1024));
+ }
+ }
+ }
+
+ // File handle, for all files
+ {
+ GetCacheRecordResult Result =
+ GetCacheRecords(BaseUri,
+ "ue4.ddc"sv,
+ Keys,
+ CachePolicy::Default,
+ RpcAcceptOptions::kAllowLocalReferences | RpcAcceptOptions::kAllowPartialLocalReferences,
+ GetCurrentProcessId());
+
+ CHECK(Result.Result.Results.size() == Keys.size());
+
+ for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results)
+ {
+ CHECK(Record);
+ const CacheKey& ExpectedKey = Keys[Index++];
+ CHECK(Record->Key == ExpectedKey);
+ for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values)
+ {
+ const IoBuffer& Body = Value.Body.GetCompressed().Flatten().AsIoBuffer();
+ IoBufferFileReference Ref;
+ bool IsFileRef = Body.GetFileReference(Ref);
+ CHECK(IsFileRef);
+ }
+ }
+ }
+ }
+}
+
+TEST_CASE("zcache.failing.upstream")
+{
+ // This is an exploratory test that takes a long time to run, so lets skip it by default
+ if (true)
+ {
+ return;
+ }
+
+ using namespace std::literals;
+ using namespace utils;
+
+ const uint16_t Upstream1PortNumber = 13338;
+ ZenConfig Upstream1Cfg = ZenConfig::New(Upstream1PortNumber);
+ ZenServerInstance Upstream1Server(TestEnv);
+
+ const uint16_t Upstream2PortNumber = 13339;
+ ZenConfig Upstream2Cfg = ZenConfig::New(Upstream2PortNumber);
+ ZenServerInstance Upstream2Server(TestEnv);
+
+ std::vector<std::uint16_t> UpstreamPorts = {Upstream1PortNumber, Upstream2PortNumber};
+ ZenConfig LocalCfg = ZenConfig::NewWithThreadedUpstreams(UpstreamPorts, false);
+ LocalCfg.Args += (" --upstream-thread-count 2");
+ ZenServerInstance LocalServer(TestEnv);
+ const uint16_t LocalPortNumber = 13337;
+ const auto LocalUri = fmt::format("http://localhost:{}/z$", LocalPortNumber);
+ const auto Upstream1Uri = fmt::format("http://localhost:{}/z$", Upstream1PortNumber);
+ const auto Upstream2Uri = fmt::format("http://localhost:{}/z$", Upstream2PortNumber);
+
+ SpawnServer(Upstream1Server, Upstream1Cfg);
+ SpawnServer(Upstream2Server, Upstream2Cfg);
+ SpawnServer(LocalServer, LocalCfg);
+ bool Upstream1Running = true;
+ bool Upstream2Running = true;
+
+ using namespace std::literals;
+
+ auto AppendCacheRecord = [](cacherequests::PutCacheRecordsRequest& Request,
+ const zen::CacheKey& CacheKey,
+ size_t PayloadSize,
+ CachePolicy RecordPolicy) {
+ std::vector<uint32_t> Data;
+ Data.resize(PayloadSize / 4);
+ for (uint32_t Idx = 0; Idx < PayloadSize / 4; ++Idx)
+ {
+ Data[Idx] = (*reinterpret_cast<const uint32_t*>(&CacheKey.Hash.Hash[0])) + Idx;
+ }
+
+ CompressedBuffer Value = zen::CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size() * 4));
+ Request.Requests.push_back({.Key = CacheKey, .Values = {{.Id = Oid::NewOid(), .Body = std::move(Value)}}, .Policy = RecordPolicy});
+ };
+
+ auto PutCacheRecords = [&AppendCacheRecord](std::string_view BaseUri,
+ std::string_view Namespace,
+ std::string_view Bucket,
+ size_t Num,
+ size_t KeyOffset,
+ size_t PayloadSize = 8192) -> std::vector<CacheKey> {
+ std::vector<zen::CacheKey> OutKeys;
+
+ cacherequests::PutCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic, .Namespace = std::string(Namespace)};
+ for (size_t Key = 1; Key <= Num; ++Key)
+ {
+ zen::IoHash KeyHash;
+ ((size_t*)(KeyHash.Hash))[0] = KeyOffset + Key;
+ const zen::CacheKey CacheKey = zen::CacheKey::Create(Bucket, KeyHash);
+
+ AppendCacheRecord(Request, CacheKey, PayloadSize, CachePolicy::Default);
+ OutKeys.push_back(CacheKey);
+ }
+
+ CbPackage Package;
+ CHECK(Request.Format(Package));
+
+ IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer();
+ cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}},
+ cpr::Body{(const char*)Body.GetData(), Body.GetSize()});
+
+ if (Result.status_code != 200)
+ {
+ ZEN_DEBUG("PutCacheRecords failed with {}, reason '{}'", Result.status_code, Result.reason);
+ OutKeys.clear();
+ }
+
+ return OutKeys;
+ };
+
+ struct GetCacheRecordResult
+ {
+ zen::CbPackage Response;
+ cacherequests::GetCacheRecordsResult Result;
+ bool Success = false;
+ };
+
+ auto GetCacheRecords = [](std::string_view BaseUri,
+ std::string_view Namespace,
+ std::span<zen::CacheKey> Keys,
+ zen::CachePolicy Policy) -> GetCacheRecordResult {
+ cacherequests::GetCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic,
+ .DefaultPolicy = Policy,
+ .Namespace = std::string(Namespace)};
+ for (const CacheKey& Key : Keys)
+ {
+ Request.Requests.push_back({.Key = Key});
+ }
+
+ CbObjectWriter RequestWriter;
+ CHECK(Request.Format(RequestWriter));
+
+ BinaryWriter Body;
+ RequestWriter.Save(Body);
+
+ cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)},
+ cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}},
+ cpr::Body{(const char*)Body.GetData(), Body.GetSize()});
+
+ GetCacheRecordResult OutResult;
+
+ if (Result.status_code == 200)
+ {
+ CbPackage Response = ParsePackageMessage(zen::IoBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size()));
+ if (!Response.IsNull())
+ {
+ OutResult.Response = std::move(Response);
+ CHECK(OutResult.Result.Parse(OutResult.Response));
+ OutResult.Success = true;
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("GetCacheRecords with {}, reason '{}'", Result.reason, Result.status_code);
+ }
+
+ return OutResult;
+ };
+
+ // Populate with some simple data
+
+ CachePolicy Policy = CachePolicy::Default;
+
+ const size_t ThreadCount = 128;
+ const size_t KeyMultiplier = 16384;
+ const size_t RecordsPerRequest = 64;
+ WorkerThreadPool Pool(ThreadCount);
+
+ std::atomic_size_t Completed = 0;
+
+ auto Keys = new std::vector<CacheKey>[ThreadCount * KeyMultiplier];
+ RwLock KeysLock;
+
+ for (size_t I = 0; I < ThreadCount * KeyMultiplier; I++)
+ {
+ size_t Iteration = I;
+ Pool.ScheduleWork([&] {
+ std::vector<CacheKey> NewKeys = PutCacheRecords(LocalUri, "ue4.ddc"sv, "mastodon"sv, RecordsPerRequest, I * RecordsPerRequest);
+ if (NewKeys.size() != RecordsPerRequest)
+ {
+ ZEN_DEBUG("PutCacheRecords iteration {} failed", Iteration);
+ Completed.fetch_add(1);
+ return;
+ }
+ {
+ RwLock::ExclusiveLockScope _(KeysLock);
+ Keys[Iteration].swap(NewKeys);
+ }
+ Completed.fetch_add(1);
+ });
+ }
+ bool UseUpstream1 = false;
+ while (Completed < ThreadCount * KeyMultiplier)
+ {
+ Sleep(8000);
+
+ if (UseUpstream1)
+ {
+ if (Upstream2Running)
+ {
+ Upstream2Server.EnableTermination();
+ Upstream2Server.Shutdown();
+ Sleep(100);
+ Upstream2Running = false;
+ }
+ if (!Upstream1Running)
+ {
+ SpawnServer(Upstream1Server, Upstream1Cfg);
+ Upstream1Running = true;
+ }
+ UseUpstream1 = !UseUpstream1;
+ }
+ else
+ {
+ if (Upstream1Running)
+ {
+ Upstream1Server.EnableTermination();
+ Upstream1Server.Shutdown();
+ Sleep(100);
+ Upstream1Running = false;
+ }
+ if (!Upstream2Running)
+ {
+ SpawnServer(Upstream2Server, Upstream2Cfg);
+ Upstream2Running = true;
+ }
+ UseUpstream1 = !UseUpstream1;
+ }
+ }
+
+ Completed = 0;
+ for (size_t I = 0; I < ThreadCount * KeyMultiplier; I++)
+ {
+ size_t Iteration = I;
+ std::vector<CacheKey>& LocalKeys = Keys[Iteration];
+ if (LocalKeys.empty())
+ {
+ Completed.fetch_add(1);
+ continue;
+ }
+ Pool.ScheduleWork([&] {
+ GetCacheRecordResult Result = GetCacheRecords(LocalUri, "ue4.ddc"sv, LocalKeys, Policy);
+
+ if (!Result.Success)
+ {
+ ZEN_DEBUG("GetCacheRecords iteration {} failed", Iteration);
+ Completed.fetch_add(1);
+ return;
+ }
+
+ if (Result.Result.Results.size() != LocalKeys.size())
+ {
+ ZEN_DEBUG("GetCacheRecords iteration {} empty records", Iteration);
+ Completed.fetch_add(1);
+ return;
+ }
+ for (size_t Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& Record : Result.Result.Results)
+ {
+ const CacheKey& ExpectedKey = LocalKeys[Index++];
+ if (!Record)
+ {
+ continue;
+ }
+ if (Record->Key != ExpectedKey)
+ {
+ continue;
+ }
+ if (Record->Values.size() != 1)
+ {
+ continue;
+ }
+
+ for (const cacherequests::GetCacheRecordResultValue& Value : Record->Values)
+ {
+ if (!Value.Body)
+ {
+ continue;
+ }
+ }
+ }
+ Completed.fetch_add(1);
+ });
+ }
+ while (Completed < ThreadCount * KeyMultiplier)
+ {
+ Sleep(10);
+ }
+}
+
+TEST_CASE("zcache.rpc.allpolicies")
+{
+ using namespace std::literals;
+ using namespace utils;
+
+ ZenConfig UpstreamCfg = ZenConfig::New(13338);
+ ZenServerInstance UpstreamServer(TestEnv);
+ ZenConfig LocalCfg = ZenConfig::NewWithUpstream(13338);
+ ZenServerInstance LocalServer(TestEnv);
+ const uint16_t LocalPortNumber = 13337;
+ const auto BaseUri = fmt::format("http://localhost:{}/z$", LocalPortNumber);
+
+ SpawnServer(UpstreamServer, UpstreamCfg);
+ SpawnServer(LocalServer, LocalCfg);
+
+ std::string_view TestVersion = "F72150A02AE34B57A9EC91D36BA1CE08"sv;
+ std::string_view TestBucket = "allpoliciestest"sv;
+ std::string_view TestNamespace = "ue4.ddc"sv;
+
+ // NumKeys = (2 Value vs Record)*(2 SkipData vs Default)*(2 ForceMiss vs Not)*(2 use local)
+ // *(2 use remote)*(2 UseValue Policy vs not)*(4 cases per type)
+ constexpr int NumKeys = 256;
+ constexpr int NumValues = 4;
+ Oid ValueIds[NumValues];
+ IoHash Hash;
+ for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex)
+ {
+ ExtendableStringBuilder<16> ValueName;
+ ValueName << "ValueId_"sv << ValueIndex;
+ static_assert(sizeof(IoHash) >= sizeof(Oid));
+ ValueIds[ValueIndex] = Oid::FromMemory(IoHash::HashBuffer(ValueName.Data(), ValueName.Size() * sizeof(ValueName.Data()[0])).Hash);
+ }
+
+ struct KeyData;
+ struct UserData
+ {
+ UserData& Set(KeyData* InKeyData, int InValueIndex)
+ {
+ Data = InKeyData;
+ ValueIndex = InValueIndex;
+ return *this;
+ }
+ KeyData* Data = nullptr;
+ int ValueIndex = 0;
+ };
+ struct KeyData
+ {
+ CompressedBuffer BufferValues[NumValues];
+ uint64_t IntValues[NumValues];
+ UserData ValueUserData[NumValues];
+ bool ReceivedChunk[NumValues];
+ CacheKey Key;
+ UserData KeyUserData;
+ uint32_t KeyIndex = 0;
+ bool GetRequestsData = true;
+ bool UseValueAPI = false;
+ bool UseValuePolicy = false;
+ bool ForceMiss = false;
+ bool UseLocal = true;
+ bool UseRemote = true;
+ bool ShouldBeHit = true;
+ bool ReceivedPut = false;
+ bool ReceivedGet = false;
+ bool ReceivedPutValue = false;
+ bool ReceivedGetValue = false;
+ };
+ struct CachePutRequest
+ {
+ CacheKey Key;
+ CbObject Record;
+ CacheRecordPolicy Policy;
+ KeyData* Values;
+ UserData* Data;
+ };
+ struct CachePutValueRequest
+ {
+ CacheKey Key;
+ CompressedBuffer Value;
+ CachePolicy Policy;
+ UserData* Data;
+ };
+ struct CacheGetRequest
+ {
+ CacheKey Key;
+ CacheRecordPolicy Policy;
+ UserData* Data;
+ };
+ struct CacheGetValueRequest
+ {
+ CacheKey Key;
+ CachePolicy Policy;
+ UserData* Data;
+ };
+ struct CacheGetChunkRequest
+ {
+ CacheKey Key;
+ Oid ValueId;
+ uint64_t RawOffset;
+ uint64_t RawSize;
+ IoHash RawHash;
+ CachePolicy Policy;
+ UserData* Data;
+ };
+
+ KeyData KeyDatas[NumKeys];
+ std::vector<CachePutRequest> PutRequests;
+ std::vector<CachePutValueRequest> PutValueRequests;
+ std::vector<CacheGetRequest> GetRequests;
+ std::vector<CacheGetValueRequest> GetValueRequests;
+ std::vector<CacheGetChunkRequest> ChunkRequests;
+
+ for (uint32_t KeyIndex = 0; KeyIndex < NumKeys; ++KeyIndex)
+ {
+ IoHashStream KeyWriter;
+ KeyWriter.Append(TestVersion.data(), TestVersion.length() * sizeof(TestVersion.data()[0]));
+ KeyWriter.Append(&KeyIndex, sizeof(KeyIndex));
+ IoHash KeyHash = KeyWriter.GetHash();
+ KeyData& KeyData = KeyDatas[KeyIndex];
+
+ KeyData.Key = CacheKey::Create(TestBucket, KeyHash);
+ KeyData.KeyIndex = KeyIndex;
+ KeyData.GetRequestsData = (KeyIndex & (1 << 1)) == 0;
+ KeyData.UseValueAPI = (KeyIndex & (1 << 2)) != 0;
+ KeyData.UseValuePolicy = (KeyIndex & (1 << 3)) != 0;
+ KeyData.ForceMiss = (KeyIndex & (1 << 4)) == 0;
+ KeyData.UseLocal = (KeyIndex & (1 << 5)) == 0;
+ KeyData.UseRemote = (KeyIndex & (1 << 6)) == 0;
+ KeyData.ShouldBeHit = !KeyData.ForceMiss && (KeyData.UseLocal || KeyData.UseRemote);
+ CachePolicy SharedPolicy = KeyData.UseLocal ? CachePolicy::Local : CachePolicy::None;
+ SharedPolicy |= KeyData.UseRemote ? CachePolicy::Remote : CachePolicy::None;
+ CachePolicy PutPolicy = SharedPolicy;
+ CachePolicy GetPolicy = SharedPolicy;
+ GetPolicy |= !KeyData.GetRequestsData ? CachePolicy::SkipData : CachePolicy::None;
+ CacheKey& Key = KeyData.Key;
+
+ for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex)
+ {
+ KeyData.IntValues[ValueIndex] = static_cast<uint64_t>(KeyIndex) | (static_cast<uint64_t>(ValueIndex) << 32);
+ KeyData.BufferValues[ValueIndex] =
+ CompressedBuffer::Compress(SharedBuffer::MakeView(&KeyData.IntValues[ValueIndex], sizeof(KeyData.IntValues[ValueIndex])));
+ KeyData.ReceivedChunk[ValueIndex] = false;
+ }
+
+ UserData& KeyUserData = KeyData.KeyUserData.Set(&KeyData, -1);
+ for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex)
+ {
+ KeyData.ValueUserData[ValueIndex].Set(&KeyData, ValueIndex);
+ }
+ if (!KeyData.UseValueAPI)
+ {
+ CbObjectWriter Builder;
+ Builder.BeginObject("key"sv);
+ Builder << "Bucket"sv << Key.Bucket << "Hash"sv << Key.Hash;
+ Builder.EndObject();
+ Builder.BeginArray("Values"sv);
+ for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex)
+ {
+ Builder.BeginObject();
+ Builder.AddObjectId("Id"sv, ValueIds[ValueIndex]);
+ Builder.AddBinaryAttachment("RawHash"sv, KeyData.BufferValues[ValueIndex].DecodeRawHash());
+ Builder.AddInteger("RawSize"sv, KeyData.BufferValues[ValueIndex].DecodeRawSize());
+ Builder.EndObject();
+ }
+ Builder.EndArray();
+
+ CacheRecordPolicy PutRecordPolicy;
+ CacheRecordPolicy GetRecordPolicy;
+ if (!KeyData.UseValuePolicy)
+ {
+ PutRecordPolicy = CacheRecordPolicy(PutPolicy);
+ GetRecordPolicy = CacheRecordPolicy(GetPolicy);
+ }
+ else
+ {
+ // Switch the SkipData field in the Record policy so that if the CacheStore ignores the ValuePolicies
+ // it will use the wrong value for SkipData and fail our tests.
+ CacheRecordPolicyBuilder PutBuilder(PutPolicy ^ CachePolicy::SkipData);
+ CacheRecordPolicyBuilder GetBuilder(GetPolicy ^ CachePolicy::SkipData);
+ for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex)
+ {
+ PutBuilder.AddValuePolicy(ValueIds[ValueIndex], PutPolicy);
+ GetBuilder.AddValuePolicy(ValueIds[ValueIndex], GetPolicy);
+ }
+ PutRecordPolicy = PutBuilder.Build();
+ GetRecordPolicy = GetBuilder.Build();
+ }
+ if (!KeyData.ForceMiss)
+ {
+ PutRequests.push_back({Key, Builder.Save(), PutRecordPolicy, &KeyData, &KeyUserData});
+ }
+ GetRequests.push_back({Key, GetRecordPolicy, &KeyUserData});
+ for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex)
+ {
+ UserData& ValueUserData = KeyData.ValueUserData[ValueIndex];
+ ChunkRequests.push_back({Key, ValueIds[ValueIndex], 0, UINT64_MAX, IoHash(), GetPolicy, &ValueUserData});
+ }
+ }
+ else
+ {
+ if (!KeyData.ForceMiss)
+ {
+ PutValueRequests.push_back({Key, KeyData.BufferValues[0], PutPolicy, &KeyUserData});
+ }
+ GetValueRequests.push_back({Key, GetPolicy, &KeyUserData});
+ ChunkRequests.push_back({Key, Oid::Zero, 0, UINT64_MAX, IoHash(), GetPolicy, &KeyUserData});
+ }
+ }
+
+ // PutCacheRecords
+ {
+ CachePolicy BatchDefaultPolicy = CachePolicy::Default;
+ cacherequests::PutCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic,
+ .DefaultPolicy = BatchDefaultPolicy,
+ .Namespace = std::string(TestNamespace)};
+ Request.Requests.reserve(PutRequests.size());
+ for (CachePutRequest& PutRequest : PutRequests)
+ {
+ cacherequests::PutCacheRecordRequest& RecordRequest = Request.Requests.emplace_back();
+ RecordRequest.Key = PutRequest.Key;
+ RecordRequest.Policy = PutRequest.Policy;
+ RecordRequest.Values.reserve(NumValues);
+ for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex)
+ {
+ RecordRequest.Values.push_back({.Id = ValueIds[ValueIndex], .Body = PutRequest.Values->BufferValues[ValueIndex]});
+ }
+ PutRequest.Data->Data->ReceivedPut = true;
+ }
+
+ CbPackage Package;
+ CHECK(Request.Format(Package));
+ IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer();
+ cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}},
+ cpr::Body{(const char*)Body.GetData(), Body.GetSize()});
+ CHECK_MESSAGE(Result.status_code == 200, "PutCacheRecords unexpectedly failed.");
+ }
+
+ // PutCacheValues
+ {
+ CachePolicy BatchDefaultPolicy = CachePolicy::Default;
+
+ cacherequests::PutCacheValuesRequest Request = {.AcceptMagic = kCbPkgMagic,
+ .DefaultPolicy = BatchDefaultPolicy,
+ .Namespace = std::string(TestNamespace)};
+ Request.Requests.reserve(PutValueRequests.size());
+ for (CachePutValueRequest& PutRequest : PutValueRequests)
+ {
+ Request.Requests.push_back({.Key = PutRequest.Key, .Body = PutRequest.Value, .Policy = PutRequest.Policy});
+ PutRequest.Data->Data->ReceivedPutValue = true;
+ }
+
+ CbPackage Package;
+ CHECK(Request.Format(Package));
+
+ IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer();
+ cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}},
+ cpr::Body{(const char*)Body.GetData(), Body.GetSize()});
+ CHECK_MESSAGE(Result.status_code == 200, "PutCacheValues unexpectedly failed.");
+ }
+
+ for (KeyData& KeyData : KeyDatas)
+ {
+ if (!KeyData.ForceMiss)
+ {
+ if (!KeyData.UseValueAPI)
+ {
+ CHECK_MESSAGE(KeyData.ReceivedPut, WriteToString<32>("Key ", KeyData.KeyIndex, " was unexpectedly not put.").c_str());
+ }
+ else
+ {
+ CHECK_MESSAGE(KeyData.ReceivedPutValue,
+ WriteToString<32>("Key ", KeyData.KeyIndex, " was unexpectedly not put to ValueAPI.").c_str());
+ }
+ }
+ }
+
+ // GetCacheRecords
+ {
+ CachePolicy BatchDefaultPolicy = CachePolicy::Default;
+ cacherequests::GetCacheRecordsRequest Request = {.AcceptMagic = kCbPkgMagic,
+ .DefaultPolicy = BatchDefaultPolicy,
+ .Namespace = std::string(TestNamespace)};
+ Request.Requests.reserve(GetRequests.size());
+ for (CacheGetRequest& GetRequest : GetRequests)
+ {
+ Request.Requests.push_back({.Key = GetRequest.Key, .Policy = GetRequest.Policy});
+ }
+
+ CbPackage Package;
+ CHECK(Request.Format(Package));
+ IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer();
+ cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}},
+ cpr::Body{(const char*)Body.GetData(), Body.GetSize()});
+ CHECK_MESSAGE(Result.status_code == 200, "GetCacheRecords unexpectedly failed.");
+ CbPackage Response = ParsePackageMessage(zen::IoBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size()));
+ bool Loaded = !Response.IsNull();
+ CHECK_MESSAGE(Loaded, "GetCacheRecords response failed to load.");
+ cacherequests::GetCacheRecordsResult RequestResult;
+ CHECK(RequestResult.Parse(Response));
+ CHECK_MESSAGE(RequestResult.Results.size() == GetRequests.size(), "GetCacheRecords response count did not match request count.");
+ for (int Index = 0; const std::optional<cacherequests::GetCacheRecordResult>& RecordResult : RequestResult.Results)
+ {
+ bool Succeeded = RecordResult.has_value();
+ CacheGetRequest& GetRequest = GetRequests[Index++];
+ KeyData* KeyData = GetRequest.Data->Data;
+ KeyData->ReceivedGet = true;
+ WriteToString<32> Name("Get(", KeyData->KeyIndex, ")");
+ if (KeyData->ShouldBeHit)
+ {
+ CHECK_MESSAGE(Succeeded, WriteToString<32>(Name, " unexpectedly failed.").c_str());
+ }
+ else if (KeyData->ForceMiss)
+ {
+ CHECK_MESSAGE(!Succeeded, WriteToString<32>(Name, " unexpectedly succeeded.").c_str());
+ }
+ if (!KeyData->ForceMiss && Succeeded)
+ {
+ CHECK_MESSAGE(RecordResult->Values.size() == NumValues,
+ WriteToString<32>(Name, " number of values did not match.").c_str());
+ for (const cacherequests::GetCacheRecordResultValue& Value : RecordResult->Values)
+ {
+ int ExpectedValueIndex = 0;
+ for (; ExpectedValueIndex < NumValues; ++ExpectedValueIndex)
+ {
+ if (ValueIds[ExpectedValueIndex] == Value.Id)
+ {
+ break;
+ }
+ }
+ CHECK_MESSAGE(ExpectedValueIndex < NumValues, WriteToString<32>(Name, " could not find matching ValueId.").c_str());
+
+ WriteToString<32> ValueName("Get(", KeyData->KeyIndex, ",", ExpectedValueIndex, ")");
+
+ CompressedBuffer ExpectedValue = KeyData->BufferValues[ExpectedValueIndex];
+ CHECK_MESSAGE(Value.RawHash == ExpectedValue.DecodeRawHash(),
+ WriteToString<32>(ValueName, " RawHash did not match.").c_str());
+ CHECK_MESSAGE(Value.RawSize == ExpectedValue.DecodeRawSize(),
+ WriteToString<32>(ValueName, " RawSize did not match.").c_str());
+
+ if (KeyData->GetRequestsData)
+ {
+ SharedBuffer Buffer = Value.Body.Decompress();
+ CHECK_MESSAGE(Buffer.GetSize() == Value.RawSize,
+ WriteToString<32>(ValueName, " BufferSize did not match RawSize.").c_str());
+ uint64_t ActualIntValue = ((const uint64_t*)Buffer.GetData())[0];
+ uint64_t ExpectedIntValue = KeyData->IntValues[ExpectedValueIndex];
+ CHECK_MESSAGE(ActualIntValue == ExpectedIntValue, WriteToString<32>(ValueName, " had unexpected data.").c_str());
+ }
+ }
+ }
+ }
+ }
+
+ // GetCacheValues
+ {
+ CachePolicy BatchDefaultPolicy = CachePolicy::Default;
+
+ cacherequests::GetCacheValuesRequest GetCacheValuesRequest = {.AcceptMagic = kCbPkgMagic,
+ .DefaultPolicy = BatchDefaultPolicy,
+ .Namespace = std::string(TestNamespace)};
+ GetCacheValuesRequest.Requests.reserve(GetValueRequests.size());
+ for (CacheGetValueRequest& GetRequest : GetValueRequests)
+ {
+ GetCacheValuesRequest.Requests.push_back({.Key = GetRequest.Key, .Policy = GetRequest.Policy});
+ }
+
+ CbPackage Package;
+ CHECK(GetCacheValuesRequest.Format(Package));
+
+ IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer();
+ cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}},
+ cpr::Body{(const char*)Body.GetData(), Body.GetSize()});
+ CHECK_MESSAGE(Result.status_code == 200, "GetCacheValues unexpectedly failed.");
+ IoBuffer MessageBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size());
+ CbPackage Response = ParsePackageMessage(MessageBuffer);
+ bool Loaded = !Response.IsNull();
+ CHECK_MESSAGE(Loaded, "GetCacheValues response failed to load.");
+ cacherequests::GetCacheValuesResult GetCacheValuesResult;
+ CHECK(GetCacheValuesResult.Parse(Response));
+ for (int Index = 0; const cacherequests::CacheValueResult& ValueResult : GetCacheValuesResult.Results)
+ {
+ bool Succeeded = ValueResult.RawHash != IoHash::Zero;
+ CacheGetValueRequest& Request = GetValueRequests[Index++];
+ KeyData* KeyData = Request.Data->Data;
+ KeyData->ReceivedGetValue = true;
+ WriteToString<32> Name("GetValue("sv, KeyData->KeyIndex, ")"sv);
+
+ if (KeyData->ShouldBeHit)
+ {
+ CHECK_MESSAGE(Succeeded, WriteToString<32>(Name, " unexpectedly failed.").c_str());
+ }
+ else if (KeyData->ForceMiss)
+ {
+ CHECK_MESSAGE(!Succeeded, WriteToString<32>(Name, "unexpectedly succeeded.").c_str());
+ }
+ if (!KeyData->ForceMiss && Succeeded)
+ {
+ CompressedBuffer ExpectedValue = KeyData->BufferValues[0];
+ CHECK_MESSAGE(ValueResult.RawHash == ExpectedValue.DecodeRawHash(),
+ WriteToString<32>(Name, " RawHash did not match.").c_str());
+ CHECK_MESSAGE(ValueResult.RawSize == ExpectedValue.DecodeRawSize(),
+ WriteToString<32>(Name, " RawSize did not match.").c_str());
+
+ if (KeyData->GetRequestsData)
+ {
+ SharedBuffer Buffer = ValueResult.Body.Decompress();
+ CHECK_MESSAGE(Buffer.GetSize() == ValueResult.RawSize,
+ WriteToString<32>(Name, " BufferSize did not match RawSize.").c_str());
+ uint64_t ActualIntValue = ((const uint64_t*)Buffer.GetData())[0];
+ uint64_t ExpectedIntValue = KeyData->IntValues[0];
+ CHECK_MESSAGE(ActualIntValue == ExpectedIntValue, WriteToString<32>(Name, " had unexpected data.").c_str());
+ }
+ }
+ }
+ }
+
+ // GetCacheChunks
+ {
+ std::sort(ChunkRequests.begin(), ChunkRequests.end(), [](CacheGetChunkRequest& A, CacheGetChunkRequest& B) {
+ return A.Key.Hash < B.Key.Hash;
+ });
+ CachePolicy BatchDefaultPolicy = CachePolicy::Default;
+ cacherequests::GetCacheChunksRequest GetCacheChunksRequest = {.AcceptMagic = kCbPkgMagic,
+ .DefaultPolicy = BatchDefaultPolicy,
+ .Namespace = std::string(TestNamespace)};
+ GetCacheChunksRequest.Requests.reserve(ChunkRequests.size());
+ for (CacheGetChunkRequest& ChunkRequest : ChunkRequests)
+ {
+ GetCacheChunksRequest.Requests.push_back({.Key = ChunkRequest.Key,
+ .ValueId = ChunkRequest.ValueId,
+ .ChunkId = IoHash(),
+ .RawOffset = ChunkRequest.RawOffset,
+ .RawSize = ChunkRequest.RawSize,
+ .Policy = ChunkRequest.Policy});
+ }
+ CbPackage Package;
+ CHECK(GetCacheChunksRequest.Format(Package));
+
+ IoBuffer Body = FormatPackageMessageBuffer(Package).Flatten().AsIoBuffer();
+ cpr::Response Result = cpr::Post(cpr::Url{fmt::format("{}/$rpc", BaseUri)},
+ cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}},
+ cpr::Body{(const char*)Body.GetData(), Body.GetSize()});
+ CHECK_MESSAGE(Result.status_code == 200, "GetCacheChunks unexpectedly failed.");
+ CbPackage Response = ParsePackageMessage(zen::IoBuffer(zen::IoBuffer::Wrap, Result.text.data(), Result.text.size()));
+ bool Loaded = !Response.IsNull();
+ CHECK_MESSAGE(Loaded, "GetCacheChunks response failed to load.");
+ cacherequests::GetCacheChunksResult GetCacheChunksResult;
+ CHECK(GetCacheChunksResult.Parse(Response));
+ CHECK_MESSAGE(GetCacheChunksResult.Results.size() == ChunkRequests.size(),
+ "GetCacheChunks response count did not match request count.");
+
+ for (int Index = 0; const cacherequests::CacheValueResult& ValueResult : GetCacheChunksResult.Results)
+ {
+ bool Succeeded = ValueResult.RawHash != IoHash::Zero;
+
+ CacheGetChunkRequest& Request = ChunkRequests[Index++];
+ KeyData* KeyData = Request.Data->Data;
+ int ValueIndex = Request.Data->ValueIndex >= 0 ? Request.Data->ValueIndex : 0;
+ KeyData->ReceivedChunk[ValueIndex] = true;
+ WriteToString<32> Name("GetChunks("sv, KeyData->KeyIndex, ","sv, ValueIndex, ")"sv);
+
+ if (KeyData->ShouldBeHit)
+ {
+ CHECK_MESSAGE(Succeeded, WriteToString<256>(Name, " unexpectedly failed."sv).c_str());
+ }
+ else if (KeyData->ForceMiss)
+ {
+ CHECK_MESSAGE(!Succeeded, WriteToString<256>(Name, " unexpectedly succeeded."sv).c_str());
+ }
+ if (KeyData->ShouldBeHit && Succeeded)
+ {
+ CompressedBuffer ExpectedValue = KeyData->BufferValues[ValueIndex];
+ CHECK_MESSAGE(ValueResult.RawHash == ExpectedValue.DecodeRawHash(),
+ WriteToString<32>(Name, " had unexpected RawHash.").c_str());
+ CHECK_MESSAGE(ValueResult.RawSize == ExpectedValue.DecodeRawSize(),
+ WriteToString<32>(Name, " had unexpected RawSize.").c_str());
+
+ if (KeyData->GetRequestsData)
+ {
+ SharedBuffer Buffer = ValueResult.Body.Decompress();
+ CHECK_MESSAGE(Buffer.GetSize() == ValueResult.RawSize,
+ WriteToString<32>(Name, " BufferSize did not match RawSize.").c_str());
+ uint64_t ActualIntValue = ((const uint64_t*)Buffer.GetData())[0];
+ uint64_t ExpectedIntValue = KeyData->IntValues[ValueIndex];
+ CHECK_MESSAGE(ActualIntValue == ExpectedIntValue, WriteToString<32>(Name, " had unexpected data.").c_str());
+ }
+ }
+ }
+ }
+
+ for (KeyData& KeyData : KeyDatas)
+ {
+ if (!KeyData.UseValueAPI)
+ {
+ CHECK_MESSAGE(KeyData.ReceivedGet, WriteToString<32>("Get(", KeyData.KeyIndex, ") was unexpectedly not received.").c_str());
+ for (int ValueIndex = 0; ValueIndex < NumValues; ++ValueIndex)
+ {
+ CHECK_MESSAGE(
+ KeyData.ReceivedChunk[ValueIndex],
+ WriteToString<32>("GetChunks(", KeyData.KeyIndex, ",", ValueIndex, ") was unexpectedly not received.").c_str());
+ }
+ }
+ else
+ {
+ CHECK_MESSAGE(KeyData.ReceivedGetValue,
+ WriteToString<32>("GetValue(", KeyData.KeyIndex, ") was unexpectedly not received.").c_str());
+ CHECK_MESSAGE(KeyData.ReceivedChunk[0],
+ WriteToString<32>("GetChunks(", KeyData.KeyIndex, ") was unexpectedly not received.").c_str());
+ }
+ }
+}
+
+class ZenServerTestHelper
+{
+public:
+ ZenServerTestHelper(std::string_view HelperId, int ServerCount) : m_HelperId{HelperId}, m_ServerCount{ServerCount} {}
+ ~ZenServerTestHelper() {}
+
+ void SpawnServers(std::string_view AdditionalServerArgs = std::string_view())
+ {
+ SpawnServers([](ZenServerInstance&) {}, AdditionalServerArgs);
+ }
+
+ void SpawnServers(auto&& Callback, std::string_view AdditionalServerArgs)
+ {
+ ZEN_INFO("{}: spawning {} server instances", m_HelperId, m_ServerCount);
+
+ m_Instances.resize(m_ServerCount);
+
+ for (int i = 0; i < m_ServerCount; ++i)
+ {
+ auto& Instance = m_Instances[i];
+
+ Instance = std::make_unique<ZenServerInstance>(TestEnv);
+ Instance->SetTestDir(TestEnv.CreateNewTestDir());
+
+ Callback(*Instance);
+
+ Instance->SpawnServer(13337 + i, AdditionalServerArgs);
+ }
+
+ for (int i = 0; i < m_ServerCount; ++i)
+ {
+ auto& Instance = m_Instances[i];
+
+ Instance->WaitUntilReady();
+ }
+ }
+
+ ZenServerInstance& GetInstance(int Index) { return *m_Instances[Index]; }
+
+private:
+ std::string m_HelperId;
+ int m_ServerCount = 0;
+ std::vector<std::unique_ptr<ZenServerInstance> > m_Instances;
+};
+
+class IoDispatcher
+{
+public:
+ IoDispatcher(asio::io_context& IoCtx) : m_IoCtx(IoCtx) {}
+ ~IoDispatcher() { Stop(); }
+
+ void Run()
+ {
+ Stop();
+
+ m_Running = true;
+
+ m_IoThread = std::thread([this]() {
+ try
+ {
+ m_IoCtx.run();
+ }
+ catch (std::exception& Error)
+ {
+ m_Error = Error;
+ }
+ });
+ }
+
+ void Stop()
+ {
+ if (m_Running)
+ {
+ m_Running = false;
+
+ if (m_IoThread.joinable())
+ {
+ m_IoThread.join();
+ }
+ }
+ }
+
+ bool IsRunning() const { return m_Running; }
+
+ const std::exception& Error() { return m_Error; }
+
+private:
+ asio::io_context& m_IoCtx;
+ std::thread m_IoThread;
+ std::exception m_Error;
+ std::atomic_bool m_Running{false};
+};
+
+TEST_CASE("http.basics")
+{
+ using namespace std::literals;
+
+ ZenServerTestHelper Servers{"http.basics"sv, 1};
+ Servers.SpawnServers();
+
+ ZenServerInstance& Instance = Servers.GetInstance(0);
+ const std::string BaseUri = Instance.GetBaseUri();
+
+ {
+ cpr::Response r = cpr::Get(cpr::Url{fmt::format("{}/testing/hello", BaseUri)});
+ CHECK(IsHttpSuccessCode(r.status_code));
+ }
+
+ {
+ cpr::Response r = cpr::Post(cpr::Url{fmt::format("{}/testing/hello", BaseUri)});
+ CHECK_EQ(r.status_code, 404);
+ }
+
+ {
+ cpr::Response r = cpr::Post(cpr::Url{fmt::format("{}/testing/echo", BaseUri)}, cpr::Body{"yoyoyoyo"});
+ CHECK_EQ(r.status_code, 200);
+ CHECK_EQ(r.text, "yoyoyoyo");
+ }
+}
+
+TEST_CASE("http.package")
+{
+ using namespace std::literals;
+
+ ZenServerTestHelper Servers{"http.package"sv, 1};
+ Servers.SpawnServers();
+
+ ZenServerInstance& Instance = Servers.GetInstance(0);
+ const std::string BaseUri = Instance.GetBaseUri();
+
+ static const uint8_t Data1[] = {0, 1, 2, 3};
+ static const uint8_t Data2[] = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+
+ zen::CompressedBuffer AttachmentData1 = zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone({Data1, 4}),
+ zen::OodleCompressor::NotSet,
+ zen::OodleCompressionLevel::None);
+ zen::CbAttachment Attach1{AttachmentData1, AttachmentData1.DecodeRawHash()};
+ zen::CompressedBuffer AttachmentData2 = zen::CompressedBuffer::Compress(zen::SharedBuffer::Clone({Data2, 8}),
+ zen::OodleCompressor::NotSet,
+ zen::OodleCompressionLevel::None);
+ zen::CbAttachment Attach2{AttachmentData2, AttachmentData2.DecodeRawHash()};
+
+ zen::CbObjectWriter Writer;
+
+ Writer.AddAttachment("attach1", Attach1);
+ Writer.AddAttachment("attach2", Attach2);
+
+ zen::CbObject CoreObject = Writer.Save();
+
+ zen::CbPackage TestPackage;
+ TestPackage.SetObject(CoreObject);
+ TestPackage.AddAttachment(Attach1);
+ TestPackage.AddAttachment(Attach2);
+
+ zen::HttpClient TestClient(BaseUri);
+ zen::HttpClient::Response Response = TestClient.TransactPackage("/testing/package"sv, TestPackage);
+
+ zen::CbPackage ResponsePackage = ParsePackageMessage(Response.ResponsePayload);
+
+ CHECK_EQ(ResponsePackage, TestPackage);
+}
+
+TEST_CASE("websocket.basic")
+{
+ if (true)
+ {
+ return;
+ }
+
+ std::filesystem::path TestDir = TestEnv.CreateNewTestDir();
+ const uint16_t PortNumber = 13337;
+ const auto MaxWaitTime = std::chrono::seconds(5);
+
+ ZenServerInstance Inst(TestEnv);
+ Inst.SetTestDir(TestDir);
+ Inst.SpawnServer(PortNumber, "--websocket-port=8848"sv);
+ Inst.WaitUntilReady();
+
+ asio::io_context IoCtx;
+ IoDispatcher IoDispatcher(IoCtx);
+ auto WebSocket = WebSocketClient::Create(IoCtx);
+
+ auto ConnectFuture = WebSocket->Connect({.Host = "127.0.0.1", .Port = 8848, .Endpoint = "/zen"});
+ IoDispatcher.Run();
+
+ ConnectFuture.wait_for(MaxWaitTime);
+ CHECK(ConnectFuture.get());
+
+ for (size_t Idx = 0; Idx < 10; Idx++)
+ {
+ CbObjectWriter Request;
+ Request << "Method"sv
+ << "SayHello"sv;
+
+ WebSocketMessage RequestMsg;
+ RequestMsg.SetMessageType(WebSocketMessageType::kRequest);
+ RequestMsg.SetBody(Request.Save());
+
+ auto ResponseFuture = WebSocket->SendRequest(std::move(RequestMsg));
+ ResponseFuture.wait_for(MaxWaitTime);
+
+ CbObject Response = ResponseFuture.get().Body().GetObject();
+ std::string_view Message = Response["Result"].AsString();
+
+ CHECK(Message == "Hello Friend!!"sv);
+ }
+
+ WebSocket->Disconnect();
+
+ IoCtx.stop();
+ IoDispatcher.Stop();
+}
+
+std::string
+OidAsString(const Oid& Id)
+{
+ StringBuilder<25> OidStringBuilder;
+ Id.ToString(OidStringBuilder);
+ return OidStringBuilder.ToString();
+}
+
+CbPackage
+CreateOplogPackage(const Oid& Id, const std::span<const std::pair<Oid, CompressedBuffer> >& Attachments)
+{
+ CbPackage Package;
+ CbObjectWriter Object;
+ Object << "key"sv << OidAsString(Id);
+ if (!Attachments.empty())
+ {
+ Object.BeginArray("bulkdata");
+ for (const auto& Attachment : Attachments)
+ {
+ CbAttachment Attach(Attachment.second, Attachment.second.DecodeRawHash());
+ Object.BeginObject();
+ Object << "id"sv << Attachment.first;
+ Object << "type"sv
+ << "Standard"sv;
+ Object << "data"sv << Attach;
+ Object.EndObject();
+
+ Package.AddAttachment(Attach);
+ ZEN_DEBUG("Added attachment {}", Attach.GetHash());
+ }
+ Object.EndArray();
+ }
+ Package.SetObject(Object.Save());
+ return Package;
+};
+
+std::vector<std::pair<Oid, CompressedBuffer> >
+CreateAttachments(const std::span<const size_t>& Sizes)
+{
+ std::vector<std::pair<Oid, CompressedBuffer> > Result;
+ Result.reserve(Sizes.size());
+ for (size_t Size : Sizes)
+ {
+ std::vector<uint8_t> Data;
+ Data.resize(Size);
+ uint16_t* DataPtr = reinterpret_cast<uint16_t*>(Data.data());
+ for (size_t Idx = 0; Idx < Size / 2; ++Idx)
+ {
+ DataPtr[Idx] = static_cast<uint16_t>(Idx % 0xffffu);
+ }
+ if (Size & 1)
+ {
+ Data[Size - 1] = static_cast<uint8_t>((Size - 1) & 0xff);
+ }
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size()));
+ Result.emplace_back(std::pair<Oid, CompressedBuffer>(Oid::NewOid(), Compressed));
+ }
+ return Result;
+}
+
+cpr::Body
+AsBody(const IoBuffer& Payload)
+{
+ return cpr::Body{(const char*)Payload.GetData(), Payload.Size()};
+};
+
+enum CbWriterMeta
+{
+ BeginObject,
+ EndObject,
+ BeginArray,
+ EndArray
+};
+
+inline CbWriter&
+operator<<(CbWriter& Writer, CbWriterMeta Meta)
+{
+ switch (Meta)
+ {
+ case BeginObject:
+ Writer.BeginObject();
+ break;
+ case EndObject:
+ Writer.EndObject();
+ break;
+ case BeginArray:
+ Writer.BeginArray();
+ break;
+ case EndArray:
+ Writer.EndArray();
+ break;
+ default:
+ ZEN_ASSERT(false);
+ }
+ return Writer;
+}
+
+TEST_CASE("project.remote")
+{
+ using namespace std::literals;
+
+ ZenServerTestHelper Servers("remote", 3);
+ Servers.SpawnServers("--debug");
+
+ std::vector<Oid> OpIds;
+ OpIds.reserve(24);
+ for (size_t I = 0; I < 24; ++I)
+ {
+ OpIds.emplace_back(Oid::NewOid());
+ }
+
+ std::unordered_map<Oid, std::vector<std::pair<Oid, CompressedBuffer> >, Oid::Hasher> Attachments;
+ {
+ std::vector<std::size_t> AttachmentSizes({7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269,
+ 2257, 3685, 3489, 7194, 6151, 5482, 6217, 3511, 6738, 5061, 7537, 2759,
+ 1916, 8210, 2235, 4024, 1582, 5251, 491, 5464, 4607, 8135, 3767, 4045,
+ 4415, 5007, 8876, 6761, 3359, 8526, 4097, 4855, 8225});
+ auto It = AttachmentSizes.begin();
+ Attachments[OpIds[0]] = {};
+ Attachments[OpIds[1]] = CreateAttachments(std::initializer_list<size_t>{*It++});
+ Attachments[OpIds[2]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++, *It++});
+ Attachments[OpIds[3]] = CreateAttachments(std::initializer_list<size_t>{*It++});
+ Attachments[OpIds[4]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++});
+ Attachments[OpIds[5]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++, *It++});
+ Attachments[OpIds[6]] = CreateAttachments(std::initializer_list<size_t>{*It++});
+ Attachments[OpIds[7]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++, *It++});
+ Attachments[OpIds[8]] = CreateAttachments(std::initializer_list<size_t>{});
+ Attachments[OpIds[9]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++, *It++});
+ Attachments[OpIds[10]] = CreateAttachments(std::initializer_list<size_t>{*It++});
+ Attachments[OpIds[11]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++});
+ Attachments[OpIds[12]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++, *It++});
+ Attachments[OpIds[13]] = CreateAttachments(std::initializer_list<size_t>{*It++});
+ Attachments[OpIds[14]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++});
+ Attachments[OpIds[15]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++});
+ Attachments[OpIds[16]] = CreateAttachments(std::initializer_list<size_t>{});
+ Attachments[OpIds[17]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++});
+ Attachments[OpIds[18]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++});
+ Attachments[OpIds[19]] = CreateAttachments(std::initializer_list<size_t>{});
+ Attachments[OpIds[20]] = CreateAttachments(std::initializer_list<size_t>{*It++});
+ Attachments[OpIds[21]] = CreateAttachments(std::initializer_list<size_t>{*It++});
+ Attachments[OpIds[22]] = CreateAttachments(std::initializer_list<size_t>{*It++, *It++, *It++});
+ Attachments[OpIds[23]] = CreateAttachments(std::initializer_list<size_t>{*It++});
+ ZEN_ASSERT(It == AttachmentSizes.end());
+ }
+
+ auto AddOp = [](const CbObject& Op, std::unordered_map<Oid, uint32_t, Oid::Hasher>& Ops) {
+ XXH3_128Stream KeyHasher;
+ Op["key"sv].WriteToStream([&](const void* Data, size_t Size) { KeyHasher.Append(Data, Size); });
+ XXH3_128 KeyHash = KeyHasher.GetHash();
+ Oid Id;
+ memcpy(Id.OidBits, &KeyHash, sizeof Id.OidBits);
+ IoBuffer Buffer = Op.GetBuffer().AsIoBuffer();
+ const uint32_t OpCoreHash = uint32_t(XXH3_64bits(Buffer.GetData(), Buffer.GetSize()) & 0xffffFFFF);
+ Ops.insert({Id, OpCoreHash});
+ };
+
+ auto MakeProject = [](cpr::Session& Session, std::string_view UrlBase, std::string_view ProjectName) {
+ CbObjectWriter Project;
+ Project.AddString("id"sv, ProjectName);
+ Project.AddString("root"sv, ""sv);
+ Project.AddString("engine"sv, ""sv);
+ Project.AddString("project"sv, ""sv);
+ Project.AddString("projectfile"sv, ""sv);
+ IoBuffer ProjectPayload = Project.Save().GetBuffer().AsIoBuffer();
+ std::string ProjectRequest = fmt::format("{}/prj/{}", UrlBase, ProjectName);
+ Session.SetUrl({ProjectRequest});
+ Session.SetBody(cpr::Body{(const char*)ProjectPayload.GetData(), ProjectPayload.GetSize()});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ };
+
+ auto MakeOplog = [](cpr::Session& Session, std::string_view UrlBase, std::string_view ProjectName, std::string_view OplogName) {
+ std::string CreateOplogRequest = fmt::format("{}/prj/{}/oplog/{}", UrlBase, ProjectName, OplogName);
+ Session.SetUrl({CreateOplogRequest});
+ Session.SetBody(cpr::Body{});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ };
+
+ auto MakeOp = [](cpr::Session& Session,
+ std::string_view UrlBase,
+ std::string_view ProjectName,
+ std::string_view OplogName,
+ const CbPackage& OpPackage) {
+ std::string CreateOpRequest = fmt::format("{}/prj/{}/oplog/{}/new", UrlBase, ProjectName, OplogName);
+ Session.SetUrl({CreateOpRequest});
+ zen::BinaryWriter MemOut;
+ legacy::SaveCbPackage(OpPackage, MemOut);
+ Session.SetBody(cpr::Body{(const char*)MemOut.Data(), MemOut.Size()});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ };
+
+ cpr::Session Session;
+ MakeProject(Session, Servers.GetInstance(0).GetBaseUri(), "proj0");
+ MakeOplog(Session, Servers.GetInstance(0).GetBaseUri(), "proj0", "oplog0");
+
+ std::unordered_map<Oid, uint32_t, Oid::Hasher> SourceOps;
+ for (const Oid& OpId : OpIds)
+ {
+ CbPackage OpPackage = CreateOplogPackage(OpId, Attachments[OpId]);
+ CHECK(OpPackage.GetAttachments().size() == Attachments[OpId].size());
+ AddOp(OpPackage.GetObject(), SourceOps);
+ MakeOp(Session, Servers.GetInstance(0).GetBaseUri(), "proj0", "oplog0", OpPackage);
+ }
+
+ std::vector<IoHash> AttachmentHashes;
+ AttachmentHashes.reserve(Attachments.size());
+ for (const auto& AttachmentOplog : Attachments)
+ {
+ for (const auto& Attachment : AttachmentOplog.second)
+ {
+ AttachmentHashes.emplace_back(Attachment.second.DecodeRawHash());
+ }
+ }
+
+ auto MakeCbObjectPayload = [](std::function<void(CbObjectWriter & Writer)> Write) -> IoBuffer {
+ CbObjectWriter Writer;
+ Write(Writer);
+ IoBuffer Result = Writer.Save().GetBuffer().AsIoBuffer();
+ Result.MakeOwned();
+ return Result;
+ };
+
+ auto ValidateAttachments = [&MakeCbObjectPayload, &AttachmentHashes, &Servers, &Session](int ServerIndex,
+ std::string_view Project,
+ std::string_view Oplog) {
+ std::string GetChunksRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(ServerIndex).GetBaseUri(), Project, Oplog);
+ Session.SetUrl({GetChunksRequest});
+ IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes](CbObjectWriter& Writer) {
+ Writer << "method"sv
+ << "getchunks"sv;
+ Writer << "chunks"sv << BeginArray;
+ for (const IoHash& Chunk : AttachmentHashes)
+ {
+ Writer << Chunk;
+ }
+ Writer << EndArray; // chunks
+ });
+ Session.SetBody(AsBody(Payload));
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ CbPackage ResponsePackage = ParsePackageMessage(IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size()));
+ CHECK(ResponsePackage.GetAttachments().size() == AttachmentHashes.size());
+ };
+
+ auto ValidateOplog = [&SourceOps, &AddOp, &Servers, &Session](int ServerIndex, std::string_view Project, std::string_view Oplog) {
+ std::unordered_map<Oid, uint32_t, Oid::Hasher> TargetOps;
+ std::vector<CbObject> ResultingOplog;
+
+ std::string GetOpsRequest =
+ fmt::format("{}/prj/{}/oplog/{}/entries", Servers.GetInstance(ServerIndex).GetBaseUri(), Project, Oplog);
+ Session.SetUrl({GetOpsRequest});
+ cpr::Response Response = Session.Get();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+
+ IoBuffer Payload(IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ CbObject OplogResonse = LoadCompactBinaryObject(Payload);
+ CbArrayView EntriesArray = OplogResonse["entries"sv].AsArrayView();
+
+ for (CbFieldView OpEntry : EntriesArray)
+ {
+ CbObjectView Core = OpEntry.AsObjectView();
+ BinaryWriter Writer;
+ Core.CopyTo(Writer);
+ MemoryView OpView = Writer.GetView();
+ IoBuffer OpBuffer(IoBuffer::Wrap, OpView.GetData(), OpView.GetSize());
+ CbObject Op(SharedBuffer(OpBuffer), CbFieldType::HasFieldType);
+ AddOp(Op, TargetOps);
+ }
+ CHECK(SourceOps == TargetOps);
+ };
+
+ SUBCASE("File")
+ {
+ ScopedTemporaryDirectory TempDir;
+ {
+ std::string SaveOplogRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(0).GetBaseUri(), "proj0", "oplog0");
+ Session.SetUrl({SaveOplogRequest});
+
+ IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) {
+ Writer << "method"sv
+ << "export"sv;
+ Writer << "params" << BeginObject;
+ {
+ Writer << "maxblocksize"sv << 3072u;
+ Writer << "maxchunkembedsize"sv << 1296u;
+ Writer << "force"sv << false;
+ Writer << "file"sv << BeginObject;
+ {
+ Writer << "path"sv << path;
+ Writer << "name"sv
+ << "proj0_oplog0"sv;
+ }
+ Writer << EndObject; // "file"
+ }
+ Writer << EndObject; // "params"
+ });
+ Session.SetBody(AsBody(Payload));
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ }
+ {
+ MakeProject(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy");
+ MakeOplog(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy");
+ std::string LoadOplogRequest =
+ fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy");
+ Session.SetUrl({LoadOplogRequest});
+
+ IoBuffer Payload = MakeCbObjectPayload([&AttachmentHashes, path = TempDir.Path().string()](CbObjectWriter& Writer) {
+ Writer << "method"sv
+ << "import"sv;
+ Writer << "params" << BeginObject;
+ {
+ Writer << "force"sv << false;
+ Writer << "file"sv << BeginObject;
+ {
+ Writer << "path"sv << path;
+ Writer << "name"sv
+ << "proj0_oplog0"sv;
+ }
+ Writer << EndObject; // "file"
+ }
+ Writer << EndObject; // "params"
+ });
+ Session.SetBody(AsBody(Payload));
+
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ }
+ ValidateAttachments(1, "proj0_copy", "oplog0_copy");
+ ValidateOplog(1, "proj0_copy", "oplog0_copy");
+ }
+
+ SUBCASE("File disable blocks")
+ {
+ ScopedTemporaryDirectory TempDir;
+ {
+ std::string SaveOplogRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(0).GetBaseUri(), "proj0", "oplog0");
+ Session.SetUrl({SaveOplogRequest});
+
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
+ Writer << "method"sv
+ << "export"sv;
+ Writer << "params" << BeginObject;
+ {
+ Writer << "maxblocksize"sv << 3072u;
+ Writer << "maxchunkembedsize"sv << 1296u;
+ Writer << "force"sv << false;
+ Writer << "file"sv << BeginObject;
+ {
+ Writer << "path"sv << TempDir.Path().string();
+ Writer << "name"sv
+ << "proj0_oplog0"sv;
+ Writer << "disableblocks"sv << true;
+ }
+ Writer << EndObject; // "file"
+ }
+ Writer << EndObject; // "params"
+ });
+ Session.SetBody(AsBody(Payload));
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ }
+ {
+ MakeProject(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy");
+ MakeOplog(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy");
+ std::string LoadOplogRequest =
+ fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy");
+ Session.SetUrl({LoadOplogRequest});
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
+ Writer << "method"sv
+ << "import"sv;
+ Writer << "params" << BeginObject;
+ {
+ Writer << "force"sv << false;
+ Writer << "file"sv << BeginObject;
+ {
+ Writer << "path"sv << TempDir.Path().string();
+ Writer << "name"sv
+ << "proj0_oplog0"sv;
+ }
+ Writer << EndObject; // "file"
+ }
+ Writer << EndObject; // "params"
+ });
+ Session.SetBody(AsBody(Payload));
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ }
+ ValidateAttachments(1, "proj0_copy", "oplog0_copy");
+ ValidateOplog(1, "proj0_copy", "oplog0_copy");
+ }
+
+ SUBCASE("File force temp blocks")
+ {
+ ScopedTemporaryDirectory TempDir;
+ {
+ std::string SaveOplogRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(0).GetBaseUri(), "proj0", "oplog0");
+ Session.SetUrl({SaveOplogRequest});
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
+ Writer << "method"sv
+ << "export"sv;
+ Writer << "params" << BeginObject;
+ {
+ Writer << "maxblocksize"sv << 3072u;
+ Writer << "maxchunkembedsize"sv << 1296u;
+ Writer << "force"sv << false;
+ Writer << "file"sv << BeginObject;
+ {
+ Writer << "path"sv << TempDir.Path().string();
+ Writer << "name"sv
+ << "proj0_oplog0"sv;
+ Writer << "enabletempblocks"sv << true;
+ }
+ Writer << EndObject; // "file"
+ }
+ Writer << EndObject; // "params"
+ });
+ Session.SetBody(AsBody(Payload));
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ }
+ {
+ MakeProject(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy");
+ MakeOplog(Session, Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy");
+ std::string LoadOplogRequest =
+ fmt::format("{}/prj/{}/oplog/{}/rpc", Servers.GetInstance(1).GetBaseUri(), "proj0_copy", "oplog0_copy");
+ Session.SetUrl({LoadOplogRequest});
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
+ Writer << "method"sv
+ << "import"sv;
+ Writer << "params" << BeginObject;
+ {
+ Writer << "force"sv << false;
+ Writer << "file"sv << BeginObject;
+ {
+ Writer << "path"sv << TempDir.Path().string();
+ Writer << "name"sv
+ << "proj0_oplog0"sv;
+ }
+ Writer << EndObject; // "file"
+ }
+ Writer << EndObject; // "params"
+ });
+ Session.SetBody(AsBody(Payload));
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ }
+ ValidateAttachments(1, "proj0_copy", "oplog0_copy");
+ ValidateOplog(1, "proj0_copy", "oplog0_copy");
+ }
+
+ SUBCASE("Zen")
+ {
+ ScopedTemporaryDirectory TempDir;
+ {
+ std::string ExportSourceUri = Servers.GetInstance(0).GetBaseUri();
+ std::string ExportTargetUri = Servers.GetInstance(1).GetBaseUri();
+ MakeProject(Session, ExportTargetUri, "proj0_copy");
+ MakeOplog(Session, ExportTargetUri, "proj0_copy", "oplog0_copy");
+
+ std::string SaveOplogRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", ExportSourceUri, "proj0", "oplog0");
+ Session.SetUrl({SaveOplogRequest});
+
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
+ Writer << "method"sv
+ << "export"sv;
+ Writer << "params" << BeginObject;
+ {
+ Writer << "maxblocksize"sv << 3072u;
+ Writer << "maxchunkembedsize"sv << 1296u;
+ Writer << "force"sv << false;
+ Writer << "zen"sv << BeginObject;
+ {
+ Writer << "url"sv << ExportTargetUri.substr(7);
+ Writer << "project"
+ << "proj0_copy";
+ Writer << "oplog"
+ << "oplog0_copy";
+ }
+ Writer << EndObject; // "file"
+ }
+ Writer << EndObject; // "params"
+ });
+ Session.SetBody(AsBody(Payload));
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ }
+ ValidateAttachments(1, "proj0_copy", "oplog0_copy");
+ ValidateOplog(1, "proj0_copy", "oplog0_copy");
+
+ {
+ std::string ImportSourceUri = Servers.GetInstance(1).GetBaseUri();
+ std::string ImportTargetUri = Servers.GetInstance(2).GetBaseUri();
+ MakeProject(Session, ImportTargetUri, "proj1");
+ MakeOplog(Session, ImportTargetUri, "proj1", "oplog1");
+ std::string LoadOplogRequest = fmt::format("{}/prj/{}/oplog/{}/rpc", ImportTargetUri, "proj1", "oplog1");
+ Session.SetUrl({LoadOplogRequest});
+
+ IoBuffer Payload = MakeCbObjectPayload([&](CbObjectWriter& Writer) {
+ Writer << "method"sv
+ << "import"sv;
+ Writer << "params" << BeginObject;
+ {
+ Writer << "force"sv << false;
+ Writer << "zen"sv << BeginObject;
+ {
+ Writer << "url"sv << ImportSourceUri.substr(7);
+ Writer << "project"
+ << "proj0_copy";
+ Writer << "oplog"
+ << "oplog0_copy";
+ }
+ Writer << EndObject; // "file"
+ }
+ Writer << EndObject; // "params"
+ });
+ Session.SetBody(AsBody(Payload));
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}});
+ cpr::Response Response = Session.Post();
+ CHECK(IsHttpSuccessCode(Response.status_code));
+ }
+ ValidateAttachments(2, "proj1", "oplog1");
+ ValidateOplog(2, "proj1", "oplog1");
+ }
+}
+
+# if 0
+TEST_CASE("lifetime.owner")
+{
+ // This test is designed to verify that the hand-over of sponsor processes is handled
+ // correctly for the case when a second or third process is launched on the same port
+ //
+ // Due to the nature of it, it cannot be
+
+ const uint16_t PortNumber = 23456;
+
+ ZenServerInstance Zen1(TestEnv);
+ std::filesystem::path TestDir1 = TestEnv.CreateNewTestDir();
+ Zen1.SetTestDir(TestDir1);
+ Zen1.SpawnServer(PortNumber);
+ Zen1.WaitUntilReady();
+ Zen1.Detach();
+
+ ZenServerInstance Zen2(TestEnv);
+ std::filesystem::path TestDir2 = TestEnv.CreateNewTestDir();
+ Zen2.SetTestDir(TestDir2);
+ Zen2.SpawnServer(PortNumber);
+ Zen2.WaitUntilReady();
+ Zen2.Detach();
+}
+
+TEST_CASE("lifetime.owner.2")
+{
+ // This test is designed to verify that the hand-over of sponsor processes is handled
+ // correctly for the case when a second or third process is launched on the same port
+ //
+ // Due to the nature of it, it cannot be
+
+ const uint16_t PortNumber = 13456;
+
+ std::filesystem::path TestDir1 = TestEnv.CreateNewTestDir();
+ std::filesystem::path TestDir2 = TestEnv.CreateNewTestDir();
+
+ ZenServerInstance Zen1(TestEnv);
+ Zen1.SetTestDir(TestDir1);
+ Zen1.SpawnServer(PortNumber);
+ Zen1.WaitUntilReady();
+
+ ZenServerInstance Zen2(TestEnv);
+ Zen2.SetTestDir(TestDir2);
+ Zen2.SetOwnerPid(Zen1.GetPid());
+ Zen2.SpawnServer(PortNumber + 1);
+ Zen2.Detach();
+
+ ZenServerInstance Zen3(TestEnv);
+ Zen3.SetTestDir(TestDir2);
+ Zen3.SetOwnerPid(Zen1.GetPid());
+ Zen3.SpawnServer(PortNumber + 1);
+ Zen3.Detach();
+
+ ZenServerInstance Zen4(TestEnv);
+ Zen4.SetTestDir(TestDir2);
+ Zen4.SetOwnerPid(Zen1.GetPid());
+ Zen4.SpawnServer(PortNumber + 1);
+ Zen4.Detach();
+}
+# endif
+
+} // namespace zen::tests
+#else
+int
+main()
+{
+}
+#endif
diff --git a/src/zenserver/admin/admin.cpp b/src/zenserver/admin/admin.cpp
new file mode 100644
index 000000000..7aa1b48d1
--- /dev/null
+++ b/src/zenserver/admin/admin.cpp
@@ -0,0 +1,101 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "admin.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/string.h>
+#include <zenstore/gc.h>
+
+#include <chrono>
+
+namespace zen {
+
+HttpAdminService::HttpAdminService(GcScheduler& Scheduler) : m_GcScheduler(Scheduler)
+{
+ using namespace std::literals;
+
+ m_Router.RegisterRoute(
+ "health",
+ [](HttpRouterRequest& Req) {
+ CbObjectWriter Obj;
+ Obj.AddBool("ok", true);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "gc",
+ [this](HttpRouterRequest& Req) {
+ const GcSchedulerStatus Status = m_GcScheduler.Status();
+
+ CbObjectWriter Response;
+ Response << "Status"sv << (GcSchedulerStatus::kIdle == Status ? "Idle"sv : "Running"sv);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Response.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "gc",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+ GcScheduler::TriggerParams GcParams;
+
+ if (auto Param = Params.GetValue("smallobjects"); Param.empty() == false)
+ {
+ GcParams.CollectSmallObjects = Param == "true"sv;
+ }
+
+ if (auto Param = Params.GetValue("maxcacheduration"); Param.empty() == false)
+ {
+ if (auto Value = ParseInt<uint64_t>(Param))
+ {
+ GcParams.MaxCacheDuration = std::chrono::seconds(Value.value());
+ }
+ }
+
+ if (auto Param = Params.GetValue("disksizesoftlimit"); Param.empty() == false)
+ {
+ if (auto Value = ParseInt<uint64_t>(Param))
+ {
+ GcParams.DiskSizeSoftLimit = Value.value();
+ }
+ }
+
+ const bool Started = m_GcScheduler.Trigger(GcParams);
+
+ CbObjectWriter Response;
+ Response << "Status"sv << (Started ? "Started"sv : "Running"sv);
+ HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save());
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "",
+ [](HttpRouterRequest& Req) {
+ CbObject Payload = Req.ServerRequest().ReadPayloadObject();
+
+ CbObjectWriter Obj;
+ Obj.AddBool("ok", true);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kPost);
+}
+
+HttpAdminService::~HttpAdminService()
+{
+}
+
+const char*
+HttpAdminService::BaseUri() const
+{
+ return "/admin/";
+}
+
+void
+HttpAdminService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ m_Router.HandleRequest(Request);
+}
+
+} // namespace zen
diff --git a/src/zenserver/admin/admin.h b/src/zenserver/admin/admin.h
new file mode 100644
index 000000000..9463ffbb3
--- /dev/null
+++ b/src/zenserver/admin/admin.h
@@ -0,0 +1,26 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+class GcScheduler;
+
+class HttpAdminService : public zen::HttpService
+{
+public:
+ HttpAdminService(GcScheduler& Scheduler);
+ ~HttpAdminService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ HttpRequestRouter m_Router;
+ GcScheduler& m_GcScheduler;
+};
+
+} // namespace zen
diff --git a/src/zenserver/auth/authmgr.cpp b/src/zenserver/auth/authmgr.cpp
new file mode 100644
index 000000000..4cd6b3362
--- /dev/null
+++ b/src/zenserver/auth/authmgr.cpp
@@ -0,0 +1,506 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <auth/authmgr.h>
+#include <auth/oidc.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/crypto.h>
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+
+#include <condition_variable>
+#include <memory>
+#include <shared_mutex>
+#include <thread>
+#include <unordered_map>
+
+#include <fmt/format.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace details {
+ IoBuffer ReadEncryptedFile(std::filesystem::path Path,
+ const AesKey256Bit& Key,
+ const AesIV128Bit& IV,
+ std::optional<std::string>& Reason)
+ {
+ FileContents Result = ReadFile(Path);
+
+ if (Result.ErrorCode)
+ {
+ return IoBuffer();
+ }
+
+ IoBuffer EncryptedBuffer = Result.Flatten();
+
+ if (EncryptedBuffer.GetSize() == 0)
+ {
+ return IoBuffer();
+ }
+
+ std::vector<uint8_t> DecryptionBuffer;
+ DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize);
+
+ MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason);
+
+ if (DecryptedView.IsEmpty())
+ {
+ return IoBuffer();
+ }
+
+ return IoBufferBuilder::MakeCloneFromMemory(DecryptedView);
+ }
+
+ void WriteEncryptedFile(std::filesystem::path Path,
+ IoBuffer FileData,
+ const AesKey256Bit& Key,
+ const AesIV128Bit& IV,
+ std::optional<std::string>& Reason)
+ {
+ if (FileData.GetSize() == 0)
+ {
+ return;
+ }
+
+ std::vector<uint8_t> EncryptionBuffer;
+ EncryptionBuffer.resize(FileData.GetSize() + Aes::BlockSize);
+
+ MemoryView EncryptedView = Aes::Encrypt(Key, IV, FileData, MakeMutableMemoryView(EncryptionBuffer), Reason);
+
+ if (EncryptedView.IsEmpty())
+ {
+ return;
+ }
+
+ WriteFile(Path, IoBuffer(IoBuffer::Wrap, EncryptedView.GetData(), EncryptedView.GetSize()));
+ }
+} // namespace details
+
+class AuthMgrImpl final : public AuthMgr
+{
+ using Clock = std::chrono::system_clock;
+ using TimePoint = Clock::time_point;
+ using Seconds = std::chrono::seconds;
+
+public:
+ AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth"))
+ {
+ LoadState();
+
+ m_BackgroundThread.Interval = Config.UpdateInterval;
+ m_BackgroundThread.Thread = std::thread(&AuthMgrImpl::BackgroundThreadEntry, this);
+ }
+
+ virtual ~AuthMgrImpl() { Shutdown(); }
+
+ virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
+ {
+ if (OpenIdProviderExist(Params.Name))
+ {
+ ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name);
+ return;
+ }
+
+ if (Params.Name.empty())
+ {
+ ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'");
+ return;
+ }
+
+ std::unique_ptr<OidcClient> Client =
+ std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});
+
+ if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
+ {
+ ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason);
+ return;
+ }
+
+ std::string NewProviderName = std::string(Params.Name);
+
+ OpenIdProvider* NewProvider = nullptr;
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.contains(NewProviderName))
+ {
+ return;
+ }
+
+ auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique<OpenIdProvider>());
+ NewProvider = InsertResult.first->second.get();
+ }
+
+ NewProvider->Name = std::string(Params.Name);
+ NewProvider->Url = std::string(Params.Url);
+ NewProvider->ClientId = std::string(Params.ClientId);
+ NewProvider->HttpClient = std::move(Client);
+
+ ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url);
+ }
+
+ virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final
+ {
+ if (Params.ProviderName.empty())
+ {
+ ZEN_WARN("trying add OpenID token with invalid provider name");
+ return false;
+ }
+
+ if (Params.RefreshToken.empty())
+ {
+ ZEN_WARN("add OpenID token FAILED, reason 'Token invalid'");
+ return false;
+ }
+
+ auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken);
+
+ if (RefreshResult.Ok == false)
+ {
+ ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason);
+ return false;
+ }
+
+ bool IsNew = false;
+
+ {
+ auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
+ .RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
+ .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
+
+ std::unique_lock _(m_TokenMutex);
+
+ const auto InsertResult = m_OpenIdTokens.insert_or_assign(std::string(Params.ProviderName), std::move(Token));
+
+ IsNew = InsertResult.second;
+ }
+
+ if (IsNew)
+ {
+ ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName);
+ }
+ else
+ {
+ ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName);
+ }
+
+ return true;
+ }
+
+ virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end())
+ {
+ const OpenIdToken& Token = It->second;
+
+ return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ }
+
+ return {};
+ }
+
+private:
+ bool OpenIdProviderExist(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ return m_OpenIdProviders.contains(std::string(ProviderName));
+ }
+
+ OidcClient& GetOpenIdClient(std::string_view ProviderName)
+ {
+ std::unique_lock _(m_ProviderMutex);
+ return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
+ }
+
+ OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
+ {
+ if (OpenIdProviderExist(ProviderName) == false)
+ {
+ return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
+ }
+
+ OidcClient& Client = GetOpenIdClient(ProviderName);
+
+ return Client.RefreshToken(RefreshToken);
+ }
+
+ void Shutdown()
+ {
+ BackgroundThread::Stop(m_BackgroundThread);
+ SaveState();
+ }
+
+ void LoadState()
+ {
+ try
+ {
+ std::optional<std::string> Reason;
+
+ IoBuffer Buffer =
+ details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason);
+
+ if (!Buffer)
+ {
+ if (Reason)
+ {
+ ZEN_WARN("load auth state FAILED, reason '{}'", Reason.value());
+ }
+
+ return;
+ }
+
+ const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);
+
+ if (ValidationError != CbValidateError::None)
+ {
+ ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
+ return;
+ }
+
+ if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
+ {
+ for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
+ {
+ CbObjectView ProviderObj = ProviderView.AsObjectView();
+
+ std::string_view ProviderName = ProviderObj["Name"].AsString();
+ std::string_view Url = ProviderObj["Url"].AsString();
+ std::string_view ClientId = ProviderObj["ClientId"].AsString();
+
+ AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId});
+ }
+
+ for (CbFieldView TokenView : AuthState["OpenIdTokens"sv])
+ {
+ CbObjectView TokenObj = TokenView.AsObjectView();
+
+ std::string_view ProviderName = TokenObj["ProviderName"sv].AsString();
+ std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString();
+
+ const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken});
+
+ if (!Ok)
+ {
+ ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName);
+ }
+ }
+ }
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what());
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+ m_OpenIdProviders.clear();
+ }
+
+ {
+ std::unique_lock _(m_TokenMutex);
+ m_OpenIdTokens.clear();
+ }
+ }
+ }
+
+ void SaveState()
+ {
+ try
+ {
+ CbObjectWriter AuthState;
+
+ {
+ std::unique_lock _(m_ProviderMutex);
+
+ if (m_OpenIdProviders.size() > 0)
+ {
+ AuthState.BeginArray("OpenIdProviders");
+ for (const auto& Kv : m_OpenIdProviders)
+ {
+ AuthState.BeginObject();
+ AuthState << "Name"sv << Kv.second->Name;
+ AuthState << "Url"sv << Kv.second->Url;
+ AuthState << "ClientId"sv << Kv.second->ClientId;
+ AuthState.EndObject();
+ }
+ AuthState.EndArray();
+ }
+ }
+
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ AuthState.BeginArray("OpenIdTokens");
+ if (m_OpenIdTokens.size() > 0)
+ {
+ for (const auto& Kv : m_OpenIdTokens)
+ {
+ AuthState.BeginObject();
+ AuthState << "ProviderName"sv << Kv.first;
+ AuthState << "RefreshToken"sv << Kv.second.RefreshToken;
+ AuthState.EndObject();
+ }
+ }
+ AuthState.EndArray();
+ }
+
+ std::filesystem::create_directories(m_Config.RootDirectory);
+
+ std::optional<std::string> Reason;
+
+ details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv,
+ AuthState.Save().GetBuffer().AsIoBuffer(),
+ m_Config.EncryptionKey,
+ m_Config.EncryptionIV,
+ Reason);
+
+ if (Reason)
+ {
+ ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value());
+ }
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("serialize state FAILED, reason '{}'", Err.what());
+ }
+ }
+
+ void BackgroundThreadEntry()
+ {
+ for (;;)
+ {
+ std::cv_status SignalStatus = BackgroundThread::WaitForSignal(m_BackgroundThread);
+
+ if (m_BackgroundThread.Running.load() == false)
+ {
+ break;
+ }
+
+ if (SignalStatus != std::cv_status::timeout)
+ {
+ continue;
+ }
+
+ {
+ // Refresh Open ID token(s)
+
+ std::vector<OpenIdTokenMap::value_type> ExpiredTokens;
+
+ {
+ std::unique_lock _(m_TokenMutex);
+
+ for (const auto& Kv : m_OpenIdTokens)
+ {
+ const Seconds ExpiresIn = std::chrono::duration_cast<Seconds>(Kv.second.ExpireTime - Clock::now());
+ const bool Expired = ExpiresIn < Seconds(m_BackgroundThread.Interval * 2);
+
+ if (Expired)
+ {
+ ExpiredTokens.push_back(Kv);
+ }
+ }
+ }
+
+ ZEN_DEBUG("refreshing '{}' OpenID token(s)", ExpiredTokens.size());
+
+ for (const auto& Kv : ExpiredTokens)
+ {
+ OidcClient::RefreshTokenResult RefreshResult = RefreshOpenIdToken(Kv.first, Kv.second.RefreshToken);
+
+ if (RefreshResult.Ok)
+ {
+ ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first);
+
+ auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
+ .RefreshToken = RefreshResult.RefreshToken,
+ .AccessToken = fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
+ .ExpireTime = Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};
+
+ {
+ std::unique_lock _(m_TokenMutex);
+ m_OpenIdTokens.insert_or_assign(Kv.first, std::move(Token));
+ }
+ }
+ else
+ {
+ ZEN_WARN("refresh access token from provider '{}' FAILED, reason '{}'", Kv.first, RefreshResult.Reason);
+ }
+ }
+ }
+ }
+ }
+
+ struct BackgroundThread
+ {
+ std::chrono::seconds Interval{10};
+ std::mutex Mutex;
+ std::condition_variable Signal;
+ std::atomic_bool Running{true};
+ std::thread Thread;
+
+ static void Stop(BackgroundThread& State)
+ {
+ if (State.Running.load())
+ {
+ State.Running.store(false);
+ State.Signal.notify_one();
+ }
+
+ if (State.Thread.joinable())
+ {
+ State.Thread.join();
+ }
+ }
+
+ static std::cv_status WaitForSignal(BackgroundThread& State)
+ {
+ std::unique_lock Lock(State.Mutex);
+ return State.Signal.wait_for(Lock, State.Interval);
+ }
+ };
+
+ struct OpenIdProvider
+ {
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+ std::unique_ptr<OidcClient> HttpClient;
+ };
+
+ struct OpenIdToken
+ {
+ std::string IdentityToken;
+ std::string RefreshToken;
+ std::string AccessToken;
+ TimePoint ExpireTime{};
+ };
+
+ using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
+ using OpenIdTokenMap = std::unordered_map<std::string, OpenIdToken>;
+
+ spdlog::logger& Log() { return m_Log; }
+
+ AuthConfig m_Config;
+ spdlog::logger& m_Log;
+ BackgroundThread m_BackgroundThread;
+ OpenIdProviderMap m_OpenIdProviders;
+ OpenIdTokenMap m_OpenIdTokens;
+ std::mutex m_ProviderMutex;
+ std::shared_mutex m_TokenMutex;
+};
+
+std::unique_ptr<AuthMgr>
+AuthMgr::Create(const AuthConfig& Config)
+{
+ return std::make_unique<AuthMgrImpl>(Config);
+}
+
+} // namespace zen
diff --git a/src/zenserver/auth/authmgr.h b/src/zenserver/auth/authmgr.h
new file mode 100644
index 000000000..054588ab9
--- /dev/null
+++ b/src/zenserver/auth/authmgr.h
@@ -0,0 +1,56 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/crypto.h>
+#include <zencore/iobuffer.h>
+#include <zencore/string.h>
+
+#include <chrono>
+#include <filesystem>
+#include <memory>
+
+namespace zen {
+
+struct AuthConfig
+{
+ std::filesystem::path RootDirectory;
+ std::chrono::seconds UpdateInterval{30};
+ AesKey256Bit EncryptionKey;
+ AesIV128Bit EncryptionIV;
+};
+
+class AuthMgr
+{
+public:
+ virtual ~AuthMgr() = default;
+
+ struct AddOpenIdProviderParams
+ {
+ std::string_view Name;
+ std::string_view Url;
+ std::string_view ClientId;
+ };
+
+ virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) = 0;
+
+ struct AddOpenIdTokenParams
+ {
+ std::string_view ProviderName;
+ std::string_view RefreshToken;
+ };
+
+ virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) = 0;
+
+ struct OpenIdAccessToken
+ {
+ std::string AccessToken;
+ std::chrono::system_clock::time_point ExpireTime{};
+ };
+
+ virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) = 0;
+
+ static std::unique_ptr<AuthMgr> Create(const AuthConfig& Config);
+};
+
+} // namespace zen
diff --git a/src/zenserver/auth/authservice.cpp b/src/zenserver/auth/authservice.cpp
new file mode 100644
index 000000000..1cc679540
--- /dev/null
+++ b/src/zenserver/auth/authservice.cpp
@@ -0,0 +1,91 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <auth/authservice.h>
+
+#include <auth/authmgr.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+HttpAuthService::HttpAuthService(AuthMgr& AuthMgr) : m_AuthMgr(AuthMgr)
+{
+ m_Router.RegisterRoute(
+ "oidc/refreshtoken",
+ [this](HttpRouterRequest& RouterRequest) {
+ HttpServerRequest& ServerRequest = RouterRequest.ServerRequest();
+
+ const HttpContentType ContentType = ServerRequest.RequestContentType();
+
+ if ((ContentType == HttpContentType::kUnknownContentType || ContentType == HttpContentType::kJSON) == false)
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ const IoBuffer Body = ServerRequest.ReadPayload();
+
+ std::string JsonText(reinterpret_cast<const char*>(Body.GetData()), Body.GetSize());
+ std::string JsonError;
+ json11::Json TokenInfo = json11::Json::parse(JsonText, JsonError);
+
+ if (!JsonError.empty())
+ {
+ CbObjectWriter Response;
+ Response << "Result"sv << false;
+ Response << "Error"sv << JsonError;
+
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save());
+ }
+
+ const std::string RefreshToken = TokenInfo["RefreshToken"].string_value();
+ std::string ProviderName = TokenInfo["ProviderName"].string_value();
+
+ if (ProviderName.empty())
+ {
+ ProviderName = "Default"sv;
+ }
+
+ const bool Ok =
+ m_AuthMgr.AddOpenIdToken(AuthMgr::AddOpenIdTokenParams{.ProviderName = ProviderName, .RefreshToken = RefreshToken});
+
+ if (Ok)
+ {
+ ServerRequest.WriteResponse(Ok ? HttpResponseCode::OK : HttpResponseCode::BadRequest);
+ }
+ else
+ {
+ CbObjectWriter Response;
+ Response << "Result"sv << false;
+ Response << "Error"sv
+ << "Invalid token"sv;
+
+ ServerRequest.WriteResponse(HttpResponseCode::BadRequest, Response.Save());
+ }
+ },
+ HttpVerb::kPost);
+}
+
+HttpAuthService::~HttpAuthService()
+{
+}
+
+const char*
+HttpAuthService::BaseUri() const
+{
+ return "/auth/";
+}
+
+void
+HttpAuthService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ m_Router.HandleRequest(Request);
+}
+
+} // namespace zen
diff --git a/src/zenserver/auth/authservice.h b/src/zenserver/auth/authservice.h
new file mode 100644
index 000000000..64b86e21f
--- /dev/null
+++ b/src/zenserver/auth/authservice.h
@@ -0,0 +1,25 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+class AuthMgr;
+
+class HttpAuthService final : public zen::HttpService
+{
+public:
+ HttpAuthService(AuthMgr& AuthMgr);
+ virtual ~HttpAuthService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ AuthMgr& m_AuthMgr;
+ HttpRequestRouter m_Router;
+};
+
+} // namespace zen
diff --git a/src/zenserver/auth/oidc.cpp b/src/zenserver/auth/oidc.cpp
new file mode 100644
index 000000000..d2265c22f
--- /dev/null
+++ b/src/zenserver/auth/oidc.cpp
@@ -0,0 +1,127 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <auth/oidc.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <fmt/format.h>
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+namespace details {
+
+ using StringArray = std::vector<std::string>;
+
+ StringArray ToStringArray(const json11::Json JsonArray)
+ {
+ StringArray Result;
+
+ const auto& Items = JsonArray.array_items();
+
+ for (const auto& Item : Items)
+ {
+ Result.push_back(Item.string_value());
+ }
+
+ return Result;
+ }
+
+} // namespace details
+
+using namespace std::literals;
+
+OidcClient::OidcClient(const OidcClient::Options& Options)
+{
+ m_BaseUrl = std::string(Options.BaseUrl);
+ m_ClientId = std::string(Options.ClientId);
+}
+
+OidcClient::InitResult
+OidcClient::Initialize()
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_BaseUrl << "/.well-known/openid-configuration"sv;
+
+ cpr::Session Session;
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+
+ cpr::Response Response = Session.Get();
+
+ if (Response.error)
+ {
+ return {.Reason = std::move(Response.error.message)};
+ }
+
+ if (Response.status_code != 200)
+ {
+ return {.Reason = std::move(Response.reason)};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {.Reason = std::move(JsonError)};
+ }
+
+ m_Config = {.Issuer = Json["issuer"].string_value(),
+ .AuthorizationEndpoint = Json["authorization_endpoint"].string_value(),
+ .TokenEndpoint = Json["token_endpoint"].string_value(),
+ .UserInfoEndpoint = Json["userinfo_endpoint"].string_value(),
+ .RegistrationEndpoint = Json["registration_endpoint"].string_value(),
+ .JwksUri = Json["jwks_uri"].string_value(),
+ .SupportedResponseTypes = details::ToStringArray(Json["response_types_supported"]),
+ .SupportedResponseModes = details::ToStringArray(Json["response_modes_supported"]),
+ .SupportedGrantTypes = details::ToStringArray(Json["grant_types_supported"]),
+ .SupportedScopes = details::ToStringArray(Json["scopes_supported"]),
+ .SupportedTokenEndpointAuthMethods = details::ToStringArray(Json["token_endpoint_auth_methods_supported"]),
+ .SupportedClaims = details::ToStringArray(Json["claims_supported"])};
+
+ return {.Ok = true};
+}
+
+OidcClient::RefreshTokenResult
+OidcClient::RefreshToken(std::string_view RefreshToken)
+{
+ const std::string Body = fmt::format("grant_type=refresh_token&refresh_token={}&client_id={}", RefreshToken, m_ClientId);
+
+ cpr::Session Session;
+
+ Session.SetOption(cpr::Url{m_Config.TokenEndpoint.c_str()});
+ Session.SetOption(cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}});
+ Session.SetBody(cpr::Body{Body.data(), Body.size()});
+
+ cpr::Response Response = Session.Post();
+
+ if (Response.error)
+ {
+ return {.Reason = std::move(Response.error.message)};
+ }
+
+ if (Response.status_code != 200)
+ {
+ return {.Reason = fmt::format("{} ({})", Response.reason, Response.text)};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {.Reason = std::move(JsonError)};
+ }
+
+ return {.TokenType = Json["token_type"].string_value(),
+ .AccessToken = Json["access_token"].string_value(),
+ .RefreshToken = Json["refresh_token"].string_value(),
+ .IdentityToken = Json["id_token"].string_value(),
+ .Scope = Json["scope"].string_value(),
+ .ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value()),
+ .Ok = true};
+}
+
+} // namespace zen
diff --git a/src/zenserver/auth/oidc.h b/src/zenserver/auth/oidc.h
new file mode 100644
index 000000000..f43ae3cd7
--- /dev/null
+++ b/src/zenserver/auth/oidc.h
@@ -0,0 +1,76 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/string.h>
+
+#include <vector>
+
+namespace zen {
+
+class OidcClient
+{
+public:
+ struct Options
+ {
+ std::string_view BaseUrl;
+ std::string_view ClientId;
+ };
+
+ OidcClient(const Options& Options);
+ ~OidcClient() = default;
+
+ OidcClient(const OidcClient&) = delete;
+ OidcClient& operator=(const OidcClient&) = delete;
+
+ struct Result
+ {
+ std::string Reason;
+ bool Ok = false;
+ };
+
+ using InitResult = Result;
+
+ InitResult Initialize();
+
+ struct RefreshTokenResult
+ {
+ std::string TokenType;
+ std::string AccessToken;
+ std::string RefreshToken;
+ std::string IdentityToken;
+ std::string Scope;
+ std::string Reason;
+ int64_t ExpiresInSeconds{};
+ bool Ok = false;
+ };
+
+ RefreshTokenResult RefreshToken(std::string_view RefreshToken);
+
+private:
+ using StringArray = std::vector<std::string>;
+
+ struct OpenIdConfiguration
+ {
+ std::string Issuer;
+ std::string AuthorizationEndpoint;
+ std::string TokenEndpoint;
+ std::string UserInfoEndpoint;
+ std::string RegistrationEndpoint;
+ std::string EndSessionEndpoint;
+ std::string DeviceAuthorizationEndpoint;
+ std::string JwksUri;
+ StringArray SupportedResponseTypes;
+ StringArray SupportedResponseModes;
+ StringArray SupportedGrantTypes;
+ StringArray SupportedScopes;
+ StringArray SupportedTokenEndpointAuthMethods;
+ StringArray SupportedClaims;
+ };
+
+ std::string m_BaseUrl;
+ std::string m_ClientId;
+ OpenIdConfiguration m_Config;
+};
+
+} // namespace zen
diff --git a/src/zenserver/cache/cachetracking.cpp b/src/zenserver/cache/cachetracking.cpp
new file mode 100644
index 000000000..9119e3122
--- /dev/null
+++ b/src/zenserver/cache/cachetracking.cpp
@@ -0,0 +1,376 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "cachetracking.h"
+
+#if ZEN_USE_CACHE_TRACKER
+
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinaryvalue.h>
+# include <zencore/endian.h>
+# include <zencore/filesystem.h>
+# include <zencore/logging.h>
+# include <zencore/scopeguard.h>
+# include <zencore/string.h>
+
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+# pragma comment(lib, "Rpcrt4.lib") // RocksDB made me do this
+# include <fmt/format.h>
+# include <rocksdb/db.h>
+# include <tsl/robin_map.h>
+# include <tsl/robin_set.h>
+# include <gsl/gsl-lite.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+namespace rocksdb = ROCKSDB_NAMESPACE;
+
+static constinit auto Epoch = std::chrono::time_point<std::chrono::system_clock>{};
+
+static uint64_t
+GetCurrentCacheTimeStamp()
+{
+ auto Duration = std::chrono::system_clock::now() - Epoch;
+ uint64_t Millis = std::chrono::duration_cast<std::chrono::milliseconds>(Duration).count();
+
+ return Millis;
+}
+
+struct CacheAccessSnapshot
+{
+public:
+ void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey)
+ {
+ BucketTracker* Tracker = GetBucket(std::string(BucketSegment));
+
+ Tracker->Track(HashKey);
+ }
+
+ bool SerializeSnapshot(CbObjectWriter& Cbo)
+ {
+ bool Serialized = false;
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ for (const auto& Kv : m_BucketMapping)
+ {
+ if (m_Buckets[Kv.second]->Size())
+ {
+ Cbo.BeginArray(Kv.first);
+ m_Buckets[Kv.second]->SerializeSnapshotAndClear(Cbo);
+ Cbo.EndArray();
+ Serialized = true;
+ }
+ }
+
+ return Serialized;
+ }
+
+private:
+ struct BucketTracker
+ {
+ mutable RwLock Lock;
+ tsl::robin_set<IoHash> AccessedKeys;
+
+ void Track(const IoHash& HashKey)
+ {
+ if (RwLock::SharedLockScope _(Lock); AccessedKeys.contains(HashKey))
+ {
+ return;
+ }
+
+ RwLock::ExclusiveLockScope _(Lock);
+
+ AccessedKeys.insert(HashKey);
+ }
+
+ void SerializeSnapshotAndClear(CbObjectWriter& Cbo)
+ {
+ RwLock::ExclusiveLockScope _(Lock);
+
+ for (const IoHash& Hash : AccessedKeys)
+ {
+ Cbo.AddHash(Hash);
+ }
+
+ AccessedKeys.clear();
+ }
+
+ size_t Size() const
+ {
+ RwLock::SharedLockScope _(Lock);
+ return AccessedKeys.size();
+ }
+ };
+
+ BucketTracker* GetBucket(const std::string& BucketName)
+ {
+ RwLock::SharedLockScope _(m_Lock);
+
+ if (auto It = m_BucketMapping.find(BucketName); It == m_BucketMapping.end())
+ {
+ _.ReleaseNow();
+
+ return AddNewBucket(BucketName);
+ }
+ else
+ {
+ return m_Buckets[It->second].get();
+ }
+ }
+
+ BucketTracker* AddNewBucket(const std::string& BucketName)
+ {
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ if (auto It = m_BucketMapping.find(BucketName); It == m_BucketMapping.end())
+ {
+ const uint32_t BucketIndex = gsl::narrow<uint32_t>(m_Buckets.size());
+ m_Buckets.emplace_back(std::make_unique<BucketTracker>());
+ m_BucketMapping[BucketName] = BucketIndex;
+
+ return m_Buckets[BucketIndex].get();
+ }
+ else
+ {
+ return m_Buckets[It->second].get();
+ }
+ }
+
+ RwLock m_Lock;
+ std::vector<std::unique_ptr<BucketTracker>> m_Buckets;
+ tsl::robin_map<std::string, uint32_t> m_BucketMapping;
+};
+
+struct ZenCacheTracker::Impl
+{
+ Impl(std::filesystem::path StateDirectory)
+ {
+ std::filesystem::path StatsDbPath{StateDirectory / ".zdb"};
+
+ std::string RocksdbPath = StatsDbPath.string();
+
+ ZEN_DEBUG("opening tracker db at '{}'", RocksdbPath);
+
+ rocksdb::DB* Db = nullptr;
+ rocksdb::DBOptions Options;
+ Options.create_if_missing = true;
+
+ std::vector<std::string> ExistingColumnFamilies;
+ rocksdb::Status Status = rocksdb::DB::ListColumnFamilies(Options, RocksdbPath, &ExistingColumnFamilies);
+
+ std::vector<rocksdb::ColumnFamilyDescriptor> ColumnDescriptors;
+
+ if (Status.IsPathNotFound())
+ {
+ ColumnDescriptors.emplace_back(rocksdb::ColumnFamilyDescriptor{rocksdb::kDefaultColumnFamilyName, {}});
+ }
+ else if (Status.ok())
+ {
+ for (const std::string& Column : ExistingColumnFamilies)
+ {
+ rocksdb::ColumnFamilyDescriptor ColumnFamily;
+ ColumnFamily.name = Column;
+ ColumnDescriptors.push_back(ColumnFamily);
+ }
+ }
+ else
+ {
+ throw std::runtime_error(fmt::format("column family iteration failed for '{}': '{}'", RocksdbPath, Status.getState()).c_str());
+ }
+
+ Status = rocksdb::DB::Open(Options, RocksdbPath, ColumnDescriptors, &m_RocksDbColumnHandles, &Db);
+
+ if (!Status.ok())
+ {
+ throw std::runtime_error(fmt::format("database open failed for '{}': '{}'", RocksdbPath, Status.getState()).c_str());
+ }
+
+ m_RocksDb.reset(Db);
+ }
+
+ ~Impl()
+ {
+ for (auto* Column : m_RocksDbColumnHandles)
+ {
+ delete Column;
+ }
+
+ m_RocksDbColumnHandles.clear();
+ }
+
+ struct KeyStruct
+ {
+ uint64_t TimestampLittleEndian;
+ };
+
+ void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey) { m_CurrentSnapshot.TrackAccess(BucketSegment, HashKey); }
+
+ void SaveSnapshot()
+ {
+ CbObjectWriter Cbo;
+
+ if (m_CurrentSnapshot.SerializeSnapshot(Cbo))
+ {
+ IoBuffer SnapshotBuffer = Cbo.Save().GetBuffer().AsIoBuffer();
+
+ const KeyStruct Key{.TimestampLittleEndian = ToNetworkOrder(GetCurrentCacheTimeStamp())};
+ rocksdb::Slice KeySlice{(const char*)&Key, sizeof Key};
+ rocksdb::Slice ValueSlice{(char*)SnapshotBuffer.Data(), SnapshotBuffer.Size()};
+
+ rocksdb::WriteOptions Wo;
+ m_RocksDb->Put(Wo, KeySlice, ValueSlice);
+ }
+ }
+
+ void IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback)
+ {
+ rocksdb::ManagedSnapshot Snap(m_RocksDb.get());
+
+ rocksdb::ReadOptions Ro;
+ Ro.snapshot = Snap.snapshot();
+
+ std::unique_ptr<rocksdb::Iterator> It{m_RocksDb->NewIterator(Ro)};
+
+ const KeyStruct ZeroKey{.TimestampLittleEndian = 0};
+ rocksdb::Slice ZeroKeySlice{(const char*)&ZeroKey, sizeof ZeroKey};
+
+ It->Seek(ZeroKeySlice);
+
+ while (It->Valid())
+ {
+ rocksdb::Slice KeySlice = It->key();
+ rocksdb::Slice ValueSlice = It->value();
+
+ if (KeySlice.size() == sizeof(KeyStruct))
+ {
+ IoBuffer ValueBuffer(IoBuffer::Wrap, ValueSlice.data(), ValueSlice.size());
+
+ CbObject Value = LoadCompactBinaryObject(ValueBuffer);
+
+ uint64_t Key = FromNetworkOrder(*reinterpret_cast<const uint64_t*>(KeySlice.data()));
+
+ Callback(Key, Value);
+ }
+
+ It->Next();
+ }
+ }
+
+ std::unique_ptr<rocksdb::DB> m_RocksDb;
+ std::vector<rocksdb::ColumnFamilyHandle*> m_RocksDbColumnHandles;
+ CacheAccessSnapshot m_CurrentSnapshot;
+};
+
+ZenCacheTracker::ZenCacheTracker(std::filesystem::path StateDirectory) : m_Impl(new Impl(StateDirectory))
+{
+}
+
+ZenCacheTracker::~ZenCacheTracker()
+{
+ delete m_Impl;
+}
+
+void
+ZenCacheTracker::TrackAccess(std::string_view BucketSegment, const IoHash& HashKey)
+{
+ m_Impl->TrackAccess(BucketSegment, HashKey);
+}
+
+void
+ZenCacheTracker::SaveSnapshot()
+{
+ m_Impl->SaveSnapshot();
+}
+
+void
+ZenCacheTracker::IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback)
+{
+ m_Impl->IterateSnapshots(std::move(Callback));
+}
+
+# if ZEN_WITH_TESTS
+
+TEST_CASE("z$.tracker")
+{
+ using namespace std::literals;
+
+ const uint64_t t0 = GetCurrentCacheTimeStamp();
+
+ ScopedTemporaryDirectory TempDir;
+
+ ZenCacheTracker Zcs(TempDir.Path());
+
+ tsl::robin_set<IoHash> KeyHashes;
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ IoHash KeyHash = IoHash::HashBuffer(&i, sizeof i);
+
+ KeyHashes.insert(KeyHash);
+
+ Zcs.TrackAccess("foo"sv, KeyHash);
+ }
+
+ for (int i = 0; i < 10000; ++i)
+ {
+ IoHash KeyHash = IoHash::HashBuffer(&i, sizeof i);
+
+ Zcs.TrackAccess("foo"sv, KeyHash);
+ }
+
+ Zcs.SaveSnapshot();
+
+ for (int n = 0; n < 10; ++n)
+ {
+ for (int i = 0; i < 1000; ++i)
+ {
+ const int Index = i + n * 1000;
+ IoHash KeyHash = IoHash::HashBuffer(&Index, sizeof Index);
+
+ Zcs.TrackAccess("foo"sv, KeyHash);
+ }
+
+ Zcs.SaveSnapshot();
+ }
+
+ Zcs.SaveSnapshot();
+
+ const uint64_t t1 = GetCurrentCacheTimeStamp();
+
+ int SnapshotCount = 0;
+
+ Zcs.IterateSnapshots([&](uint64_t TimeStamp, CbObject Snapshot) {
+ CHECK(TimeStamp >= t0);
+ CHECK(TimeStamp <= t1);
+
+ for (auto& Field : Snapshot)
+ {
+ CHECK_EQ(Field.GetName(), "foo"sv);
+
+ const CbArray& Array = Field.AsArray();
+
+ for (const auto& Element : Array)
+ {
+ CHECK(KeyHashes.contains(Element.GetValue().AsHash()));
+ }
+ }
+
+ ++SnapshotCount;
+ });
+
+ CHECK_EQ(SnapshotCount, 11);
+}
+
+# endif
+
+void
+cachetracker_forcelink()
+{
+}
+
+} // namespace zen
+
+#endif // ZEN_USE_CACHE_TRACKER
diff --git a/src/zenserver/cache/cachetracking.h b/src/zenserver/cache/cachetracking.h
new file mode 100644
index 000000000..fdfe1a4c7
--- /dev/null
+++ b/src/zenserver/cache/cachetracking.h
@@ -0,0 +1,41 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iohash.h>
+
+#include <stdint.h>
+#include <filesystem>
+#include <functional>
+
+namespace zen {
+
+#define ZEN_USE_CACHE_TRACKER 0
+#if ZEN_USE_CACHE_TRACKER
+
+class CbObject;
+
+/**
+ */
+
+class ZenCacheTracker
+{
+public:
+ ZenCacheTracker(std::filesystem::path StateDirectory);
+ ~ZenCacheTracker();
+
+ void TrackAccess(std::string_view BucketSegment, const IoHash& HashKey);
+ void SaveSnapshot();
+ void IterateSnapshots(std::function<void(uint64_t TimeStamp, CbObject Snapshot)>&& Callback);
+
+private:
+ struct Impl;
+
+ Impl* m_Impl = nullptr;
+};
+
+void cachetracker_forcelink();
+
+#endif // ZEN_USE_CACHE_TRACKER
+
+} // namespace zen
diff --git a/src/zenserver/cache/structuredcache.cpp b/src/zenserver/cache/structuredcache.cpp
new file mode 100644
index 000000000..90e905bf6
--- /dev/null
+++ b/src/zenserver/cache/structuredcache.cpp
@@ -0,0 +1,3159 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "structuredcache.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/compress.h>
+#include <zencore/enumflags.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+#include <zencore/workthreadpool.h>
+#include <zenhttp/httpserver.h>
+#include <zenhttp/httpshared.h>
+#include <zenutil/cache/cache.h>
+#include <zenutil/cache/rpcrecording.h>
+
+#include "monitoring/httpstats.h"
+#include "structuredcachestore.h"
+#include "upstream/jupiter.h"
+#include "upstream/upstreamcache.h"
+#include "upstream/zen.h"
+#include "zenstore/cidstore.h"
+#include "zenstore/scrubcontext.h"
+
+#include <algorithm>
+#include <atomic>
+#include <filesystem>
+#include <queue>
+#include <thread>
+
+#include <cpr/cpr.h>
+#include <gsl/gsl-lite.hpp>
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+#endif
+
+namespace zen {
+
+using namespace std::literals;
+
+//////////////////////////////////////////////////////////////////////////
+
+CachePolicy
+ParseCachePolicy(const HttpServerRequest::QueryParams& QueryParams)
+{
+ std::string_view PolicyText = QueryParams.GetValue("Policy"sv);
+ return !PolicyText.empty() ? zen::ParseCachePolicy(PolicyText) : CachePolicy::Default;
+}
+
+CacheRecordPolicy
+LoadCacheRecordPolicy(CbObjectView Object, CachePolicy DefaultPolicy = CachePolicy::Default)
+{
+ OptionalCacheRecordPolicy Policy = CacheRecordPolicy::Load(Object);
+ return Policy ? std::move(Policy).Get() : CacheRecordPolicy(DefaultPolicy);
+}
+
+struct AttachmentCount
+{
+ uint32_t New = 0;
+ uint32_t Valid = 0;
+ uint32_t Invalid = 0;
+ uint32_t Total = 0;
+};
+
+struct PutRequestData
+{
+ std::string Namespace;
+ CacheKey Key;
+ CbObjectView RecordObject;
+ CacheRecordPolicy Policy;
+};
+
+namespace {
+ static constinit std::string_view HttpZCacheRPCPrefix = "$rpc"sv;
+ static constinit std::string_view HttpZCacheUtilStartRecording = "exec$/start-recording"sv;
+ static constinit std::string_view HttpZCacheUtilStopRecording = "exec$/stop-recording"sv;
+ static constinit std::string_view HttpZCacheUtilReplayRecording = "exec$/replay-recording"sv;
+ static constinit std::string_view HttpZCacheDetailsPrefix = "details$"sv;
+
+ struct HttpRequestData
+ {
+ std::optional<std::string> Namespace;
+ std::optional<std::string> Bucket;
+ std::optional<IoHash> HashKey;
+ std::optional<IoHash> ValueContentId;
+ };
+
+ constinit AsciiSet ValidNamespaceNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+ constinit AsciiSet ValidBucketNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+
+ std::optional<std::string> GetValidNamespaceName(std::string_view Name)
+ {
+ if (Name.empty())
+ {
+ ZEN_WARN("Namespace is invalid, empty namespace is not allowed");
+ return {};
+ }
+
+ if (Name.length() > 64)
+ {
+ ZEN_WARN("Namespace '{}' is invalid, length exceeds 64 characters", Name);
+ return {};
+ }
+
+ if (!AsciiSet::HasOnly(Name, ValidNamespaceNameCharactersSet))
+ {
+ ZEN_WARN("Namespace '{}' is invalid, invalid characters detected", Name);
+ return {};
+ }
+
+ return ToLower(Name);
+ }
+
+ std::optional<std::string> GetValidBucketName(std::string_view Name)
+ {
+ if (Name.empty())
+ {
+ ZEN_WARN("Bucket name is invalid, empty bucket name is not allowed");
+ return {};
+ }
+
+ if (!AsciiSet::HasOnly(Name, ValidBucketNameCharactersSet))
+ {
+ ZEN_WARN("Bucket name '{}' is invalid, invalid characters detected", Name);
+ return {};
+ }
+
+ return ToLower(Name);
+ }
+
+ std::optional<IoHash> GetValidIoHash(std::string_view Hash)
+ {
+ if (Hash.length() != IoHash::StringLength)
+ {
+ return {};
+ }
+
+ IoHash KeyHash;
+ if (!ParseHexBytes(Hash.data(), Hash.size(), KeyHash.Hash))
+ {
+ return {};
+ }
+ return KeyHash;
+ }
+
+ bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data)
+ {
+ std::vector<std::string_view> Tokens;
+ uint32_t TokenCount = ForEachStrTok(Key, '/', [&](const std::string_view& Token) {
+ Tokens.push_back(Token);
+ return true;
+ });
+
+ switch (TokenCount)
+ {
+ case 0:
+ return true;
+ case 1:
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ return Data.Namespace.has_value();
+ case 2:
+ {
+ std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]);
+ if (PossibleHashKey.has_value())
+ {
+ // Legacy bucket/key request
+ Data.Bucket = GetValidBucketName(Tokens[0]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = PossibleHashKey;
+ Data.Namespace = ZenCacheStore::DefaultNamespace;
+ return true;
+ }
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ return true;
+ }
+ case 3:
+ {
+ std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]);
+ if (PossibleHashKey.has_value())
+ {
+ // Legacy bucket/key/valueid request
+ Data.Bucket = GetValidBucketName(Tokens[0]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = PossibleHashKey;
+ Data.ValueContentId = GetValidIoHash(Tokens[2]);
+ if (!Data.ValueContentId.has_value())
+ {
+ return false;
+ }
+ Data.Namespace = ZenCacheStore::DefaultNamespace;
+ return true;
+ }
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = GetValidIoHash(Tokens[2]);
+ if (!Data.HashKey)
+ {
+ return false;
+ }
+ return true;
+ }
+ case 4:
+ {
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+
+ Data.HashKey = GetValidIoHash(Tokens[2]);
+ if (!Data.HashKey.has_value())
+ {
+ return false;
+ }
+
+ Data.ValueContentId = GetValidIoHash(Tokens[3]);
+ if (!Data.ValueContentId.has_value())
+ {
+ return false;
+ }
+ return true;
+ }
+ default:
+ return false;
+ }
+ }
+
+ std::optional<std::string> GetRpcRequestNamespace(const CbObjectView Params)
+ {
+ CbFieldView NamespaceField = Params["Namespace"sv];
+ if (!NamespaceField)
+ {
+ return std::string(ZenCacheStore::DefaultNamespace);
+ }
+
+ if (NamespaceField.HasError())
+ {
+ return {};
+ }
+ if (!NamespaceField.IsString())
+ {
+ return {};
+ }
+ return GetValidNamespaceName(NamespaceField.AsString());
+ }
+
+ bool GetRpcRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key)
+ {
+ CbFieldView BucketField = KeyView["Bucket"sv];
+ if (BucketField.HasError())
+ {
+ return false;
+ }
+ if (!BucketField.IsString())
+ {
+ return false;
+ }
+ std::optional<std::string> Bucket = GetValidBucketName(BucketField.AsString());
+ if (!Bucket.has_value())
+ {
+ return false;
+ }
+ CbFieldView HashField = KeyView["Hash"sv];
+ if (HashField.HasError())
+ {
+ return false;
+ }
+ if (!HashField.IsHash())
+ {
+ return false;
+ }
+ IoHash Hash = HashField.AsHash();
+ Key = CacheKey::Create(*Bucket, Hash);
+ return true;
+ }
+
+} // namespace
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpStructuredCacheService::HttpStructuredCacheService(ZenCacheStore& InCacheStore,
+ CidStore& InCidStore,
+ HttpStatsService& StatsService,
+ HttpStatusService& StatusService,
+ UpstreamCache& UpstreamCache)
+: m_Log(logging::Get("cache"))
+, m_CacheStore(InCacheStore)
+, m_StatsService(StatsService)
+, m_StatusService(StatusService)
+, m_CidStore(InCidStore)
+, m_UpstreamCache(UpstreamCache)
+{
+ m_StatsService.RegisterHandler("z$", *this);
+ m_StatusService.RegisterHandler("z$", *this);
+}
+
+HttpStructuredCacheService::~HttpStructuredCacheService()
+{
+ ZEN_INFO("closing structured cache");
+ m_RequestRecorder.reset();
+
+ m_StatsService.UnregisterHandler("z$", *this);
+ m_StatusService.UnregisterHandler("z$", *this);
+}
+
+const char*
+HttpStructuredCacheService::BaseUri() const
+{
+ return "/z$/";
+}
+
+void
+HttpStructuredCacheService::Flush()
+{
+ m_CacheStore.Flush();
+}
+
+void
+HttpStructuredCacheService::Scrub(ScrubContext& Ctx)
+{
+ if (m_LastScrubTime == Ctx.ScrubTimestamp())
+ {
+ return;
+ }
+
+ m_LastScrubTime = Ctx.ScrubTimestamp();
+
+ m_CidStore.Scrub(Ctx);
+ m_CacheStore.Scrub(Ctx);
+}
+
+void
+HttpStructuredCacheService::HandleDetailsRequest(HttpServerRequest& Request)
+{
+ std::string_view Key = Request.RelativeUri();
+ std::vector<std::string> Tokens;
+ uint32_t TokenCount = ForEachStrTok(Key, '/', [&Tokens](std::string_view Token) {
+ Tokens.push_back(std::string(Token));
+ return true;
+ });
+ std::string FilterNamespace;
+ std::string FilterBucket;
+ std::string FilterValue;
+ switch (TokenCount)
+ {
+ case 1:
+ break;
+ case 2:
+ {
+ FilterNamespace = Tokens[1];
+ if (FilterNamespace.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ }
+ break;
+ case 3:
+ {
+ FilterNamespace = Tokens[1];
+ if (FilterNamespace.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ FilterBucket = Tokens[2];
+ if (FilterBucket.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ }
+ break;
+ case 4:
+ {
+ FilterNamespace = Tokens[1];
+ if (FilterNamespace.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ FilterBucket = Tokens[2];
+ if (FilterBucket.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ FilterValue = Tokens[3];
+ if (FilterValue.empty())
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+ }
+ break;
+ default:
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+
+ HttpServerRequest::QueryParams Params = Request.GetQueryParams();
+ bool CSV = Params.GetValue("csv") == "true";
+ bool Details = Params.GetValue("details") == "true";
+ bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true";
+
+ std::chrono::seconds NowSeconds = std::chrono::duration_cast<std::chrono::seconds>(GcClock::Now().time_since_epoch());
+ CacheValueDetails ValueDetails = m_CacheStore.GetValueDetails(FilterNamespace, FilterBucket, FilterValue);
+
+ if (CSV)
+ {
+ ExtendableStringBuilder<4096> CSVWriter;
+ if (AttachmentDetails)
+ {
+ CSVWriter << "Namespace, Bucket, Key, Cid, Size";
+ }
+ else if (Details)
+ {
+ CSVWriter << "Namespace, Bucket, Key, Size, RawSize, RawHash, ContentType, Age, AttachmentsCount, AttachmentsSize";
+ }
+ else
+ {
+ CSVWriter << "Namespace, Bucket, Key";
+ }
+ for (const auto& NamespaceIt : ValueDetails.Namespaces)
+ {
+ const std::string& Namespace = NamespaceIt.first;
+ for (const auto& BucketIt : NamespaceIt.second.Buckets)
+ {
+ const std::string& Bucket = BucketIt.first;
+ for (const auto& ValueIt : BucketIt.second.Values)
+ {
+ if (AttachmentDetails)
+ {
+ for (const IoHash& Hash : ValueIt.second.Attachments)
+ {
+ IoBuffer Payload = m_CidStore.FindChunkByCid(Hash);
+ CSVWriter << "\r\n"
+ << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString() << ", " << Hash.ToHexString()
+ << ", " << gsl::narrow<uint64_t>(Payload.GetSize());
+ }
+ }
+ else if (Details)
+ {
+ std::chrono::seconds LastAccessedSeconds = std::chrono::duration_cast<std::chrono::seconds>(
+ GcClock::TimePointFromTick(ValueIt.second.LastAccess).time_since_epoch());
+ CSVWriter << "\r\n"
+ << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString() << ", " << ValueIt.second.Size << ","
+ << ValueIt.second.RawSize << "," << ValueIt.second.RawHash.ToHexString() << ", "
+ << ToString(ValueIt.second.ContentType) << ", " << (NowSeconds.count() - LastAccessedSeconds.count())
+ << ", " << gsl::narrow<uint64_t>(ValueIt.second.Attachments.size());
+ size_t AttachmentsSize = 0;
+ for (const IoHash& Hash : ValueIt.second.Attachments)
+ {
+ IoBuffer Payload = m_CidStore.FindChunkByCid(Hash);
+ AttachmentsSize += Payload.GetSize();
+ }
+ CSVWriter << ", " << gsl::narrow<uint64_t>(AttachmentsSize);
+ }
+ else
+ {
+ CSVWriter << "\r\n" << Namespace << "," << Bucket << "," << ValueIt.first.ToHexString();
+ }
+ }
+ }
+ }
+ return Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView());
+ }
+ else
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("namespaces");
+ {
+ for (const auto& NamespaceIt : ValueDetails.Namespaces)
+ {
+ const std::string& Namespace = NamespaceIt.first;
+ Cbo.BeginObject();
+ {
+ Cbo.AddString("name", Namespace);
+ Cbo.BeginArray("buckets");
+ {
+ for (const auto& BucketIt : NamespaceIt.second.Buckets)
+ {
+ const std::string& Bucket = BucketIt.first;
+ Cbo.BeginObject();
+ {
+ Cbo.AddString("name", Bucket);
+ Cbo.BeginArray("values");
+ {
+ for (const auto& ValueIt : BucketIt.second.Values)
+ {
+ std::chrono::seconds LastAccessedSeconds = std::chrono::duration_cast<std::chrono::seconds>(
+ GcClock::TimePointFromTick(ValueIt.second.LastAccess).time_since_epoch());
+ Cbo.BeginObject();
+ {
+ Cbo.AddHash("key", ValueIt.first);
+ if (Details)
+ {
+ Cbo.AddInteger("size", ValueIt.second.Size);
+ if (ValueIt.second.Size > 0 && ValueIt.second.RawSize != 0 &&
+ ValueIt.second.RawSize != ValueIt.second.Size)
+ {
+ Cbo.AddInteger("rawsize", ValueIt.second.RawSize);
+ Cbo.AddHash("rawhash", ValueIt.second.RawHash);
+ }
+ Cbo.AddString("contenttype", ToString(ValueIt.second.ContentType));
+ Cbo.AddInteger("age", NowSeconds.count() - LastAccessedSeconds.count());
+ if (ValueIt.second.Attachments.size() > 0)
+ {
+ if (AttachmentDetails)
+ {
+ Cbo.BeginArray("attachments");
+ {
+ for (const IoHash& Hash : ValueIt.second.Attachments)
+ {
+ Cbo.BeginObject();
+ Cbo.AddHash("cid", Hash);
+ IoBuffer Payload = m_CidStore.FindChunkByCid(Hash);
+ Cbo.AddInteger("size", gsl::narrow<uint64_t>(Payload.GetSize()));
+ Cbo.EndObject();
+ }
+ }
+ Cbo.EndArray();
+ }
+ else
+ {
+ Cbo.AddInteger("attachmentcount",
+ gsl::narrow<uint64_t>(ValueIt.second.Attachments.size()));
+ size_t AttachmentsSize = 0;
+ for (const IoHash& Hash : ValueIt.second.Attachments)
+ {
+ IoBuffer Payload = m_CidStore.FindChunkByCid(Hash);
+ AttachmentsSize += Payload.GetSize();
+ }
+ Cbo.AddInteger("attachmentssize", gsl::narrow<uint64_t>(AttachmentsSize));
+ }
+ }
+ }
+ }
+ Cbo.EndObject();
+ }
+ }
+ Cbo.EndArray();
+ }
+ Cbo.EndObject();
+ }
+ }
+ Cbo.EndArray();
+ }
+ Cbo.EndObject();
+ }
+ }
+ Cbo.EndArray();
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+}
+
+void
+HttpStructuredCacheService::HandleRequest(HttpServerRequest& Request)
+{
+ metrics::OperationTiming::Scope $(m_HttpRequests);
+
+ std::string_view Key = Request.RelativeUri();
+ if (Key == HttpZCacheRPCPrefix)
+ {
+ return HandleRpcRequest(Request);
+ }
+
+ if (Key == HttpZCacheUtilStartRecording)
+ {
+ m_RequestRecorder.reset();
+ HttpServerRequest::QueryParams Params = Request.GetQueryParams();
+ std::string RecordPath = cpr::util::urlDecode(std::string(Params.GetValue("path")));
+ m_RequestRecorder = cache::MakeDiskRequestRecorder(RecordPath);
+ Request.WriteResponse(HttpResponseCode::OK);
+ return;
+ }
+ if (Key == HttpZCacheUtilStopRecording)
+ {
+ m_RequestRecorder.reset();
+ Request.WriteResponse(HttpResponseCode::OK);
+ return;
+ }
+ if (Key == HttpZCacheUtilReplayRecording)
+ {
+ m_RequestRecorder.reset();
+ HttpServerRequest::QueryParams Params = Request.GetQueryParams();
+ std::string RecordPath = cpr::util::urlDecode(std::string(Params.GetValue("path")));
+ uint32_t ThreadCount = std::thread::hardware_concurrency();
+ if (auto Param = Params.GetValue("thread_count"); Param.empty() == false)
+ {
+ if (auto Value = ParseInt<uint64_t>(Param))
+ {
+ ThreadCount = gsl::narrow<uint32_t>(Value.value());
+ }
+ }
+ std::unique_ptr<cache::IRpcRequestReplayer> Replayer(cache::MakeDiskRequestReplayer(RecordPath, false));
+ ReplayRequestRecorder(*Replayer, ThreadCount < 1 ? 1 : ThreadCount);
+ Request.WriteResponse(HttpResponseCode::OK);
+ return;
+ }
+ if (Key.starts_with(HttpZCacheDetailsPrefix))
+ {
+ HandleDetailsRequest(Request);
+ return;
+ }
+
+ HttpRequestData RequestData;
+ if (!HttpRequestParseRelativeUri(Key, RequestData))
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest); // invalid URL
+ }
+
+ if (RequestData.ValueContentId.has_value())
+ {
+ ZEN_ASSERT(RequestData.Namespace.has_value());
+ ZEN_ASSERT(RequestData.Bucket.has_value());
+ ZEN_ASSERT(RequestData.HashKey.has_value());
+ CacheRef Ref = {.Namespace = RequestData.Namespace.value(),
+ .BucketSegment = RequestData.Bucket.value(),
+ .HashKey = RequestData.HashKey.value(),
+ .ValueContentId = RequestData.ValueContentId.value()};
+ return HandleCacheChunkRequest(Request, Ref, ParseCachePolicy(Request.GetQueryParams()));
+ }
+
+ if (RequestData.HashKey.has_value())
+ {
+ ZEN_ASSERT(RequestData.Namespace.has_value());
+ ZEN_ASSERT(RequestData.Bucket.has_value());
+ CacheRef Ref = {.Namespace = RequestData.Namespace.value(),
+ .BucketSegment = RequestData.Bucket.value(),
+ .HashKey = RequestData.HashKey.value(),
+ .ValueContentId = IoHash::Zero};
+ return HandleCacheRecordRequest(Request, Ref, ParseCachePolicy(Request.GetQueryParams()));
+ }
+
+ if (RequestData.Bucket.has_value())
+ {
+ ZEN_ASSERT(RequestData.Namespace.has_value());
+ return HandleCacheBucketRequest(Request, RequestData.Namespace.value(), RequestData.Bucket.value());
+ }
+
+ if (RequestData.Namespace.has_value())
+ {
+ return HandleCacheNamespaceRequest(Request, RequestData.Namespace.value());
+ }
+ return HandleCacheRequest(Request);
+}
+
+void
+HttpStructuredCacheService::HandleCacheRequest(HttpServerRequest& Request)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ ZenCacheStore::Info Info = m_CacheStore.GetInfo();
+
+ CbObjectWriter ResponseWriter;
+
+ ResponseWriter.BeginObject("Configuration");
+ {
+ ExtendableStringBuilder<128> BasePathString;
+ BasePathString << Info.Config.BasePath.u8string();
+ ResponseWriter.AddString("BasePath"sv, BasePathString.ToView());
+ ResponseWriter.AddBool("AllowAutomaticCreationOfNamespaces", Info.Config.AllowAutomaticCreationOfNamespaces);
+ }
+ ResponseWriter.EndObject();
+
+ std::sort(begin(Info.NamespaceNames), end(Info.NamespaceNames), [](std::string_view L, std::string_view R) {
+ return L.compare(R) < 0;
+ });
+ ResponseWriter.BeginArray("Namespaces");
+ for (const std::string& NamespaceName : Info.NamespaceNames)
+ {
+ ResponseWriter.AddString(NamespaceName);
+ }
+ ResponseWriter.EndArray();
+ ResponseWriter.BeginObject("StorageSize");
+ {
+ ResponseWriter.AddInteger("DiskSize", Info.StorageSize.DiskSize);
+ ResponseWriter.AddInteger("MemorySize", Info.StorageSize.MemorySize);
+ }
+
+ ResponseWriter.EndObject();
+
+ ResponseWriter.AddInteger("DiskEntryCount", Info.DiskEntryCount);
+ ResponseWriter.AddInteger("MemoryEntryCount", Info.MemoryEntryCount);
+
+ return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
+ }
+ break;
+ }
+}
+
+void
+HttpStructuredCacheService::HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view NamespaceName)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ std::optional<ZenCacheNamespace::Info> Info = m_CacheStore.GetNamespaceInfo(NamespaceName);
+ if (!Info.has_value())
+ {
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ CbObjectWriter ResponseWriter;
+
+ ResponseWriter.BeginObject("Configuration");
+ {
+ ExtendableStringBuilder<128> BasePathString;
+ BasePathString << Info->Config.RootDir.u8string();
+ ResponseWriter.AddString("RootDir"sv, BasePathString.ToView());
+ ResponseWriter.AddInteger("DiskLayerThreshold"sv, Info->Config.DiskLayerThreshold);
+ }
+ ResponseWriter.EndObject();
+
+ std::sort(begin(Info->BucketNames), end(Info->BucketNames), [](std::string_view L, std::string_view R) {
+ return L.compare(R) < 0;
+ });
+
+ ResponseWriter.BeginArray("Buckets"sv);
+ for (const std::string& BucketName : Info->BucketNames)
+ {
+ ResponseWriter.AddString(BucketName);
+ }
+ ResponseWriter.EndArray();
+
+ ResponseWriter.BeginObject("StorageSize"sv);
+ {
+ ResponseWriter.AddInteger("DiskSize"sv, Info->DiskLayerInfo.TotalSize);
+ ResponseWriter.AddInteger("MemorySize"sv, Info->MemoryLayerInfo.TotalSize);
+ }
+ ResponseWriter.EndObject();
+
+ ResponseWriter.AddInteger("DiskEntryCount", Info->DiskLayerInfo.EntryCount);
+ ResponseWriter.AddInteger("MemoryEntryCount", Info->MemoryLayerInfo.EntryCount);
+
+ return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
+ }
+ break;
+
+ case HttpVerb::kDelete:
+ // Drop namespace
+ {
+ if (m_CacheStore.DropNamespace(NamespaceName))
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+}
+
+void
+HttpStructuredCacheService::HandleCacheBucketRequest(HttpServerRequest& Request,
+ std::string_view NamespaceName,
+ std::string_view BucketName)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ std::optional<ZenCacheNamespace::BucketInfo> Info = m_CacheStore.GetBucketInfo(NamespaceName, BucketName);
+ if (!Info.has_value())
+ {
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ CbObjectWriter ResponseWriter;
+
+ ResponseWriter.BeginObject("StorageSize");
+ {
+ ResponseWriter.AddInteger("DiskSize", Info->DiskLayerInfo.TotalSize);
+ ResponseWriter.AddInteger("MemorySize", Info->MemoryLayerInfo.TotalSize);
+ }
+ ResponseWriter.EndObject();
+
+ ResponseWriter.AddInteger("DiskEntryCount", Info->DiskLayerInfo.EntryCount);
+ ResponseWriter.AddInteger("MemoryEntryCount", Info->MemoryLayerInfo.EntryCount);
+
+ return Request.WriteResponse(HttpResponseCode::OK, ResponseWriter.Save());
+ }
+ break;
+
+ case HttpVerb::kDelete:
+ // Drop bucket
+ {
+ if (m_CacheStore.DropBucket(NamespaceName, BucketName))
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+}
+
+void
+HttpStructuredCacheService::HandleCacheRecordRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ HandleGetCacheRecord(Request, Ref, PolicyFromUrl);
+ }
+ break;
+
+ case HttpVerb::kPut:
+ HandlePutCacheRecord(Request, Ref, PolicyFromUrl);
+ break;
+ default:
+ break;
+ }
+}
+
+void
+HttpStructuredCacheService::HandleGetCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ const ZenContentType AcceptType = Request.AcceptContentType();
+ const bool SkipData = EnumHasAllFlags(PolicyFromUrl, CachePolicy::SkipData);
+ const bool PartialRecord = EnumHasAllFlags(PolicyFromUrl, CachePolicy::PartialRecord);
+
+ bool Success = false;
+ ZenCacheValue ClientResultValue;
+ if (!EnumHasAnyFlags(PolicyFromUrl, CachePolicy::Query))
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+
+ Stopwatch Timer;
+
+ if (EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryLocal) &&
+ m_CacheStore.Get(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, ClientResultValue))
+ {
+ Success = true;
+ ZenContentType ContentType = ClientResultValue.Value.GetContentType();
+
+ if (AcceptType == ZenContentType::kCbPackage)
+ {
+ if (ContentType == ZenContentType::kCbObject)
+ {
+ CbPackage Package;
+ uint32_t MissingCount = 0;
+
+ CbObjectView CacheRecord(ClientResultValue.Value.Data());
+ CacheRecord.IterateAttachments([this, &MissingCount, &Package, SkipData](CbFieldView AttachmentHash) {
+ if (SkipData)
+ {
+ if (!m_CidStore.ContainsChunk(AttachmentHash.AsHash()))
+ {
+ MissingCount++;
+ }
+ }
+ else
+ {
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(AttachmentHash.AsHash()))
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk));
+ Package.AddAttachment(CbAttachment(Compressed, AttachmentHash.AsHash()));
+ }
+ else
+ {
+ MissingCount++;
+ }
+ }
+ });
+
+ Success = MissingCount == 0 || PartialRecord;
+
+ if (Success)
+ {
+ Package.SetObject(LoadCompactBinaryObject(ClientResultValue.Value));
+
+ BinaryWriter MemStream;
+ Package.Save(MemStream);
+
+ ClientResultValue.Value = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size());
+ ClientResultValue.Value.SetContentType(HttpContentType::kCbPackage);
+ }
+ }
+ else
+ {
+ Success = false;
+ }
+ }
+ else if (AcceptType != ClientResultValue.Value.GetContentType() && AcceptType != ZenContentType::kUnknownContentType &&
+ AcceptType != ZenContentType::kBinary)
+ {
+ Success = false;
+ }
+ }
+
+ if (Success)
+ {
+ ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {} '{}' (LOCAL) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ NiceBytes(ClientResultValue.Value.Size()),
+ ToString(ClientResultValue.Value.GetContentType()),
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ m_CacheStats.HitCount++;
+ if (SkipData && AcceptType != ZenContentType::kCbPackage && AcceptType != ZenContentType::kCbObject)
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ // kCbPackage handled SkipData when constructing the ClientResultValue, kcbObject ignores SkipData
+ return Request.WriteResponse(HttpResponseCode::OK, ClientResultValue.Value.GetContentType(), ClientResultValue.Value);
+ }
+ }
+ else if (!EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryRemote))
+ {
+ ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}' '{}' in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(AcceptType),
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.MissCount++;
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ // Issue upstream query asynchronously in order to keep requests flowing without
+ // hogging I/O servicing threads with blocking work
+
+ uint64_t LocalElapsedTimeUs = Timer.GetElapsedTimeUs();
+
+ Request.WriteResponseAsync([this, AcceptType, PolicyFromUrl, Ref, LocalElapsedTimeUs](HttpServerRequest& AsyncRequest) {
+ Stopwatch Timer;
+ bool Success = false;
+ const bool PartialRecord = EnumHasAllFlags(PolicyFromUrl, CachePolicy::PartialRecord);
+ const bool QueryLocal = EnumHasAllFlags(PolicyFromUrl, CachePolicy::QueryLocal);
+ const bool StoreLocal = EnumHasAllFlags(PolicyFromUrl, CachePolicy::StoreLocal);
+ const bool SkipData = EnumHasAllFlags(PolicyFromUrl, CachePolicy::SkipData);
+ ZenCacheValue ClientResultValue;
+
+ metrics::OperationTiming::Scope $(m_UpstreamGetRequestTiming);
+
+ if (GetUpstreamCacheSingleResult UpstreamResult =
+ m_UpstreamCache.GetCacheRecord(Ref.Namespace, {Ref.BucketSegment, Ref.HashKey}, AcceptType);
+ UpstreamResult.Status.Success)
+ {
+ Success = true;
+
+ ClientResultValue.Value = UpstreamResult.Value;
+ ClientResultValue.Value.SetContentType(AcceptType);
+
+ if (AcceptType == ZenContentType::kBinary || AcceptType == ZenContentType::kCbObject)
+ {
+ if (AcceptType == ZenContentType::kCbObject)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(UpstreamResult.Value, CbValidateMode::All);
+ if (ValidationResult != CbValidateError::None)
+ {
+ Success = false;
+ ZEN_WARN("Get - '{}/{}/{}' '{}' FAILED, invalid compact binary object from upstream",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(AcceptType));
+ }
+
+ // We do not do anything to the returned object for SkipData, only package attachments are cut when skipping data
+ }
+
+ if (Success && StoreLocal)
+ {
+ m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, ClientResultValue);
+ }
+ }
+ else if (AcceptType == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+ if (Package.TryLoad(ClientResultValue.Value))
+ {
+ CbObject CacheRecord = Package.GetObject();
+ AttachmentCount Count;
+ size_t NumAttachments = Package.GetAttachments().size();
+ std::vector<const CbAttachment*> AttachmentsToStoreLocally;
+ AttachmentsToStoreLocally.reserve(NumAttachments);
+
+ CacheRecord.IterateAttachments(
+ [this, &Package, &Ref, &AttachmentsToStoreLocally, &Count, QueryLocal, StoreLocal, SkipData](CbFieldView HashView) {
+ IoHash Hash = HashView.AsHash();
+ if (const CbAttachment* Attachment = Package.FindAttachment(Hash))
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ if (StoreLocal)
+ {
+ AttachmentsToStoreLocally.emplace_back(Attachment);
+ }
+ Count.Valid++;
+ }
+ else
+ {
+ ZEN_WARN("Uncompressed value '{}' from upstream cache record '{}/{}'",
+ Hash,
+ Ref.BucketSegment,
+ Ref.HashKey);
+ Count.Invalid++;
+ }
+ }
+ else if (QueryLocal)
+ {
+ if (SkipData)
+ {
+ if (m_CidStore.ContainsChunk(Hash))
+ {
+ Count.Valid++;
+ }
+ }
+ else if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Hash))
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk));
+ if (Compressed)
+ {
+ Package.AddAttachment(CbAttachment(Compressed, Hash));
+ Count.Valid++;
+ }
+ else
+ {
+ ZEN_WARN("Uncompressed value '{}' stored in local cache '{}/{}'",
+ Hash,
+ Ref.BucketSegment,
+ Ref.HashKey);
+ Count.Invalid++;
+ }
+ }
+ }
+ Count.Total++;
+ });
+
+ if ((Count.Valid == Count.Total) || PartialRecord)
+ {
+ ZenCacheValue CacheValue;
+ CacheValue.Value = CacheRecord.GetBuffer().AsIoBuffer();
+ CacheValue.Value.SetContentType(ZenContentType::kCbObject);
+
+ if (StoreLocal)
+ {
+ m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, CacheValue);
+ }
+
+ for (const CbAttachment* Attachment : AttachmentsToStoreLocally)
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ CidStore::InsertResult InsertResult =
+ m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash());
+ if (InsertResult.New)
+ {
+ Count.New++;
+ }
+ }
+
+ BinaryWriter MemStream;
+ if (SkipData)
+ {
+ // Save a package containing only the object.
+ CbPackage(Package.GetObject()).Save(MemStream);
+ }
+ else
+ {
+ Package.Save(MemStream);
+ }
+
+ ClientResultValue.Value = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size());
+ ClientResultValue.Value.SetContentType(ZenContentType::kCbPackage);
+ }
+ else
+ {
+ Success = false;
+ ZEN_WARN("Get - '{}/{}' '{}' FAILED, attachments missing in upstream package",
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(AcceptType));
+ }
+ }
+ else
+ {
+ Success = false;
+ ZEN_WARN("Get - '{}/{}/{}' '{}' FAILED, invalid upstream package",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(AcceptType));
+ }
+ }
+ }
+
+ if (Success)
+ {
+ ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {} '{}' (UPSTREAM) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ NiceBytes(ClientResultValue.Value.Size()),
+ ToString(ClientResultValue.Value.GetContentType()),
+ NiceLatencyNs((LocalElapsedTimeUs + Timer.GetElapsedTimeUs()) * 1000));
+
+ m_CacheStats.HitCount++;
+ m_CacheStats.UpstreamHitCount++;
+
+ if (SkipData && AcceptType == ZenContentType::kBinary)
+ {
+ AsyncRequest.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ // Other methods modify ClientResultValue to a version that has skipped the data but keeps the Object and optionally
+ // metadata.
+ AsyncRequest.WriteResponse(HttpResponseCode::OK, ClientResultValue.Value.GetContentType(), ClientResultValue.Value);
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}' '{}' in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(AcceptType),
+ NiceLatencyNs((LocalElapsedTimeUs + Timer.GetElapsedTimeUs()) * 1000));
+ m_CacheStats.MissCount++;
+ AsyncRequest.WriteResponse(HttpResponseCode::NotFound);
+ }
+ });
+}
+
+void
+HttpStructuredCacheService::HandlePutCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ IoBuffer Body = Request.ReadPayload();
+
+ if (!Body || Body.Size() == 0)
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ const HttpContentType ContentType = Request.RequestContentType();
+
+ Body.SetContentType(ContentType);
+
+ Stopwatch Timer;
+
+ if (ContentType == HttpContentType::kBinary || ContentType == HttpContentType::kCompressedBinary)
+ {
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = Body.GetSize();
+ if (ContentType == HttpContentType::kCompressedBinary)
+ {
+ if (!CompressedBuffer::ValidateCompressedHeader(Body, RawHash, RawSize))
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Payload is not a valid compressed binary"sv);
+ }
+ }
+ else
+ {
+ RawHash = IoHash::HashBuffer(SharedBuffer(Body));
+ }
+ m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, {.Value = Body, .RawSize = RawSize, .RawHash = RawHash});
+
+ if (EnumHasAllFlags(PolicyFromUrl, CachePolicy::StoreRemote))
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ContentType, .Namespace = Ref.Namespace, .Key = {Ref.BucketSegment, Ref.HashKey}});
+ }
+
+ ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}' in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ NiceBytes(Body.Size()),
+ ToString(ContentType),
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ Request.WriteResponse(HttpResponseCode::Created);
+ }
+ else if (ContentType == HttpContentType::kCbObject)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(MemoryView(Body.GetData(), Body.GetSize()), CbValidateMode::All);
+
+ if (ValidationResult != CbValidateError::None)
+ {
+ ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, invalid compact binary",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(ContentType));
+ return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Compact binary validation failed"sv);
+ }
+
+ Body.SetContentType(ZenContentType::kCbObject);
+ m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, {.Value = Body});
+
+ CbObjectView CacheRecord(Body.Data());
+ std::vector<IoHash> ValidAttachments;
+ int32_t TotalCount = 0;
+
+ CacheRecord.IterateAttachments([this, &TotalCount, &ValidAttachments](CbFieldView AttachmentHash) {
+ const IoHash Hash = AttachmentHash.AsHash();
+ if (m_CidStore.ContainsChunk(Hash))
+ {
+ ValidAttachments.emplace_back(Hash);
+ }
+ TotalCount++;
+ });
+
+ ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}' attachments '{}/{}' (valid/total) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ NiceBytes(Body.Size()),
+ ToString(ContentType),
+ TotalCount,
+ ValidAttachments.size(),
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ const bool IsPartialRecord = TotalCount != static_cast<int32_t>(ValidAttachments.size());
+
+ CachePolicy Policy = PolicyFromUrl;
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) && !IsPartialRecord)
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbObject,
+ .Namespace = Ref.Namespace,
+ .Key = {Ref.BucketSegment, Ref.HashKey},
+ .ValueContentIds = std::move(ValidAttachments)});
+ }
+
+ Request.WriteResponse(HttpResponseCode::Created);
+ }
+ else if (ContentType == HttpContentType::kCbPackage)
+ {
+ CbPackage Package;
+
+ if (!Package.TryLoad(Body))
+ {
+ ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, invalid package",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(ContentType));
+ return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid package"sv);
+ }
+ CachePolicy Policy = PolicyFromUrl;
+
+ CbObject CacheRecord = Package.GetObject();
+
+ AttachmentCount Count;
+ size_t NumAttachments = Package.GetAttachments().size();
+ std::vector<IoHash> ValidAttachments;
+ std::vector<const CbAttachment*> AttachmentsToStoreLocally;
+ ValidAttachments.reserve(NumAttachments);
+ AttachmentsToStoreLocally.reserve(NumAttachments);
+
+ CacheRecord.IterateAttachments([this, &Ref, &Package, &AttachmentsToStoreLocally, &ValidAttachments, &Count](CbFieldView HashView) {
+ const IoHash Hash = HashView.AsHash();
+ if (const CbAttachment* Attachment = Package.FindAttachment(Hash))
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ AttachmentsToStoreLocally.emplace_back(Attachment);
+ ValidAttachments.emplace_back(Hash);
+ Count.Valid++;
+ }
+ else
+ {
+ ZEN_WARN("PUTCACHERECORD - '{}/{}/{}' '{}' FAILED, attachment '{}' is not compressed",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ ToString(HttpContentType::kCbPackage),
+ Hash);
+ Count.Invalid++;
+ }
+ }
+ else if (m_CidStore.ContainsChunk(Hash))
+ {
+ ValidAttachments.emplace_back(Hash);
+ Count.Valid++;
+ }
+ Count.Total++;
+ });
+
+ if (Count.Invalid > 0)
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid attachment(s)"sv);
+ }
+
+ ZenCacheValue CacheValue;
+ CacheValue.Value = CacheRecord.GetBuffer().AsIoBuffer();
+ CacheValue.Value.SetContentType(ZenContentType::kCbObject);
+ m_CacheStore.Put(Ref.Namespace, Ref.BucketSegment, Ref.HashKey, CacheValue);
+
+ for (const CbAttachment* Attachment : AttachmentsToStoreLocally)
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash());
+ if (InsertResult.New)
+ {
+ Count.New++;
+ }
+ }
+
+ ZEN_DEBUG("PUTCACHERECORD - '{}/{}/{}' {} '{}', attachments '{}/{}/{}' (new/valid/total) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ NiceBytes(Body.GetSize()),
+ ToString(ContentType),
+ Count.New,
+ Count.Valid,
+ Count.Total,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ const bool IsPartialRecord = Count.Valid != Count.Total;
+
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote) && !IsPartialRecord)
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage,
+ .Namespace = Ref.Namespace,
+ .Key = {Ref.BucketSegment, Ref.HashKey},
+ .ValueContentIds = std::move(ValidAttachments)});
+ }
+
+ Request.WriteResponse(HttpResponseCode::Created);
+ }
+ else
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Content-Type invalid"sv);
+ }
+}
+
+void
+HttpStructuredCacheService::HandleCacheChunkRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ HandleGetCacheChunk(Request, Ref, PolicyFromUrl);
+ break;
+ case HttpVerb::kPut:
+ HandlePutCacheChunk(Request, Ref, PolicyFromUrl);
+ break;
+ default:
+ break;
+ }
+}
+
+void
+HttpStructuredCacheService::HandleGetCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ Stopwatch Timer;
+
+ IoBuffer Value = m_CidStore.FindChunkByCid(Ref.ValueContentId);
+ const UpstreamEndpointInfo* Source = nullptr;
+ CachePolicy Policy = PolicyFromUrl;
+ {
+ const bool QueryUpstream = !Value && EnumHasAllFlags(Policy, CachePolicy::QueryRemote);
+
+ if (QueryUpstream)
+ {
+ if (GetUpstreamCacheSingleResult UpstreamResult =
+ m_UpstreamCache.GetCacheChunk(Ref.Namespace, {Ref.BucketSegment, Ref.HashKey}, Ref.ValueContentId);
+ UpstreamResult.Status.Success)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(UpstreamResult.Value, RawHash, RawSize))
+ {
+ if (RawHash == Ref.ValueContentId)
+ {
+ m_CidStore.AddChunk(UpstreamResult.Value, RawHash);
+ Source = UpstreamResult.Source;
+ }
+ else
+ {
+ ZEN_WARN("got missmatching upstream cache value");
+ }
+ }
+ else
+ {
+ ZEN_WARN("got uncompressed upstream cache value");
+ }
+ }
+ }
+ }
+
+ if (!Value)
+ {
+ ZEN_DEBUG("GETCACHECHUNK MISS - '{}/{}/{}/{}' '{}' in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ Ref.ValueContentId,
+ ToString(Request.AcceptContentType()),
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.MissCount++;
+ return Request.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ZEN_DEBUG("GETCACHECHUNK HIT - '{}/{}/{}/{}' {} '{}' ({}) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ Ref.ValueContentId,
+ NiceBytes(Value.Size()),
+ ToString(Value.GetContentType()),
+ Source ? Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ m_CacheStats.HitCount++;
+ if (Source)
+ {
+ m_CacheStats.UpstreamHitCount++;
+ }
+
+ if (EnumHasAllFlags(Policy, CachePolicy::SkipData))
+ {
+ Request.WriteResponse(HttpResponseCode::OK);
+ }
+ else
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value);
+ }
+}
+
+void
+HttpStructuredCacheService::HandlePutCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl)
+{
+ // Note: Individual cacherecord values are not propagated upstream until a valid cache record has been stored
+ ZEN_UNUSED(PolicyFromUrl);
+
+ Stopwatch Timer;
+
+ IoBuffer Body = Request.ReadPayload();
+
+ if (!Body || Body.Size() == 0)
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ Body.SetContentType(Request.RequestContentType());
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(Body, RawHash, RawSize))
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Attachments must be compressed"sv);
+ }
+
+ if (RawHash != Ref.ValueContentId)
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "ValueContentId does not match attachment hash"sv);
+ }
+
+ CidStore::InsertResult Result = m_CidStore.AddChunk(Body, RawHash);
+
+ ZEN_DEBUG("PUTCACHECHUNK - '{}/{}/{}/{}' {} '{}' ({}) in {}",
+ Ref.Namespace,
+ Ref.BucketSegment,
+ Ref.HashKey,
+ Ref.ValueContentId,
+ NiceBytes(Body.Size()),
+ ToString(Body.GetContentType()),
+ Result.New ? "NEW" : "OLD",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ const HttpResponseCode ResponseCode = Result.New ? HttpResponseCode::Created : HttpResponseCode::OK;
+
+ Request.WriteResponse(ResponseCode);
+}
+
+CbPackage
+HttpStructuredCacheService::HandleRpcRequest(const ZenContentType ContentType,
+ IoBuffer&& Body,
+ uint32_t& OutAcceptMagic,
+ RpcAcceptOptions& OutAcceptFlags,
+ int& OutTargetProcessId)
+{
+ CbPackage Package;
+ CbObjectView Object;
+ CbObject ObjectBuffer;
+ if (ContentType == ZenContentType::kCbObject)
+ {
+ ObjectBuffer = LoadCompactBinaryObject(std::move(Body));
+ Object = ObjectBuffer;
+ }
+ else
+ {
+ Package = ParsePackageMessage(Body);
+ Object = Package.GetObject();
+ }
+ OutAcceptMagic = Object["Accept"sv].AsUInt32();
+ OutAcceptFlags = static_cast<RpcAcceptOptions>(Object["AcceptFlags"sv].AsUInt16(0u));
+ OutTargetProcessId = Object["Pid"sv].AsInt32(0);
+
+ const std::string_view Method = Object["Method"sv].AsString();
+
+ if (Method == "PutCacheRecords"sv)
+ {
+ return HandleRpcPutCacheRecords(Package);
+ }
+ else if (Method == "GetCacheRecords"sv)
+ {
+ return HandleRpcGetCacheRecords(Object);
+ }
+ else if (Method == "PutCacheValues"sv)
+ {
+ return HandleRpcPutCacheValues(Package);
+ }
+ else if (Method == "GetCacheValues"sv)
+ {
+ return HandleRpcGetCacheValues(Object);
+ }
+ else if (Method == "GetCacheChunks"sv)
+ {
+ return HandleRpcGetCacheChunks(Object);
+ }
+ return CbPackage{};
+}
+
+void
+HttpStructuredCacheService::ReplayRequestRecorder(cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount)
+{
+ WorkerThreadPool WorkerPool(ThreadCount);
+ uint64_t RequestCount = Replayer.GetRequestCount();
+ Stopwatch Timer;
+ auto _ = MakeGuard([&]() { ZEN_INFO("Replayed {} requests in {}", RequestCount, NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000)); });
+ Latch JobLatch(RequestCount);
+ ZEN_INFO("Replaying {} requests", RequestCount);
+ for (uint64_t RequestIndex = 0; RequestIndex < RequestCount; ++RequestIndex)
+ {
+ WorkerPool.ScheduleWork([this, &JobLatch, &Replayer, RequestIndex]() {
+ IoBuffer Body;
+ std::pair<ZenContentType, ZenContentType> ContentType = Replayer.GetRequest(RequestIndex, Body);
+ if (Body)
+ {
+ uint32_t AcceptMagic = 0;
+ RpcAcceptOptions AcceptFlags = RpcAcceptOptions::kNone;
+ int TargetPid = 0;
+ CbPackage RpcResult = HandleRpcRequest(ContentType.first, std::move(Body), AcceptMagic, AcceptFlags, TargetPid);
+ if (AcceptMagic == kCbPkgMagic)
+ {
+ FormatFlags Flags = FormatFlags::kDefault;
+ if (EnumHasAllFlags(AcceptFlags, RpcAcceptOptions::kAllowLocalReferences))
+ {
+ Flags |= FormatFlags::kAllowLocalReferences;
+ if (!EnumHasAnyFlags(AcceptFlags, RpcAcceptOptions::kAllowPartialLocalReferences))
+ {
+ Flags |= FormatFlags::kDenyPartialLocalReferences;
+ }
+ }
+ CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(RpcResult, Flags, TargetPid);
+ ZEN_ASSERT(RpcResponseBuffer.GetSize() > 0);
+ }
+ else
+ {
+ BinaryWriter MemStream;
+ RpcResult.Save(MemStream);
+ IoBuffer RpcResponseBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize());
+ ZEN_ASSERT(RpcResponseBuffer.Size() > 0);
+ }
+ }
+ JobLatch.CountDown();
+ });
+ }
+ while (!JobLatch.Wait(10000))
+ {
+ ZEN_INFO("Replayed {} of {} requests, elapsed {}",
+ RequestCount - JobLatch.Remaining(),
+ RequestCount,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ }
+}
+
+void
+HttpStructuredCacheService::HandleRpcRequest(HttpServerRequest& Request)
+{
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kPost:
+ {
+ const HttpContentType ContentType = Request.RequestContentType();
+ const HttpContentType AcceptType = Request.AcceptContentType();
+
+ if ((ContentType != HttpContentType::kCbObject && ContentType != HttpContentType::kCbPackage) ||
+ AcceptType != HttpContentType::kCbPackage)
+ {
+ return Request.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ Request.WriteResponseAsync(
+ [this, Body = Request.ReadPayload(), ContentType, AcceptType](HttpServerRequest& AsyncRequest) mutable {
+ std::uint64_t RequestIndex =
+ m_RequestRecorder ? m_RequestRecorder->RecordRequest(ContentType, AcceptType, Body) : ~0ull;
+ uint32_t AcceptMagic = 0;
+ RpcAcceptOptions AcceptFlags = RpcAcceptOptions::kNone;
+ int TargetProcessId = 0;
+ CbPackage RpcResult = HandleRpcRequest(ContentType, std::move(Body), AcceptMagic, AcceptFlags, TargetProcessId);
+ if (RpcResult.IsNull())
+ {
+ AsyncRequest.WriteResponse(HttpResponseCode::BadRequest);
+ return;
+ }
+ if (AcceptMagic == kCbPkgMagic)
+ {
+ FormatFlags Flags = FormatFlags::kDefault;
+ if (EnumHasAllFlags(AcceptFlags, RpcAcceptOptions::kAllowLocalReferences))
+ {
+ Flags |= FormatFlags::kAllowLocalReferences;
+ if (!EnumHasAnyFlags(AcceptFlags, RpcAcceptOptions::kAllowPartialLocalReferences))
+ {
+ Flags |= FormatFlags::kDenyPartialLocalReferences;
+ }
+ }
+ CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(RpcResult, Flags, TargetProcessId);
+ if (RequestIndex != ~0ull)
+ {
+ ZEN_ASSERT(m_RequestRecorder);
+ m_RequestRecorder->RecordResponse(RequestIndex, HttpContentType::kCbPackage, RpcResponseBuffer);
+ }
+ AsyncRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer);
+ }
+ else
+ {
+ BinaryWriter MemStream;
+ RpcResult.Save(MemStream);
+
+ if (RequestIndex != ~0ull)
+ {
+ ZEN_ASSERT(m_RequestRecorder);
+ m_RequestRecorder->RecordResponse(RequestIndex,
+ HttpContentType::kCbPackage,
+ IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()));
+ }
+ AsyncRequest.WriteResponse(HttpResponseCode::OK,
+ HttpContentType::kCbPackage,
+ IoBuffer(IoBuffer::Wrap, MemStream.GetData(), MemStream.GetSize()));
+ }
+ });
+ }
+ break;
+ default:
+ Request.WriteResponse(HttpResponseCode::BadRequest);
+ break;
+ }
+}
+
+CbPackage
+HttpStructuredCacheService::HandleRpcPutCacheRecords(const CbPackage& BatchRequest)
+{
+ ZEN_TRACE_CPU("Z$::RpcPutCacheRecords");
+ CbObjectView BatchObject = BatchRequest.GetObject();
+ ZEN_ASSERT(BatchObject["Method"sv].AsString() == "PutCacheRecords"sv);
+
+ CbObjectView Params = BatchObject["Params"sv].AsObjectView();
+ CachePolicy DefaultPolicy;
+
+ std::string_view PolicyText = Params["DefaultPolicy"].AsString();
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+ DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::vector<bool> Results;
+ for (CbFieldView RequestField : Params["Requests"sv])
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView RecordObject = RequestObject["Record"sv].AsObjectView();
+ CbObjectView KeyView = RecordObject["Key"sv].AsObjectView();
+
+ CacheKey Key;
+ if (!GetRpcRequestCacheKey(KeyView, Key))
+ {
+ return CbPackage{};
+ }
+ CacheRecordPolicy Policy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
+ PutRequestData PutRequest{*Namespace, std::move(Key), RecordObject, std::move(Policy)};
+
+ PutResult Result = PutCacheRecord(PutRequest, &BatchRequest);
+
+ if (Result == PutResult::Invalid)
+ {
+ return CbPackage{};
+ }
+ Results.push_back(Result == PutResult::Success);
+ }
+ if (Results.empty())
+ {
+ return CbPackage{};
+ }
+
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (bool Value : Results)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ CbPackage RpcResponse;
+ RpcResponse.SetObject(ResponseObject.Save());
+ return RpcResponse;
+}
+
+HttpStructuredCacheService::PutResult
+HttpStructuredCacheService::PutCacheRecord(PutRequestData& Request, const CbPackage* Package)
+{
+ CbObjectView Record = Request.RecordObject;
+ uint64_t RecordObjectSize = Record.GetSize();
+ uint64_t TransferredSize = RecordObjectSize;
+
+ AttachmentCount Count;
+ size_t NumAttachments = Package->GetAttachments().size();
+ std::vector<IoHash> ValidAttachments;
+ std::vector<const CbAttachment*> AttachmentsToStoreLocally;
+ ValidAttachments.reserve(NumAttachments);
+ AttachmentsToStoreLocally.reserve(NumAttachments);
+
+ Stopwatch Timer;
+
+ Request.RecordObject.IterateAttachments(
+ [this, &Request, Package, &AttachmentsToStoreLocally, &ValidAttachments, &Count, &TransferredSize](CbFieldView HashView) {
+ const IoHash ValueHash = HashView.AsHash();
+ if (const CbAttachment* Attachment = Package ? Package->FindAttachment(ValueHash) : nullptr)
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ AttachmentsToStoreLocally.emplace_back(Attachment);
+ ValidAttachments.emplace_back(ValueHash);
+ Count.Valid++;
+ }
+ else
+ {
+ ZEN_WARN("PUTCACEHRECORD - '{}/{}/{}' '{}' FAILED, attachment '{}' is not compressed",
+ Request.Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ ToString(HttpContentType::kCbPackage),
+ ValueHash);
+ Count.Invalid++;
+ }
+ }
+ else if (m_CidStore.ContainsChunk(ValueHash))
+ {
+ ValidAttachments.emplace_back(ValueHash);
+ Count.Valid++;
+ }
+ Count.Total++;
+ });
+
+ if (Count.Invalid > 0)
+ {
+ return PutResult::Invalid;
+ }
+
+ ZenCacheValue CacheValue;
+ CacheValue.Value = IoBuffer(Record.GetSize());
+ Record.CopyTo(MutableMemoryView(CacheValue.Value.MutableData(), CacheValue.Value.GetSize()));
+ CacheValue.Value.SetContentType(ZenContentType::kCbObject);
+ m_CacheStore.Put(Request.Namespace, Request.Key.Bucket, Request.Key.Hash, CacheValue);
+
+ for (const CbAttachment* Attachment : AttachmentsToStoreLocally)
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ CidStore::InsertResult InsertResult = m_CidStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash());
+ if (InsertResult.New)
+ {
+ Count.New++;
+ }
+ TransferredSize += Chunk.GetCompressedSize();
+ }
+
+ ZEN_DEBUG("PUTCACEHRECORD - '{}/{}/{}' {}, attachments '{}/{}/{}' (new/valid/total) in {}",
+ Request.Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ NiceBytes(TransferredSize),
+ Count.New,
+ Count.Valid,
+ Count.Total,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+
+ const bool IsPartialRecord = Count.Valid != Count.Total;
+
+ if (EnumHasAllFlags(Request.Policy.GetRecordPolicy(), CachePolicy::StoreRemote) && !IsPartialRecord)
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCbPackage,
+ .Namespace = Request.Namespace,
+ .Key = Request.Key,
+ .ValueContentIds = std::move(ValidAttachments)});
+ }
+ return PutResult::Success;
+}
+
+CbPackage
+HttpStructuredCacheService::HandleRpcGetCacheRecords(CbObjectView RpcRequest)
+{
+ ZEN_TRACE_CPU("Z$::RpcGetCacheRecords");
+
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheRecords"sv);
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+
+ struct ValueRequestData
+ {
+ Oid ValueId;
+ IoHash ContentId;
+ CompressedBuffer Payload;
+ CachePolicy DownstreamPolicy;
+ bool Exists = false;
+ bool ReadFromUpstream = false;
+ };
+ struct RecordRequestData
+ {
+ CacheKeyRequest Upstream;
+ CbObjectView RecordObject;
+ IoBuffer RecordCacheValue;
+ CacheRecordPolicy DownstreamPolicy;
+ std::vector<ValueRequestData> Values;
+ bool Complete = false;
+ const UpstreamEndpointInfo* Source = nullptr;
+ uint64_t ElapsedTimeUs;
+ };
+
+ std::string_view PolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+ std::vector<RecordRequestData> Requests;
+ std::vector<size_t> UpstreamIndexes;
+ CbArrayView RequestsArray = Params["Requests"sv].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+
+ auto ParseValues = [](RecordRequestData& Request) {
+ CbArrayView ValuesArray = Request.RecordObject["Values"sv].AsArrayView();
+ Request.Values.reserve(ValuesArray.Num());
+ for (CbFieldView ValueField : ValuesArray)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ CbFieldView RawHashField = ValueObject["RawHash"sv];
+ IoHash RawHash = RawHashField.AsBinaryAttachment();
+ if (ValueId && !RawHashField.HasError())
+ {
+ Request.Values.push_back({ValueId, RawHash});
+ Request.Values.back().DownstreamPolicy = Request.DownstreamPolicy.GetValuePolicy(ValueId);
+ }
+ }
+ };
+
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ Stopwatch Timer;
+ RecordRequestData& Request = Requests.emplace_back();
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+
+ CacheKey& Key = Request.Upstream.Key;
+ if (!GetRpcRequestCacheKey(KeyObject, Key))
+ {
+ return CbPackage{};
+ }
+
+ Request.DownstreamPolicy = LoadCacheRecordPolicy(RequestObject["Policy"sv].AsObjectView(), DefaultPolicy);
+ const CacheRecordPolicy& Policy = Request.DownstreamPolicy;
+
+ ZenCacheValue CacheValue;
+ bool NeedUpstreamAttachment = false;
+ bool FoundLocalInvalid = false;
+ ZenCacheValue RecordCacheValue;
+
+ if (EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryLocal) &&
+ m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, RecordCacheValue))
+ {
+ Request.RecordCacheValue = std::move(RecordCacheValue.Value);
+ if (Request.RecordCacheValue.GetContentType() != ZenContentType::kCbObject)
+ {
+ FoundLocalInvalid = true;
+ }
+ else
+ {
+ Request.RecordObject = CbObjectView(Request.RecordCacheValue.GetData());
+ ParseValues(Request);
+
+ Request.Complete = true;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ CachePolicy ValuePolicy = Value.DownstreamPolicy;
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryLocal))
+ {
+ // A value that is requested without the Query flag (such as None/Disable) counts as existing, because we
+ // didn't ask for it and thus the record is complete in its absence.
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ Value.Exists = true;
+ }
+ else
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ Request.Complete = false;
+ }
+ }
+ else if (EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ if (m_CidStore.ContainsChunk(Value.ContentId))
+ {
+ Value.Exists = true;
+ }
+ else
+ {
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ }
+ Request.Complete = false;
+ }
+ }
+ else
+ {
+ if (IoBuffer Chunk = m_CidStore.FindChunkByCid(Value.ContentId))
+ {
+ ZEN_ASSERT(Chunk.GetSize() > 0);
+ Value.Payload = CompressedBuffer::FromCompressedNoValidate(std::move(Chunk));
+ Value.Exists = true;
+ }
+ else
+ {
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ NeedUpstreamAttachment = true;
+ Value.ReadFromUpstream = true;
+ }
+ Request.Complete = false;
+ }
+ }
+ }
+ }
+ }
+ if (!Request.Complete)
+ {
+ bool NeedUpstreamRecord =
+ !Request.RecordObject && !FoundLocalInvalid && EnumHasAllFlags(Policy.GetRecordPolicy(), CachePolicy::QueryRemote);
+ if (NeedUpstreamRecord || NeedUpstreamAttachment)
+ {
+ UpstreamIndexes.push_back(Requests.size() - 1);
+ }
+ }
+ Request.ElapsedTimeUs = Timer.GetElapsedTimeUs();
+ }
+ if (Requests.empty())
+ {
+ return CbPackage{};
+ }
+
+ if (!UpstreamIndexes.empty())
+ {
+ std::vector<CacheKeyRequest*> UpstreamRequests;
+ UpstreamRequests.reserve(UpstreamIndexes.size());
+ for (size_t Index : UpstreamIndexes)
+ {
+ RecordRequestData& Request = Requests[Index];
+ UpstreamRequests.push_back(&Request.Upstream);
+
+ if (Request.Values.size())
+ {
+ // We will be returning the local object and know all the value Ids that exist in it
+ // Convert all their Downstream Values to upstream values, and add SkipData to any ones that we already have.
+ CachePolicy UpstreamBasePolicy = ConvertToUpstream(Request.DownstreamPolicy.GetBasePolicy()) | CachePolicy::SkipMeta;
+ CacheRecordPolicyBuilder Builder(UpstreamBasePolicy);
+ for (ValueRequestData& Value : Request.Values)
+ {
+ CachePolicy UpstreamPolicy = ConvertToUpstream(Value.DownstreamPolicy);
+ UpstreamPolicy |= !Value.ReadFromUpstream ? CachePolicy::SkipData : CachePolicy::None;
+ Builder.AddValuePolicy(Value.ValueId, UpstreamPolicy);
+ }
+ Request.Upstream.Policy = Builder.Build();
+ }
+ else
+ {
+ // We don't know which Values exist in the Record; ask the upstrem for all values that the client wants,
+ // and convert the CacheRecordPolicy to an upstream policy
+ Request.Upstream.Policy = Request.DownstreamPolicy.ConvertToUpstream();
+ }
+ }
+
+ const auto OnCacheRecordGetComplete = [this, Namespace, &ParseValues](CacheRecordGetCompleteParams&& Params) {
+ if (!Params.Record)
+ {
+ return;
+ }
+
+ RecordRequestData& Request =
+ *reinterpret_cast<RecordRequestData*>(reinterpret_cast<char*>(&Params.Request) - offsetof(RecordRequestData, Upstream));
+ Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
+ const CacheKey& Key = Request.Upstream.Key;
+ Stopwatch Timer;
+ auto TimeGuard = MakeGuard([&Timer, &Request]() { Request.ElapsedTimeUs += Timer.GetElapsedTimeUs(); });
+ if (!Request.RecordObject)
+ {
+ CbObject ObjectBuffer = CbObject::Clone(Params.Record);
+ Request.RecordCacheValue = ObjectBuffer.GetBuffer().AsIoBuffer();
+ Request.RecordCacheValue.SetContentType(ZenContentType::kCbObject);
+ Request.RecordObject = ObjectBuffer;
+ if (EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::StoreLocal))
+ {
+ m_CacheStore.Put(*Namespace, Key.Bucket, Key.Hash, {.Value = {Request.RecordCacheValue}});
+ }
+ ParseValues(Request);
+ Request.Source = Params.Source;
+ }
+
+ Request.Complete = true;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ if (Value.Exists)
+ {
+ continue;
+ }
+ CachePolicy ValuePolicy = Value.DownstreamPolicy;
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::QueryRemote))
+ {
+ Request.Complete = false;
+ continue;
+ }
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData) || EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
+ {
+ if (const CbAttachment* Attachment = Params.Package.FindAttachment(Value.ContentId))
+ {
+ if (CompressedBuffer Compressed = Attachment->AsCompressedBinary())
+ {
+ Request.Source = Params.Source;
+ Value.Exists = true;
+ if (EnumHasAllFlags(ValuePolicy, CachePolicy::StoreLocal))
+ {
+ m_CidStore.AddChunk(Compressed.GetCompressed().Flatten().AsIoBuffer(), Attachment->GetHash());
+ }
+ if (!EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ Value.Payload = Compressed;
+ }
+ }
+ else
+ {
+ ZEN_DEBUG("Uncompressed value '{}' from upstream cache record '{}/{}/{}'",
+ Value.ContentId,
+ *Namespace,
+ Key.Bucket,
+ Key.Hash);
+ }
+ }
+ if (!Value.Exists && !EnumHasAllFlags(ValuePolicy, CachePolicy::SkipData))
+ {
+ Request.Complete = false;
+ }
+ // Request.Complete does not need to be set to false for upstream SkipData attachments.
+ // In the PartialRecord==false case, the upstream will have failed the entire record if any SkipData attachment
+ // didn't exist and we will not get here. In the PartialRecord==true case, we do not need to inform the client of
+ // any missing SkipData attachments.
+ }
+ Request.ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+ };
+
+ m_UpstreamCache.GetCacheRecords(*Namespace, UpstreamRequests, std::move(OnCacheRecordGetComplete));
+ }
+
+ CbPackage ResponsePackage;
+ CbObjectWriter ResponseObject;
+
+ ResponseObject.BeginArray("Result"sv);
+ for (RecordRequestData& Request : Requests)
+ {
+ const CacheKey& Key = Request.Upstream.Key;
+ if (Request.Complete ||
+ (Request.RecordObject && EnumHasAllFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::PartialRecord)))
+ {
+ ResponseObject << Request.RecordObject;
+ for (ValueRequestData& Value : Request.Values)
+ {
+ if (!EnumHasAllFlags(Value.DownstreamPolicy, CachePolicy::SkipData) && Value.Payload)
+ {
+ ResponsePackage.AddAttachment(CbAttachment(Value.Payload, Value.ContentId));
+ }
+ }
+
+ ZEN_DEBUG("GETCACHERECORD HIT - '{}/{}/{}' {}{} ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(Request.RecordCacheValue.Size()),
+ Request.Complete ? ""sv : " (PARTIAL)"sv,
+ Request.Source ? Request.Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.HitCount++;
+ m_CacheStats.UpstreamHitCount += Request.Source ? 1 : 0;
+ }
+ else
+ {
+ ResponseObject.AddNull();
+
+ if (!EnumHasAnyFlags(Request.DownstreamPolicy.GetRecordPolicy(), CachePolicy::Query))
+ {
+ // If they requested no query, do not record this as a miss
+ ZEN_DEBUG("GETCACHERECORD DISABLEDQUERY - '{}/{}/{}' in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHERECORD MISS - '{}/{}/{}'{} ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ Request.RecordObject ? ""sv : " (PARTIAL)"sv,
+ Request.Source ? Request.Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.MissCount++;
+ }
+ }
+ }
+ ResponseObject.EndArray();
+ ResponsePackage.SetObject(ResponseObject.Save());
+ return ResponsePackage;
+}
+
+CbPackage
+HttpStructuredCacheService::HandleRpcPutCacheValues(const CbPackage& BatchRequest)
+{
+ CbObjectView BatchObject = BatchRequest.GetObject();
+ CbObjectView Params = BatchObject["Params"sv].AsObjectView();
+
+ std::string_view PolicyText = Params["DefaultPolicy"].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+ std::vector<bool> Results;
+ for (CbFieldView RequestField : Params["Requests"sv])
+ {
+ Stopwatch Timer;
+
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyView = RequestObject["Key"sv].AsObjectView();
+
+ CacheKey Key;
+ if (!GetRpcRequestCacheKey(KeyView, Key))
+ {
+ return CbPackage{};
+ }
+
+ PolicyText = RequestObject["Policy"sv].AsString();
+ CachePolicy Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+ IoHash RawHash = RequestObject["RawHash"sv].AsBinaryAttachment();
+ uint64_t RawSize = RequestObject["RawSize"sv].AsUInt64();
+ bool Succeeded = false;
+ uint64_t TransferredSize = 0;
+
+ if (const CbAttachment* Attachment = BatchRequest.FindAttachment(RawHash))
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ CompressedBuffer Chunk = Attachment->AsCompressedBinary();
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote))
+ {
+ // TODO: Implement upstream puts of CacheValues with StoreLocal == false.
+ // Currently ProcessCacheRecord requires that the value exist in the local cache to put it upstream.
+ Policy |= CachePolicy::StoreLocal;
+ }
+
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreLocal))
+ {
+ IoBuffer Value = Chunk.GetCompressed().Flatten().AsIoBuffer();
+ Value.SetContentType(ZenContentType::kCompressedBinary);
+ if (RawSize == 0)
+ {
+ RawSize = Chunk.DecodeRawSize();
+ }
+ m_CacheStore.Put(*Namespace, Key.Bucket, Key.Hash, {.Value = Value, .RawSize = RawSize, .RawHash = RawHash});
+ TransferredSize = Chunk.GetCompressedSize();
+ }
+ Succeeded = true;
+ }
+ else
+ {
+ ZEN_WARN("PUTCACHEVALUES - '{}/{}/{}/{}' FAILED, value is not compressed", *Namespace, Key.Bucket, Key.Hash, RawHash);
+ return CbPackage{};
+ }
+ }
+ else if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue ExistingValue;
+ if (m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, ExistingValue) &&
+ IsCompressedBinary(ExistingValue.Value.GetContentType()))
+ {
+ Succeeded = true;
+ }
+ }
+ // We do not search the Upstream. No data in a put means the caller is probing for whether they need to do a heavy put.
+ // If it doesn't exist locally they should do the heavy put rather than having us fetch it from upstream.
+
+ if (Succeeded && EnumHasAllFlags(Policy, CachePolicy::StoreRemote))
+ {
+ m_UpstreamCache.EnqueueUpstream({.Type = ZenContentType::kCompressedBinary, .Namespace = *Namespace, .Key = Key});
+ }
+ Results.push_back(Succeeded);
+ ZEN_DEBUG("PUTCACHEVALUES - '{}/{}/{}' {}, '{}' in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(TransferredSize),
+ Succeeded ? "Added"sv : "Invalid",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ }
+ if (Results.empty())
+ {
+ return CbPackage{};
+ }
+
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (bool Value : Results)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ CbPackage RpcResponse;
+ RpcResponse.SetObject(ResponseObject.Save());
+
+ return RpcResponse;
+}
+
+CbPackage
+HttpStructuredCacheService::HandleRpcGetCacheValues(CbObjectView RpcRequest)
+{
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheValues"sv);
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ std::string_view PolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : CachePolicy::Default;
+ std::optional<std::string> Namespace = GetRpcRequestNamespace(Params);
+ if (!Namespace)
+ {
+ return CbPackage{};
+ }
+
+ struct RequestData
+ {
+ CacheKey Key;
+ CachePolicy Policy;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+ CompressedBuffer Result;
+ };
+ std::vector<RequestData> Requests;
+
+ std::vector<size_t> RemoteRequestIndexes;
+
+ for (CbFieldView RequestField : Params["Requests"sv])
+ {
+ Stopwatch Timer;
+
+ RequestData& Request = Requests.emplace_back();
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+
+ if (!GetRpcRequestCacheKey(KeyObject, Request.Key))
+ {
+ return CbPackage{};
+ }
+
+ PolicyText = RequestObject["Policy"sv].AsString();
+ Request.Policy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+
+ CacheKey& Key = Request.Key;
+ CachePolicy Policy = Request.Policy;
+
+ ZenCacheValue CacheValue;
+ if (EnumHasAllFlags(Policy, CachePolicy::QueryLocal))
+ {
+ if (m_CacheStore.Get(*Namespace, Key.Bucket, Key.Hash, CacheValue) && IsCompressedBinary(CacheValue.Value.GetContentType()))
+ {
+ Request.RawHash = CacheValue.RawHash;
+ Request.RawSize = CacheValue.RawSize;
+ Request.Result = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value));
+ }
+ }
+ if (Request.Result)
+ {
+ ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ NiceBytes(Request.Result.GetCompressed().GetSize()),
+ "LOCAL"sv,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.HitCount++;
+ }
+ else if (EnumHasAllFlags(Policy, CachePolicy::QueryRemote))
+ {
+ RemoteRequestIndexes.push_back(Requests.size() - 1);
+ }
+ else if (!EnumHasAnyFlags(Policy, CachePolicy::Query))
+ {
+ // If they requested no query, do not record this as a miss
+ ZEN_DEBUG("GETCACHEVALUES DISABLEDQUERY - '{}/{}/{}'", *Namespace, Key.Bucket, Key.Hash);
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}",
+ *Namespace,
+ Key.Bucket,
+ Key.Hash,
+ "LOCAL"sv,
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.MissCount++;
+ }
+ }
+
+ if (!RemoteRequestIndexes.empty())
+ {
+ std::vector<CacheValueRequest> RequestedRecordsData;
+ std::vector<CacheValueRequest*> CacheValueRequests;
+ RequestedRecordsData.reserve(RemoteRequestIndexes.size());
+ CacheValueRequests.reserve(RemoteRequestIndexes.size());
+ for (size_t Index : RemoteRequestIndexes)
+ {
+ RequestData& Request = Requests[Index];
+ RequestedRecordsData.push_back({.Key = {Request.Key.Bucket, Request.Key.Hash}, .Policy = ConvertToUpstream(Request.Policy)});
+ CacheValueRequests.push_back(&RequestedRecordsData.back());
+ }
+ Stopwatch Timer;
+ m_UpstreamCache.GetCacheValues(
+ *Namespace,
+ CacheValueRequests,
+ [this, Namespace, &RequestedRecordsData, &Requests, &RemoteRequestIndexes, &Timer](CacheValueGetCompleteParams&& Params) {
+ CacheValueRequest& ChunkRequest = Params.Request;
+ if (Params.RawHash != IoHash::Zero)
+ {
+ size_t RequestOffset = std::distance(RequestedRecordsData.data(), &ChunkRequest);
+ size_t RequestIndex = RemoteRequestIndexes[RequestOffset];
+ RequestData& Request = Requests[RequestIndex];
+ Request.RawHash = Params.RawHash;
+ Request.RawSize = Params.RawSize;
+ const bool HasData = IsCompressedBinary(Params.Value.GetContentType());
+ const bool SkipData = EnumHasAllFlags(Request.Policy, CachePolicy::SkipData);
+ const bool StoreData = EnumHasAllFlags(Request.Policy, CachePolicy::StoreLocal);
+ const bool IsHit = SkipData || HasData;
+ if (IsHit)
+ {
+ if (HasData && !SkipData)
+ {
+ Request.Result = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value));
+ }
+
+ if (HasData && StoreData)
+ {
+ m_CacheStore.Put(*Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash,
+ ZenCacheValue{.Value = Params.Value, .RawSize = Request.RawSize, .RawHash = Request.RawHash});
+ }
+
+ ZEN_DEBUG("GETCACHEVALUES HIT - '{}/{}/{}' {} ({}) in {}",
+ *Namespace,
+ ChunkRequest.Key.Bucket,
+ ChunkRequest.Key.Hash,
+ NiceBytes(Request.Result.GetCompressed().GetSize()),
+ Params.Source ? Params.Source->Url : "UPSTREAM",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.HitCount++;
+ m_CacheStats.UpstreamHitCount++;
+ return;
+ }
+ }
+ ZEN_DEBUG("GETCACHEVALUES MISS - '{}/{}/{}' ({}) in {}",
+ *Namespace,
+ ChunkRequest.Key.Bucket,
+ ChunkRequest.Key.Hash,
+ Params.Source ? Params.Source->Url : "UPSTREAM",
+ NiceLatencyNs(Timer.GetElapsedTimeUs() * 1000));
+ m_CacheStats.MissCount++;
+ });
+ }
+
+ if (Requests.empty())
+ {
+ return CbPackage{};
+ }
+
+ CbPackage RpcResponse;
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result"sv);
+ for (const RequestData& Request : Requests)
+ {
+ ResponseObject.BeginObject();
+ {
+ const CompressedBuffer& Result = Request.Result;
+ if (Result)
+ {
+ ResponseObject.AddHash("RawHash"sv, Request.RawHash);
+ if (!EnumHasAllFlags(Request.Policy, CachePolicy::SkipData))
+ {
+ RpcResponse.AddAttachment(CbAttachment(Result, Request.RawHash));
+ }
+ else
+ {
+ ResponseObject.AddInteger("RawSize"sv, Request.RawSize);
+ }
+ }
+ else if (Request.RawHash != IoHash::Zero)
+ {
+ ResponseObject.AddHash("RawHash"sv, Request.RawHash);
+ ResponseObject.AddInteger("RawSize"sv, Request.RawSize);
+ }
+ }
+ ResponseObject.EndObject();
+ }
+ ResponseObject.EndArray();
+
+ RpcResponse.SetObject(ResponseObject.Save());
+ return RpcResponse;
+}
+
+namespace cache::detail {
+
+ struct RecordValue
+ {
+ Oid ValueId;
+ IoHash ContentId;
+ uint64_t RawSize;
+ };
+ struct RecordBody
+ {
+ IoBuffer CacheValue;
+ std::vector<RecordValue> Values;
+ const UpstreamEndpointInfo* Source = nullptr;
+ CachePolicy DownstreamPolicy;
+ bool Exists = false;
+ bool HasRequest = false;
+ bool ValuesRead = false;
+ };
+ struct ChunkRequest
+ {
+ CacheChunkRequest* Key = nullptr;
+ RecordBody* Record = nullptr;
+ CompressedBuffer Value;
+ const UpstreamEndpointInfo* Source = nullptr;
+ uint64_t RawSize = 0;
+ uint64_t RequestedSize = 0;
+ uint64_t RequestedOffset = 0;
+ CachePolicy DownstreamPolicy;
+ bool Exists = false;
+ bool RawSizeKnown = false;
+ bool IsRecordRequest = false;
+ uint64_t ElapsedTimeUs = 0;
+ };
+
+} // namespace cache::detail
+
+CbPackage
+HttpStructuredCacheService::HandleRpcGetCacheChunks(CbObjectView RpcRequest)
+{
+ using namespace cache::detail;
+
+ std::string Namespace;
+ std::vector<CacheKeyRequest> RecordKeys; // Data about a Record necessary to identify it to the upstream
+ std::vector<RecordBody> Records; // Scratch-space data about a Record when fulfilling RecordRequests
+ std::vector<CacheChunkRequest> RequestKeys; // Data about a ChunkRequest necessary to identify it to the upstream
+ std::vector<ChunkRequest> Requests; // Intermediate and result data about a ChunkRequest
+ std::vector<ChunkRequest*> RecordRequests; // The ChunkRequests that are requesting a subvalue from a Record Key
+ std::vector<ChunkRequest*> ValueRequests; // The ChunkRequests that are requesting a Value Key
+ std::vector<CacheChunkRequest*> UpstreamChunks; // ChunkRequests that we need to send to the upstream
+
+ // Parse requests from the CompactBinary body of the RpcRequest and divide it into RecordRequests and ValueRequests
+ if (!ParseGetCacheChunksRequest(Namespace, RecordKeys, Records, RequestKeys, Requests, RecordRequests, ValueRequests, RpcRequest))
+ {
+ return CbPackage{};
+ }
+
+ // For each Record request, load the Record if necessary to find the Chunk's ContentId, load its Payloads if we
+ // have it locally, and otherwise append a request for the payload to UpstreamChunks
+ GetLocalCacheRecords(Namespace, RecordKeys, Records, RecordRequests, UpstreamChunks);
+
+ // For each Value request, load the Value if we have it locally and otherwise append a request for the payload to UpstreamChunks
+ GetLocalCacheValues(Namespace, ValueRequests, UpstreamChunks);
+
+ // Call GetCacheChunks on the upstream for any payloads we do not have locally
+ GetUpstreamCacheChunks(Namespace, UpstreamChunks, RequestKeys, Requests);
+
+ // Send the payload and descriptive data about each chunk to the client
+ return WriteGetCacheChunksResponse(Namespace, Requests);
+}
+
+bool
+HttpStructuredCacheService::ParseGetCacheChunksRequest(std::string& Namespace,
+ std::vector<CacheKeyRequest>& RecordKeys,
+ std::vector<cache::detail::RecordBody>& Records,
+ std::vector<CacheChunkRequest>& RequestKeys,
+ std::vector<cache::detail::ChunkRequest>& Requests,
+ std::vector<cache::detail::ChunkRequest*>& RecordRequests,
+ std::vector<cache::detail::ChunkRequest*>& ValueRequests,
+ CbObjectView RpcRequest)
+{
+ using namespace cache::detail;
+
+ ZEN_ASSERT(RpcRequest["Method"sv].AsString() == "GetCacheChunks"sv);
+
+ CbObjectView Params = RpcRequest["Params"sv].AsObjectView();
+ std::string_view DefaultPolicyText = Params["DefaultPolicy"sv].AsString();
+ CachePolicy DefaultPolicy = !DefaultPolicyText.empty() ? ParseCachePolicy(DefaultPolicyText) : CachePolicy::Default;
+
+ std::optional<std::string> NamespaceText = GetRpcRequestNamespace(Params);
+ if (!NamespaceText)
+ {
+ ZEN_WARN("GetCacheChunks: Invalid namespace in ChunkRequest.");
+ return false;
+ }
+ Namespace = *NamespaceText;
+
+ CbArrayView ChunkRequestsArray = Params["ChunkRequests"sv].AsArrayView();
+ size_t NumRequests = static_cast<size_t>(ChunkRequestsArray.Num());
+
+ // Note that these reservations allow us to take pointers to the elements while populating them. If the reservation is removed,
+ // we will need to change the pointers to indexes to handle reallocations.
+ RecordKeys.reserve(NumRequests);
+ Records.reserve(NumRequests);
+ RequestKeys.reserve(NumRequests);
+ Requests.reserve(NumRequests);
+ RecordRequests.reserve(NumRequests);
+ ValueRequests.reserve(NumRequests);
+
+ CacheKeyRequest* PreviousRecordKey = nullptr;
+ RecordBody* PreviousRecord = nullptr;
+
+ for (CbFieldView RequestView : ChunkRequestsArray)
+ {
+ CbObjectView RequestObject = RequestView.AsObjectView();
+ CacheChunkRequest& RequestKey = RequestKeys.emplace_back();
+ ChunkRequest& Request = Requests.emplace_back();
+ CbObjectView KeyObject = RequestObject["Key"sv].AsObjectView();
+
+ Request.Key = &RequestKey;
+ if (!GetRpcRequestCacheKey(KeyObject, Request.Key->Key))
+ {
+ ZEN_WARN("GetCacheChunks: Invalid key in ChunkRequest.");
+ return false;
+ }
+
+ RequestKey.ChunkId = RequestObject["ChunkId"sv].AsHash();
+ RequestKey.ValueId = RequestObject["ValueId"sv].AsObjectId();
+ RequestKey.RawOffset = RequestObject["RawOffset"sv].AsUInt64();
+ RequestKey.RawSize = RequestObject["RawSize"sv].AsUInt64(UINT64_MAX);
+ Request.RequestedSize = RequestKey.RawSize;
+ Request.RequestedOffset = RequestKey.RawOffset;
+ std::string_view PolicyText = RequestObject["Policy"sv].AsString();
+ Request.DownstreamPolicy = !PolicyText.empty() ? ParseCachePolicy(PolicyText) : DefaultPolicy;
+ Request.IsRecordRequest = (bool)RequestKey.ValueId;
+
+ if (!Request.IsRecordRequest)
+ {
+ ValueRequests.push_back(&Request);
+ }
+ else
+ {
+ RecordRequests.push_back(&Request);
+ CacheKeyRequest* RecordKey = nullptr;
+ RecordBody* Record = nullptr;
+
+ if (!PreviousRecordKey || PreviousRecordKey->Key < RequestKey.Key)
+ {
+ RecordKey = &RecordKeys.emplace_back();
+ PreviousRecordKey = RecordKey;
+ Record = &Records.emplace_back();
+ PreviousRecord = Record;
+ RecordKey->Key = RequestKey.Key;
+ }
+ else if (RequestKey.Key == PreviousRecordKey->Key)
+ {
+ RecordKey = PreviousRecordKey;
+ Record = PreviousRecord;
+ }
+ else
+ {
+ ZEN_WARN("GetCacheChunks: Keys in ChunkRequest are not sorted: {}/{} came after {}/{}.",
+ RequestKey.Key.Bucket,
+ RequestKey.Key.Hash,
+ PreviousRecordKey->Key.Bucket,
+ PreviousRecordKey->Key.Hash);
+ return false;
+ }
+ Request.Record = Record;
+ if (RequestKey.ChunkId == RequestKey.ChunkId.Zero)
+ {
+ Record->DownstreamPolicy =
+ Record->HasRequest ? Union(Record->DownstreamPolicy, Request.DownstreamPolicy) : Request.DownstreamPolicy;
+ Record->HasRequest = true;
+ }
+ }
+ }
+ if (Requests.empty())
+ {
+ return false;
+ }
+ return true;
+}
+
+void
+HttpStructuredCacheService::GetLocalCacheRecords(std::string_view Namespace,
+ std::vector<CacheKeyRequest>& RecordKeys,
+ std::vector<cache::detail::RecordBody>& Records,
+ std::vector<cache::detail::ChunkRequest*>& RecordRequests,
+ std::vector<CacheChunkRequest*>& OutUpstreamChunks)
+{
+ using namespace cache::detail;
+
+ std::vector<CacheKeyRequest*> UpstreamRecordRequests;
+ for (size_t RecordIndex = 0; RecordIndex < Records.size(); ++RecordIndex)
+ {
+ Stopwatch Timer;
+ CacheKeyRequest& RecordKey = RecordKeys[RecordIndex];
+ RecordBody& Record = Records[RecordIndex];
+ if (Record.HasRequest)
+ {
+ Record.DownstreamPolicy |= CachePolicy::SkipData | CachePolicy::SkipMeta;
+
+ if (!Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue CacheValue;
+ if (m_CacheStore.Get(Namespace, RecordKey.Key.Bucket, RecordKey.Key.Hash, CacheValue))
+ {
+ Record.Exists = true;
+ Record.CacheValue = std::move(CacheValue.Value);
+ }
+ }
+ if (!Record.Exists && EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ RecordKey.Policy = CacheRecordPolicy(ConvertToUpstream(Record.DownstreamPolicy));
+ UpstreamRecordRequests.push_back(&RecordKey);
+ }
+ RecordRequests[RecordIndex]->ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+ }
+
+ if (!UpstreamRecordRequests.empty())
+ {
+ const auto OnCacheRecordGetComplete =
+ [this, Namespace, &RecordKeys, &Records, &RecordRequests](CacheRecordGetCompleteParams&& Params) {
+ if (!Params.Record)
+ {
+ return;
+ }
+ CacheKeyRequest& RecordKey = Params.Request;
+ size_t RecordIndex = std::distance(RecordKeys.data(), &RecordKey);
+ RecordRequests[RecordIndex]->ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
+ RecordBody& Record = Records[RecordIndex];
+
+ const CacheKey& Key = RecordKey.Key;
+ Record.Exists = true;
+ CbObject ObjectBuffer = CbObject::Clone(Params.Record);
+ Record.CacheValue = ObjectBuffer.GetBuffer().AsIoBuffer();
+ Record.CacheValue.SetContentType(ZenContentType::kCbObject);
+ Record.Source = Params.Source;
+
+ if (EnumHasAllFlags(Record.DownstreamPolicy, CachePolicy::StoreLocal))
+ {
+ m_CacheStore.Put(Namespace, Key.Bucket, Key.Hash, {.Value = Record.CacheValue});
+ }
+ };
+ m_UpstreamCache.GetCacheRecords(Namespace, UpstreamRecordRequests, std::move(OnCacheRecordGetComplete));
+ }
+
+ std::vector<CacheChunkRequest*> UpstreamPayloadRequests;
+ for (ChunkRequest* Request : RecordRequests)
+ {
+ Stopwatch Timer;
+ if (Request->Key->ChunkId == IoHash::Zero)
+ {
+ // Unreal uses a 12 byte ID to address cache record values. When the uncompressed hash (ChunkId)
+ // is missing, parse the cache record and try to find the raw hash from the ValueId.
+ RecordBody& Record = *Request->Record;
+ if (!Record.ValuesRead)
+ {
+ Record.ValuesRead = true;
+ if (Record.CacheValue && Record.CacheValue.GetContentType() == ZenContentType::kCbObject)
+ {
+ CbObjectView RecordObject = CbObjectView(Record.CacheValue.GetData());
+ CbArrayView ValuesArray = RecordObject["Values"sv].AsArrayView();
+ Record.Values.reserve(ValuesArray.Num());
+ for (CbFieldView ValueField : ValuesArray)
+ {
+ CbObjectView ValueObject = ValueField.AsObjectView();
+ Oid ValueId = ValueObject["Id"sv].AsObjectId();
+ CbFieldView RawHashField = ValueObject["RawHash"sv];
+ IoHash RawHash = RawHashField.AsBinaryAttachment();
+ if (ValueId && !RawHashField.HasError())
+ {
+ Record.Values.push_back({ValueId, RawHash, ValueObject["RawSize"sv].AsUInt64()});
+ }
+ }
+ }
+ }
+
+ for (const RecordValue& Value : Record.Values)
+ {
+ if (Value.ValueId == Request->Key->ValueId)
+ {
+ Request->Key->ChunkId = Value.ContentId;
+ Request->RawSize = Value.RawSize;
+ Request->RawSizeKnown = true;
+ break;
+ }
+ }
+ }
+
+ // Now load the ContentId from the local ContentIdStore or from the upstream
+ if (Request->Key->ChunkId != IoHash::Zero)
+ {
+ if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal))
+ {
+ if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData) && Request->RawSizeKnown)
+ {
+ if (m_CidStore.ContainsChunk(Request->Key->ChunkId))
+ {
+ Request->Exists = true;
+ }
+ }
+ else if (IoBuffer Payload = m_CidStore.FindChunkByCid(Request->Key->ChunkId))
+ {
+ if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(Payload));
+ if (Request->Value)
+ {
+ Request->Exists = true;
+ Request->RawSizeKnown = false;
+ }
+ }
+ else
+ {
+ IoHash _;
+ if (CompressedBuffer::ValidateCompressedHeader(Payload, _, Request->RawSize))
+ {
+ Request->Exists = true;
+ Request->RawSizeKnown = true;
+ }
+ }
+ }
+ }
+ if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ Request->Key->Policy = ConvertToUpstream(Request->DownstreamPolicy);
+ OutUpstreamChunks.push_back(Request->Key);
+ }
+ }
+ Request->ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+}
+
+void
+HttpStructuredCacheService::GetLocalCacheValues(std::string_view Namespace,
+ std::vector<cache::detail::ChunkRequest*>& ValueRequests,
+ std::vector<CacheChunkRequest*>& OutUpstreamChunks)
+{
+ using namespace cache::detail;
+
+ for (ChunkRequest* Request : ValueRequests)
+ {
+ Stopwatch Timer;
+ if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryLocal))
+ {
+ ZenCacheValue CacheValue;
+ if (m_CacheStore.Get(Namespace, Request->Key->Key.Bucket, Request->Key->Key.Hash, CacheValue))
+ {
+ if (IsCompressedBinary(CacheValue.Value.GetContentType()))
+ {
+ Request->Key->ChunkId = CacheValue.RawHash;
+ Request->Exists = true;
+ Request->RawSize = CacheValue.RawSize;
+ Request->RawSizeKnown = true;
+ if (!EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Request->Value = CompressedBuffer::FromCompressedNoValidate(std::move(CacheValue.Value));
+ }
+ }
+ }
+ }
+ if (!Request->Exists && EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::QueryRemote))
+ {
+ if (EnumHasAllFlags(Request->DownstreamPolicy, CachePolicy::StoreLocal))
+ {
+ // Convert the Offset,Size request into a request for the entire value; we will need it all to be able to store it locally
+ Request->Key->RawOffset = 0;
+ Request->Key->RawSize = UINT64_MAX;
+ }
+ OutUpstreamChunks.push_back(Request->Key);
+ }
+ Request->ElapsedTimeUs += Timer.GetElapsedTimeUs();
+ }
+}
+
+void
+HttpStructuredCacheService::GetUpstreamCacheChunks(std::string_view Namespace,
+ std::vector<CacheChunkRequest*>& UpstreamChunks,
+ std::vector<CacheChunkRequest>& RequestKeys,
+ std::vector<cache::detail::ChunkRequest>& Requests)
+{
+ using namespace cache::detail;
+
+ if (!UpstreamChunks.empty())
+ {
+ const auto OnCacheChunksGetComplete = [this, Namespace, &RequestKeys, &Requests](CacheChunkGetCompleteParams&& Params) {
+ if (Params.RawHash == Params.RawHash.Zero)
+ {
+ return;
+ }
+
+ CacheChunkRequest& Key = Params.Request;
+ size_t RequestIndex = std::distance(RequestKeys.data(), &Key);
+ ChunkRequest& Request = Requests[RequestIndex];
+ Request.ElapsedTimeUs += static_cast<uint64_t>(Params.ElapsedSeconds * 1000000.0);
+ if (EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal) ||
+ !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(IoBuffer(Params.Value));
+ if (!Compressed)
+ {
+ return;
+ }
+
+ if (EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::StoreLocal))
+ {
+ if (Request.IsRecordRequest)
+ {
+ m_CidStore.AddChunk(Params.Value, Params.RawHash);
+ }
+ else
+ {
+ m_CacheStore.Put(Namespace,
+ Key.Key.Bucket,
+ Key.Key.Hash,
+ {.Value = Params.Value, .RawSize = Params.RawSize, .RawHash = Params.RawHash});
+ }
+ }
+ if (!EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ Request.Value = std::move(Compressed);
+ }
+ }
+ Key.ChunkId = Params.RawHash;
+ Request.Exists = true;
+ Request.RawSize = Params.RawSize;
+ Request.RawSizeKnown = true;
+ Request.Source = Params.Source;
+
+ m_CacheStats.UpstreamHitCount++;
+ };
+
+ m_UpstreamCache.GetCacheChunks(Namespace, UpstreamChunks, std::move(OnCacheChunksGetComplete));
+ }
+}
+
+CbPackage
+HttpStructuredCacheService::WriteGetCacheChunksResponse(std::string_view Namespace, std::vector<cache::detail::ChunkRequest>& Requests)
+{
+ using namespace cache::detail;
+
+ CbPackage RpcResponse;
+ CbObjectWriter Writer;
+
+ Writer.BeginArray("Result"sv);
+ for (ChunkRequest& Request : Requests)
+ {
+ Writer.BeginObject();
+ {
+ if (Request.Exists)
+ {
+ Writer.AddHash("RawHash"sv, Request.Key->ChunkId);
+ if (Request.Value && !EnumHasAllFlags(Request.DownstreamPolicy, CachePolicy::SkipData))
+ {
+ RpcResponse.AddAttachment(CbAttachment(Request.Value, Request.Key->ChunkId));
+ }
+ else
+ {
+ Writer.AddInteger("RawSize"sv, Request.RawSize);
+ }
+
+ ZEN_DEBUG("GETCACHECHUNKS HIT - '{}/{}/{}/{}' {} '{}' ({}) in {}",
+ Namespace,
+ Request.Key->Key.Bucket,
+ Request.Key->Key.Hash,
+ Request.Key->ValueId,
+ NiceBytes(Request.RawSize),
+ Request.IsRecordRequest ? "Record"sv : "Value"sv,
+ Request.Source ? Request.Source->Url : "LOCAL"sv,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.HitCount++;
+ }
+ else if (!EnumHasAnyFlags(Request.DownstreamPolicy, CachePolicy::Query))
+ {
+ ZEN_DEBUG("GETCACHECHUNKS DISABLEDQUERY - '{}/{}/{}/{}' in {}",
+ Namespace,
+ Request.Key->Key.Bucket,
+ Request.Key->Key.Hash,
+ Request.Key->ValueId,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ }
+ else
+ {
+ ZEN_DEBUG("GETCACHECHUNKS MISS - '{}/{}/{}/{}' in {}",
+ Namespace,
+ Request.Key->Key.Bucket,
+ Request.Key->Key.Hash,
+ Request.Key->ValueId,
+ NiceLatencyNs(Request.ElapsedTimeUs * 1000));
+ m_CacheStats.MissCount++;
+ }
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+
+ RpcResponse.SetObject(Writer.Save());
+ return RpcResponse;
+}
+
+void
+HttpStructuredCacheService::HandleStatsRequest(HttpServerRequest& Request)
+{
+ CbObjectWriter Cbo;
+
+ EmitSnapshot("requests", m_HttpRequests, Cbo);
+ EmitSnapshot("upstream_gets", m_UpstreamGetRequestTiming, Cbo);
+
+ const uint64_t HitCount = m_CacheStats.HitCount;
+ const uint64_t UpstreamHitCount = m_CacheStats.UpstreamHitCount;
+ const uint64_t MissCount = m_CacheStats.MissCount;
+ const uint64_t TotalCount = HitCount + MissCount;
+
+ const CidStoreSize CidSize = m_CidStore.TotalSize();
+ const GcStorageSize CacheSize = m_CacheStore.StorageSize();
+
+ Cbo.BeginObject("cache");
+ {
+ Cbo.BeginObject("size");
+ {
+ Cbo << "disk" << CacheSize.DiskSize;
+ Cbo << "memory" << CacheSize.MemorySize;
+ }
+ Cbo.EndObject();
+
+ Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
+ Cbo << "hits" << HitCount << "misses" << MissCount;
+ Cbo << "hit_ratio" << (TotalCount > 0 ? (double(HitCount) / double(TotalCount)) : 0.0);
+ Cbo << "upstream_hits" << m_CacheStats.UpstreamHitCount;
+ Cbo << "upstream_ratio" << (HitCount > 0 ? (double(UpstreamHitCount) / double(HitCount)) : 0.0);
+ }
+ Cbo.EndObject();
+
+ Cbo.BeginObject("upstream");
+ {
+ m_UpstreamCache.GetStatus(Cbo);
+ }
+ Cbo.EndObject();
+
+ Cbo.BeginObject("cid");
+ {
+ Cbo.BeginObject("size");
+ {
+ Cbo << "tiny" << CidSize.TinySize;
+ Cbo << "small" << CidSize.SmallSize;
+ Cbo << "large" << CidSize.LargeSize;
+ Cbo << "total" << CidSize.TotalSize;
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndObject();
+
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+}
+
+void
+HttpStructuredCacheService::HandleStatusRequest(HttpServerRequest& Request)
+{
+ CbObjectWriter Cbo;
+ Cbo << "ok" << true;
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+}
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("z$service.parse.relative.Uri")
+{
+ HttpRequestData RootRequest;
+ CHECK(HttpRequestParseRelativeUri("", RootRequest));
+ CHECK(!RootRequest.Namespace.has_value());
+ CHECK(!RootRequest.Bucket.has_value());
+ CHECK(!RootRequest.HashKey.has_value());
+ CHECK(!RootRequest.ValueContentId.has_value());
+
+ RootRequest = {};
+ CHECK(HttpRequestParseRelativeUri("/", RootRequest));
+ CHECK(!RootRequest.Namespace.has_value());
+ CHECK(!RootRequest.Bucket.has_value());
+ CHECK(!RootRequest.HashKey.has_value());
+ CHECK(!RootRequest.ValueContentId.has_value());
+
+ HttpRequestData LegacyBucketRequestBecomesNamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("test", LegacyBucketRequestBecomesNamespaceRequest));
+ CHECK(LegacyBucketRequestBecomesNamespaceRequest.Namespace == "test"sv);
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.Bucket.has_value());
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.HashKey.has_value());
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData LegacyHashKeyRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", LegacyHashKeyRequest));
+ CHECK(LegacyHashKeyRequest.Namespace == ZenCacheStore::DefaultNamespace);
+ CHECK(LegacyHashKeyRequest.Bucket == "test"sv);
+ CHECK(LegacyHashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv));
+ CHECK(!LegacyHashKeyRequest.ValueContentId.has_value());
+
+ HttpRequestData LegacyValueContentIdRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789",
+ LegacyValueContentIdRequest));
+ CHECK(LegacyValueContentIdRequest.Namespace == ZenCacheStore::DefaultNamespace);
+ CHECK(LegacyValueContentIdRequest.Bucket == "test"sv);
+ CHECK(LegacyValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv));
+ CHECK(LegacyValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"sv));
+
+ HttpRequestData V2DefaultNamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("ue4.ddc", V2DefaultNamespaceRequest));
+ CHECK(V2DefaultNamespaceRequest.Namespace == "ue4.ddc");
+ CHECK(!V2DefaultNamespaceRequest.Bucket.has_value());
+ CHECK(!V2DefaultNamespaceRequest.HashKey.has_value());
+ CHECK(!V2DefaultNamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData V2NamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("nicenamespace", V2NamespaceRequest));
+ CHECK(V2NamespaceRequest.Namespace == "nicenamespace"sv);
+ CHECK(!V2NamespaceRequest.Bucket.has_value());
+ CHECK(!V2NamespaceRequest.HashKey.has_value());
+ CHECK(!V2NamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData V2BucketRequestWithDefaultNamespace;
+ CHECK(HttpRequestParseRelativeUri("ue4.ddc/test", V2BucketRequestWithDefaultNamespace));
+ CHECK(V2BucketRequestWithDefaultNamespace.Namespace == "ue4.ddc");
+ CHECK(V2BucketRequestWithDefaultNamespace.Bucket == "test"sv);
+ CHECK(!V2BucketRequestWithDefaultNamespace.HashKey.has_value());
+ CHECK(!V2BucketRequestWithDefaultNamespace.ValueContentId.has_value());
+
+ HttpRequestData V2BucketRequestWithNamespace;
+ CHECK(HttpRequestParseRelativeUri("nicenamespace/test", V2BucketRequestWithNamespace));
+ CHECK(V2BucketRequestWithNamespace.Namespace == "nicenamespace"sv);
+ CHECK(V2BucketRequestWithNamespace.Bucket == "test"sv);
+ CHECK(!V2BucketRequestWithNamespace.HashKey.has_value());
+ CHECK(!V2BucketRequestWithNamespace.ValueContentId.has_value());
+
+ HttpRequestData V2HashKeyRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", V2HashKeyRequest));
+ CHECK(V2HashKeyRequest.Namespace == ZenCacheStore::DefaultNamespace);
+ CHECK(V2HashKeyRequest.Bucket == "test");
+ CHECK(V2HashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv));
+ CHECK(!V2HashKeyRequest.ValueContentId.has_value());
+
+ HttpRequestData V2ValueContentIdRequest;
+ CHECK(
+ HttpRequestParseRelativeUri("nicenamespace/test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789",
+ V2ValueContentIdRequest));
+ CHECK(V2ValueContentIdRequest.Namespace == "nicenamespace"sv);
+ CHECK(V2ValueContentIdRequest.Bucket == "test"sv);
+ CHECK(V2ValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"sv));
+ CHECK(V2ValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"sv));
+
+ HttpRequestData Invalid;
+ CHECK(!HttpRequestParseRelativeUri("bad\2_namespace", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("nice/\2\1bucket", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789a", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcdef1234", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/pppppppp89abcdef12340123456789abcdef1234", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcd", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/ppppppppdef12345678956789abcdef123456789",
+ Invalid));
+}
+
+#endif
+
+void
+z$service_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenserver/cache/structuredcache.h b/src/zenserver/cache/structuredcache.h
new file mode 100644
index 000000000..4e7b98ac9
--- /dev/null
+++ b/src/zenserver/cache/structuredcache.h
@@ -0,0 +1,187 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/stats.h>
+#include <zenhttp/httpserver.h>
+
+#include "monitoring/httpstats.h"
+#include "monitoring/httpstatus.h"
+
+#include <memory>
+#include <vector>
+
+namespace spdlog {
+class logger;
+}
+
+namespace zen {
+
+struct CacheChunkRequest;
+struct CacheKeyRequest;
+class CidStore;
+class CbObjectView;
+struct PutRequestData;
+class ScrubContext;
+class UpstreamCache;
+class ZenCacheStore;
+enum class CachePolicy : uint32_t;
+enum class RpcAcceptOptions : uint16_t;
+
+namespace cache {
+ class IRpcRequestReplayer;
+ class IRpcRequestRecorder;
+ namespace detail {
+ struct RecordBody;
+ struct ChunkRequest;
+ } // namespace detail
+} // namespace cache
+
+/**
+ * Structured cache service. Imposes constraints on keys, supports blobs and
+ * structured values
+ *
+ * Keys are structured as:
+ *
+ * {BucketId}/{KeyHash}
+ *
+ * Where BucketId is a lower-case alphanumeric string, and KeyHash is a 40-character
+ * hexadecimal sequence. The hash value may be derived in any number of ways, it's
+ * up to the application to pick an approach.
+ *
+ * Values may be structured or unstructured. Structured values are encoded using Unreal
+ * Engine's compact binary encoding (see CbObject)
+ *
+ * Additionally, attachments may be addressed as:
+ *
+ * {BucketId}/{KeyHash}/{ValueHash}
+ *
+ * Where the two initial components are the same as for the main endpoint
+ *
+ * The storage strategy is as follows:
+ *
+ * - Structured values are stored in a dedicated backing store per bucket
+ * - Unstructured values and attachments are stored in the CAS pool
+ *
+ */
+
+class HttpStructuredCacheService : public HttpService, public IHttpStatsProvider, public IHttpStatusProvider
+{
+public:
+ HttpStructuredCacheService(ZenCacheStore& InCacheStore,
+ CidStore& InCidStore,
+ HttpStatsService& StatsService,
+ HttpStatusService& StatusService,
+ UpstreamCache& UpstreamCache);
+ ~HttpStructuredCacheService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+
+private:
+ struct CacheRef
+ {
+ std::string Namespace;
+ std::string BucketSegment;
+ IoHash HashKey;
+ IoHash ValueContentId;
+ };
+
+ struct CacheStats
+ {
+ std::atomic_uint64_t HitCount{};
+ std::atomic_uint64_t UpstreamHitCount{};
+ std::atomic_uint64_t MissCount{};
+ };
+ enum class PutResult
+ {
+ Success,
+ Fail,
+ Invalid,
+ };
+
+ void HandleCacheRecordRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandleGetCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandlePutCacheRecord(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandleCacheChunkRequest(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandleGetCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandlePutCacheChunk(HttpServerRequest& Request, const CacheRef& Ref, CachePolicy PolicyFromUrl);
+ void HandleRpcRequest(HttpServerRequest& Request);
+ void HandleDetailsRequest(HttpServerRequest& Request);
+
+ CbPackage HandleRpcPutCacheRecords(const CbPackage& BatchRequest);
+ CbPackage HandleRpcGetCacheRecords(CbObjectView BatchRequest);
+ CbPackage HandleRpcPutCacheValues(const CbPackage& BatchRequest);
+ CbPackage HandleRpcGetCacheValues(CbObjectView BatchRequest);
+ CbPackage HandleRpcGetCacheChunks(CbObjectView BatchRequest);
+ CbPackage HandleRpcRequest(const ZenContentType ContentType,
+ IoBuffer&& Body,
+ uint32_t& OutAcceptMagic,
+ RpcAcceptOptions& OutAcceptFlags,
+ int& OutTargetProcessId);
+
+ void HandleCacheRequest(HttpServerRequest& Request);
+ void HandleCacheNamespaceRequest(HttpServerRequest& Request, std::string_view Namespace);
+ void HandleCacheBucketRequest(HttpServerRequest& Request, std::string_view Namespace, std::string_view Bucket);
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+ virtual void HandleStatusRequest(HttpServerRequest& Request) override;
+ PutResult PutCacheRecord(PutRequestData& Request, const CbPackage* Package);
+
+ /** HandleRpcGetCacheChunks Helper: Parse the Body object into RecordValue Requests and Value Requests. */
+ bool ParseGetCacheChunksRequest(std::string& Namespace,
+ std::vector<CacheKeyRequest>& RecordKeys,
+ std::vector<cache::detail::RecordBody>& Records,
+ std::vector<CacheChunkRequest>& RequestKeys,
+ std::vector<cache::detail::ChunkRequest>& Requests,
+ std::vector<cache::detail::ChunkRequest*>& RecordRequests,
+ std::vector<cache::detail::ChunkRequest*>& ValueRequests,
+ CbObjectView RpcRequest);
+ /** HandleRpcGetCacheChunks Helper: Load records to get ContentId for RecordRequests, and load their payloads if they exist locally. */
+ void GetLocalCacheRecords(std::string_view Namespace,
+ std::vector<CacheKeyRequest>& RecordKeys,
+ std::vector<cache::detail::RecordBody>& Records,
+ std::vector<cache::detail::ChunkRequest*>& RecordRequests,
+ std::vector<CacheChunkRequest*>& OutUpstreamChunks);
+ /** HandleRpcGetCacheChunks Helper: For ValueRequests, load their payloads if they exist locally. */
+ void GetLocalCacheValues(std::string_view Namespace,
+ std::vector<cache::detail::ChunkRequest*>& ValueRequests,
+ std::vector<CacheChunkRequest*>& OutUpstreamChunks);
+ /** HandleRpcGetCacheChunks Helper: Load payloads from upstream that did not exist locally. */
+ void GetUpstreamCacheChunks(std::string_view Namespace,
+ std::vector<CacheChunkRequest*>& UpstreamChunks,
+ std::vector<CacheChunkRequest>& RequestKeys,
+ std::vector<cache::detail::ChunkRequest>& Requests);
+ /** HandleRpcGetCacheChunks Helper: Send response message containing all chunk results. */
+ CbPackage WriteGetCacheChunksResponse(std::string_view Namespace, std::vector<cache::detail::ChunkRequest>& Requests);
+
+ spdlog::logger& Log() { return m_Log; }
+ spdlog::logger& m_Log;
+ ZenCacheStore& m_CacheStore;
+ HttpStatsService& m_StatsService;
+ HttpStatusService& m_StatusService;
+ CidStore& m_CidStore;
+ UpstreamCache& m_UpstreamCache;
+ uint64_t m_LastScrubTime = 0;
+ metrics::OperationTiming m_HttpRequests;
+ metrics::OperationTiming m_UpstreamGetRequestTiming;
+ CacheStats m_CacheStats;
+
+ void ReplayRequestRecorder(cache::IRpcRequestReplayer& Replayer, uint32_t ThreadCount);
+
+ std::unique_ptr<cache::IRpcRequestRecorder> m_RequestRecorder;
+};
+
+/** Recognize both kBinary and kCompressedBinary as kCompressedBinary for structured cache value keys.
+ * We need this until the content type is preserved for kCompressedBinary when passing to and from upstream servers. */
+inline bool
+IsCompressedBinary(ZenContentType Type)
+{
+ return Type == ZenContentType::kBinary || Type == ZenContentType::kCompressedBinary;
+}
+
+void z$service_forcelink();
+
+} // namespace zen
diff --git a/src/zenserver/cache/structuredcachestore.cpp b/src/zenserver/cache/structuredcachestore.cpp
new file mode 100644
index 000000000..26e970073
--- /dev/null
+++ b/src/zenserver/cache/structuredcachestore.cpp
@@ -0,0 +1,3648 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "structuredcachestore.h"
+
+#include <zencore/except.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/compress.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+#include <zenstore/cidstore.h>
+#include <zenstore/scrubcontext.h>
+
+#include <xxhash.h>
+
+#include <limits>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#endif
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/core.h>
+#include <gsl/gsl-lite.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+# include <zencore/workthreadpool.h>
+# include <random>
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+namespace {
+
+#pragma pack(push)
+#pragma pack(1)
+
+ struct CacheBucketIndexHeader
+ {
+ static constexpr uint32_t ExpectedMagic = 0x75696478; // 'uidx';
+ static constexpr uint32_t Version2 = 2;
+ static constexpr uint32_t CurrentVersion = Version2;
+
+ uint32_t Magic = ExpectedMagic;
+ uint32_t Version = CurrentVersion;
+ uint64_t EntryCount = 0;
+ uint64_t LogPosition = 0;
+ uint32_t PayloadAlignment = 0;
+ uint32_t Checksum = 0;
+
+ static uint32_t ComputeChecksum(const CacheBucketIndexHeader& Header)
+ {
+ return XXH32(&Header.Magic, sizeof(CacheBucketIndexHeader) - sizeof(uint32_t), 0xC0C0'BABA);
+ }
+ };
+
+ static_assert(sizeof(CacheBucketIndexHeader) == 32);
+
+#pragma pack(pop)
+
+ const char* IndexExtension = ".uidx";
+ const char* LogExtension = ".slog";
+
+ std::filesystem::path GetIndexPath(const std::filesystem::path& BucketDir, const std::string& BucketName)
+ {
+ return BucketDir / (BucketName + IndexExtension);
+ }
+
+ std::filesystem::path GetTempIndexPath(const std::filesystem::path& BucketDir, const std::string& BucketName)
+ {
+ return BucketDir / (BucketName + ".tmp");
+ }
+
+ std::filesystem::path GetLogPath(const std::filesystem::path& BucketDir, const std::string& BucketName)
+ {
+ return BucketDir / (BucketName + LogExtension);
+ }
+
+ bool ValidateEntry(const DiskIndexEntry& Entry, std::string& OutReason)
+ {
+ if (Entry.Key == IoHash::Zero)
+ {
+ OutReason = fmt::format("Invalid hash key {}", Entry.Key.ToHexString());
+ return false;
+ }
+ if (Entry.Location.GetFlags() &
+ ~(DiskLocation::kStandaloneFile | DiskLocation::kStructured | DiskLocation::kTombStone | DiskLocation::kCompressed))
+ {
+ OutReason = fmt::format("Invalid flags {} for entry {}", Entry.Location.GetFlags(), Entry.Key.ToHexString());
+ return false;
+ }
+ if (Entry.Location.IsFlagSet(DiskLocation::kTombStone))
+ {
+ return true;
+ }
+ if (Entry.Location.Reserved != 0)
+ {
+ OutReason = fmt::format("Invalid reserved field {} for entry {}", Entry.Location.Reserved, Entry.Key.ToHexString());
+ return false;
+ }
+ uint64_t Size = Entry.Location.Size();
+ if (Size == 0)
+ {
+ OutReason = fmt::format("Invalid size {} for entry {}", Size, Entry.Key.ToHexString());
+ return false;
+ }
+ return true;
+ }
+
+ bool MoveAndDeleteDirectory(const std::filesystem::path& Dir)
+ {
+ int DropIndex = 0;
+ do
+ {
+ if (!std::filesystem::exists(Dir))
+ {
+ return false;
+ }
+
+ std::string DroppedName = fmt::format("[dropped]{}({})", Dir.filename().string(), DropIndex);
+ std::filesystem::path DroppedBucketPath = Dir.parent_path() / DroppedName;
+ if (std::filesystem::exists(DroppedBucketPath))
+ {
+ DropIndex++;
+ continue;
+ }
+
+ std::error_code Ec;
+ std::filesystem::rename(Dir, DroppedBucketPath, Ec);
+ if (!Ec)
+ {
+ DeleteDirectories(DroppedBucketPath);
+ return true;
+ }
+ // TODO: Do we need to bail at some point?
+ zen::Sleep(100);
+ } while (true);
+ }
+
+} // namespace
+
+namespace fs = std::filesystem;
+
+static CbObject
+LoadCompactBinaryObject(const fs::path& Path)
+{
+ FileContents Result = ReadFile(Path);
+
+ if (!Result.ErrorCode)
+ {
+ IoBuffer Buffer = Result.Flatten();
+ if (CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All); Error == CbValidateError::None)
+ {
+ return LoadCompactBinaryObject(Buffer);
+ }
+ }
+
+ return CbObject();
+}
+
+static void
+SaveCompactBinaryObject(const fs::path& Path, const CbObject& Object)
+{
+ WriteFile(Path, Object.GetBuffer().AsIoBuffer());
+}
+
+ZenCacheNamespace::ZenCacheNamespace(GcManager& Gc, const std::filesystem::path& RootDir)
+: GcStorage(Gc)
+, GcContributor(Gc)
+, m_RootDir(RootDir)
+, m_DiskLayer(RootDir)
+{
+ ZEN_INFO("initializing structured cache at '{}'", RootDir);
+ CreateDirectories(RootDir);
+
+ m_DiskLayer.DiscoverBuckets();
+
+#if ZEN_USE_CACHE_TRACKER
+ m_AccessTracker.reset(new ZenCacheTracker(RootDir));
+#endif
+}
+
+ZenCacheNamespace::~ZenCacheNamespace()
+{
+}
+
+bool
+ZenCacheNamespace::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ ZEN_TRACE_CPU("Z$::Get");
+
+ bool Ok = m_MemLayer.Get(InBucket, HashKey, OutValue);
+
+#if ZEN_USE_CACHE_TRACKER
+ auto _ = MakeGuard([&] {
+ if (!Ok)
+ return;
+
+ m_AccessTracker->TrackAccess(InBucket, HashKey);
+ });
+#endif
+
+ if (Ok)
+ {
+ ZEN_ASSERT(OutValue.Value.Size());
+
+ return true;
+ }
+
+ Ok = m_DiskLayer.Get(InBucket, HashKey, OutValue);
+
+ if (Ok)
+ {
+ ZEN_ASSERT(OutValue.Value.Size());
+
+ if (OutValue.Value.Size() <= m_DiskLayerSizeThreshold)
+ {
+ m_MemLayer.Put(InBucket, HashKey, OutValue);
+ }
+ }
+
+ return Ok;
+}
+
+void
+ZenCacheNamespace::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ ZEN_TRACE_CPU("Z$::Put");
+
+ // Store value and index
+
+ ZEN_ASSERT(Value.Value.Size());
+
+ m_DiskLayer.Put(InBucket, HashKey, Value);
+
+#if ZEN_USE_REF_TRACKING
+ if (Value.Value.GetContentType() == ZenContentType::kCbObject)
+ {
+ if (ValidateCompactBinary(Value.Value, CbValidateMode::All) == CbValidateError::None)
+ {
+ CbObject Object{SharedBuffer(Value.Value)};
+
+ uint8_t TempBuffer[8 * sizeof(IoHash)];
+ std::pmr::monotonic_buffer_resource Linear{TempBuffer, sizeof TempBuffer};
+ std::pmr::polymorphic_allocator Allocator{&Linear};
+ std::pmr::vector<IoHash> CidReferences{Allocator};
+
+ Object.IterateAttachments([&](CbFieldView Field) { CidReferences.push_back(Field.AsAttachment()); });
+
+ m_Gc.OnNewCidReferences(CidReferences);
+ }
+ }
+#endif
+
+ if (Value.Value.Size() <= m_DiskLayerSizeThreshold)
+ {
+ m_MemLayer.Put(InBucket, HashKey, Value);
+ }
+}
+
+bool
+ZenCacheNamespace::DropBucket(std::string_view Bucket)
+{
+ ZEN_INFO("dropping bucket '{}'", Bucket);
+
+ // TODO: should ensure this is done atomically across all layers
+
+ const bool MemDropped = m_MemLayer.DropBucket(Bucket);
+ const bool DiskDropped = m_DiskLayer.DropBucket(Bucket);
+ const bool AnyDropped = MemDropped || DiskDropped;
+
+ ZEN_INFO("bucket '{}' was {}", Bucket, AnyDropped ? "dropped" : "not found");
+
+ return AnyDropped;
+}
+
+bool
+ZenCacheNamespace::Drop()
+{
+ m_MemLayer.Drop();
+ return m_DiskLayer.Drop();
+}
+
+void
+ZenCacheNamespace::Flush()
+{
+ m_DiskLayer.Flush();
+}
+
+void
+ZenCacheNamespace::Scrub(ScrubContext& Ctx)
+{
+ if (m_LastScrubTime == Ctx.ScrubTimestamp())
+ {
+ return;
+ }
+
+ m_LastScrubTime = Ctx.ScrubTimestamp();
+
+ m_DiskLayer.Scrub(Ctx);
+ m_MemLayer.Scrub(Ctx);
+}
+
+void
+ZenCacheNamespace::GatherReferences(GcContext& GcCtx)
+{
+ Stopwatch Timer;
+ const auto Guard =
+ MakeGuard([&] { ZEN_DEBUG("cache gathered all references from '{}' in {}", m_RootDir, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); });
+
+ access_tracking::AccessTimes AccessTimes;
+ m_MemLayer.GatherAccessTimes(AccessTimes);
+
+ m_DiskLayer.UpdateAccessTimes(AccessTimes);
+ m_DiskLayer.GatherReferences(GcCtx);
+}
+
+void
+ZenCacheNamespace::CollectGarbage(GcContext& GcCtx)
+{
+ m_MemLayer.Reset();
+ m_DiskLayer.CollectGarbage(GcCtx);
+}
+
+GcStorageSize
+ZenCacheNamespace::StorageSize() const
+{
+ return {.DiskSize = m_DiskLayer.TotalSize(), .MemorySize = m_MemLayer.TotalSize()};
+}
+
+ZenCacheNamespace::Info
+ZenCacheNamespace::GetInfo() const
+{
+ ZenCacheNamespace::Info Info = {.Config = {.RootDir = m_RootDir, .DiskLayerThreshold = m_DiskLayerSizeThreshold},
+ .DiskLayerInfo = m_DiskLayer.GetInfo(),
+ .MemoryLayerInfo = m_MemLayer.GetInfo()};
+ std::unordered_set<std::string> BucketNames;
+ for (const std::string& BucketName : Info.DiskLayerInfo.BucketNames)
+ {
+ BucketNames.insert(BucketName);
+ }
+ for (const std::string& BucketName : Info.MemoryLayerInfo.BucketNames)
+ {
+ BucketNames.insert(BucketName);
+ }
+ Info.BucketNames.insert(Info.BucketNames.end(), BucketNames.begin(), BucketNames.end());
+ return Info;
+}
+
+std::optional<ZenCacheNamespace::BucketInfo>
+ZenCacheNamespace::GetBucketInfo(std::string_view Bucket) const
+{
+ std::optional<ZenCacheDiskLayer::BucketInfo> DiskBucketInfo = m_DiskLayer.GetBucketInfo(Bucket);
+ if (!DiskBucketInfo.has_value())
+ {
+ return {};
+ }
+ ZenCacheNamespace::BucketInfo Info = {.DiskLayerInfo = *DiskBucketInfo,
+ .MemoryLayerInfo = m_MemLayer.GetBucketInfo(Bucket).value_or(ZenCacheMemoryLayer::BucketInfo{})};
+ return Info;
+}
+
+CacheValueDetails::NamespaceDetails
+ZenCacheNamespace::GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const
+{
+ return m_DiskLayer.GetValueDetails(BucketFilter, ValueFilter);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenCacheMemoryLayer::ZenCacheMemoryLayer()
+{
+}
+
+ZenCacheMemoryLayer::~ZenCacheMemoryLayer()
+{
+}
+
+bool
+ZenCacheMemoryLayer::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ auto It = m_Buckets.find(std::string(InBucket));
+
+ if (It == m_Buckets.end())
+ {
+ return false;
+ }
+
+ CacheBucket* Bucket = It->second.get();
+
+ _.ReleaseNow();
+
+ // There's a race here. Since the lock is released early to allow
+ // inserts, the bucket delete path could end up deleting the
+ // underlying data structure
+
+ return Bucket->Get(HashKey, OutValue);
+}
+
+void
+ZenCacheMemoryLayer::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ const auto BucketName = std::string(InBucket);
+ CacheBucket* Bucket = nullptr;
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(std::string(InBucket)); It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ }
+
+ if (Bucket == nullptr)
+ {
+ // New bucket
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(std::string(InBucket)); It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ else
+ {
+ auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>());
+ Bucket = InsertResult.first->second.get();
+ }
+ }
+
+ // Note that since the underlying IoBuffer is retained, the content type is also
+
+ Bucket->Put(HashKey, Value);
+}
+
+bool
+ZenCacheMemoryLayer::DropBucket(std::string_view InBucket)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ auto It = m_Buckets.find(std::string(InBucket));
+
+ if (It != m_Buckets.end())
+ {
+ CacheBucket& Bucket = *It->second;
+ m_DroppedBuckets.push_back(std::move(It->second));
+ m_Buckets.erase(It);
+ Bucket.Drop();
+ return true;
+ }
+ return false;
+}
+
+void
+ZenCacheMemoryLayer::Drop()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ std::vector<std::unique_ptr<CacheBucket>> Buckets;
+ Buckets.reserve(m_Buckets.size());
+ while (!m_Buckets.empty())
+ {
+ const auto& It = m_Buckets.begin();
+ CacheBucket& Bucket = *It->second;
+ m_DroppedBuckets.push_back(std::move(It->second));
+ m_Buckets.erase(It->first);
+ Bucket.Drop();
+ }
+}
+
+void
+ZenCacheMemoryLayer::Scrub(ScrubContext& Ctx)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ Kv.second->Scrub(Ctx);
+ }
+}
+
+void
+ZenCacheMemoryLayer::GatherAccessTimes(zen::access_tracking::AccessTimes& AccessTimes)
+{
+ using namespace zen::access_tracking;
+
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ std::vector<KeyAccessTime>& Bucket = AccessTimes.Buckets[Kv.first];
+ Kv.second->GatherAccessTimes(Bucket);
+ }
+}
+
+void
+ZenCacheMemoryLayer::Reset()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Buckets.clear();
+}
+
+uint64_t
+ZenCacheMemoryLayer::TotalSize() const
+{
+ uint64_t TotalSize{};
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ TotalSize += Kv.second->TotalSize();
+ }
+
+ return TotalSize;
+}
+
+ZenCacheMemoryLayer::Info
+ZenCacheMemoryLayer::GetInfo() const
+{
+ ZenCacheMemoryLayer::Info Info = {.Config = m_Configuration, .TotalSize = TotalSize()};
+
+ RwLock::SharedLockScope _(m_Lock);
+ Info.BucketNames.reserve(m_Buckets.size());
+ for (auto& Kv : m_Buckets)
+ {
+ Info.BucketNames.push_back(Kv.first);
+ Info.EntryCount += Kv.second->EntryCount();
+ }
+ return Info;
+}
+
+std::optional<ZenCacheMemoryLayer::BucketInfo>
+ZenCacheMemoryLayer::GetBucketInfo(std::string_view Bucket) const
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(std::string(Bucket)); It != m_Buckets.end())
+ {
+ return ZenCacheMemoryLayer::BucketInfo{.EntryCount = It->second->EntryCount(), .TotalSize = It->second->TotalSize()};
+ }
+ return {};
+}
+
+void
+ZenCacheMemoryLayer::CacheBucket::Scrub(ScrubContext& Ctx)
+{
+ RwLock::SharedLockScope _(m_BucketLock);
+
+ std::vector<IoHash> BadHashes;
+
+ auto ValidateEntry = [](const IoHash& Hash, ZenContentType ContentType, IoBuffer Buffer) {
+ if (ContentType == ZenContentType::kCbObject)
+ {
+ CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All);
+ return Error == CbValidateError::None;
+ }
+ if (ContentType == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize))
+ {
+ return false;
+ }
+ if (Hash != RawHash)
+ {
+ return false;
+ }
+ }
+ return true;
+ };
+
+ for (auto& Kv : m_CacheMap)
+ {
+ const BucketPayload& Payload = m_Payloads[Kv.second];
+ if (!ValidateEntry(Kv.first, Payload.Payload.GetContentType(), Payload.Payload))
+ {
+ BadHashes.push_back(Kv.first);
+ }
+ }
+
+ if (!BadHashes.empty())
+ {
+ Ctx.ReportBadCidChunks(BadHashes);
+ }
+}
+
+void
+ZenCacheMemoryLayer::CacheBucket::GatherAccessTimes(std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes)
+{
+ RwLock::SharedLockScope _(m_BucketLock);
+ std::transform(m_CacheMap.begin(), m_CacheMap.end(), std::back_inserter(AccessTimes), [this](const auto& Kv) {
+ return access_tracking::KeyAccessTime{.Key = Kv.first, .LastAccess = m_AccessTimes[Kv.second]};
+ });
+}
+
+bool
+ZenCacheMemoryLayer::CacheBucket::Get(const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ RwLock::SharedLockScope _(m_BucketLock);
+
+ if (auto It = m_CacheMap.find(HashKey); It != m_CacheMap.end())
+ {
+ uint32_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size());
+ ZEN_ASSERT_SLOW(m_AccessTimes.size() == m_Payloads.size());
+
+ const BucketPayload& Payload = m_Payloads[EntryIndex];
+ OutValue = {.Value = Payload.Payload, .RawSize = Payload.RawSize, .RawHash = Payload.RawHash};
+ m_AccessTimes[EntryIndex] = GcClock::TickCount();
+
+ return true;
+ }
+
+ return false;
+}
+
+void
+ZenCacheMemoryLayer::CacheBucket::Put(const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ size_t PayloadSize = Value.Value.GetSize();
+ {
+ GcClock::Tick AccessTime = GcClock::TickCount();
+ RwLock::ExclusiveLockScope _(m_BucketLock);
+ if (m_CacheMap.size() == std::numeric_limits<uint32_t>::max())
+ {
+ // No more space in our memory cache!
+ return;
+ }
+ if (auto It = m_CacheMap.find(HashKey); It != m_CacheMap.end())
+ {
+ uint32_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size());
+
+ m_TotalSize.fetch_sub(PayloadSize, std::memory_order::relaxed);
+ BucketPayload& Payload = m_Payloads[EntryIndex];
+ Payload.Payload = Value.Value;
+ Payload.RawHash = Value.RawHash;
+ Payload.RawSize = gsl::narrow<uint32_t>(Value.RawSize);
+ m_AccessTimes[EntryIndex] = AccessTime;
+ }
+ else
+ {
+ uint32_t EntryIndex = gsl::narrow<uint32_t>(m_Payloads.size());
+ m_Payloads.emplace_back(
+ BucketPayload{.Payload = Value.Value, .RawSize = gsl::narrow<uint32_t>(Value.RawSize), .RawHash = Value.RawHash});
+ m_AccessTimes.emplace_back(AccessTime);
+ m_CacheMap.insert_or_assign(HashKey, EntryIndex);
+ }
+ ZEN_ASSERT_SLOW(m_Payloads.size() == m_CacheMap.size());
+ ZEN_ASSERT_SLOW(m_AccessTimes.size() == m_Payloads.size());
+ }
+
+ m_TotalSize.fetch_add(PayloadSize, std::memory_order::relaxed);
+}
+
+void
+ZenCacheMemoryLayer::CacheBucket::Drop()
+{
+ RwLock::ExclusiveLockScope _(m_BucketLock);
+ m_CacheMap.clear();
+ m_AccessTimes.clear();
+ m_Payloads.clear();
+ m_TotalSize.store(0);
+}
+
+uint64_t
+ZenCacheMemoryLayer::CacheBucket::EntryCount() const
+{
+ RwLock::SharedLockScope _(m_BucketLock);
+ return static_cast<uint64_t>(m_CacheMap.size());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenCacheDiskLayer::CacheBucket::CacheBucket(std::string BucketName) : m_BucketName(std::move(BucketName)), m_BucketId(Oid::Zero)
+{
+}
+
+ZenCacheDiskLayer::CacheBucket::~CacheBucket()
+{
+}
+
+bool
+ZenCacheDiskLayer::CacheBucket::OpenOrCreate(std::filesystem::path BucketDir, bool AllowCreate)
+{
+ using namespace std::literals;
+
+ m_BlocksBasePath = BucketDir / "blocks";
+ m_BucketDir = BucketDir;
+
+ CreateDirectories(m_BucketDir);
+
+ std::filesystem::path ManifestPath{m_BucketDir / "zen_manifest"};
+
+ bool IsNew = false;
+
+ CbObject Manifest = LoadCompactBinaryObject(ManifestPath);
+
+ if (Manifest)
+ {
+ m_BucketId = Manifest["BucketId"sv].AsObjectId();
+ if (m_BucketId == Oid::Zero)
+ {
+ return false;
+ }
+ }
+ else if (AllowCreate)
+ {
+ m_BucketId.Generate();
+
+ CbObjectWriter Writer;
+ Writer << "BucketId"sv << m_BucketId;
+ Manifest = Writer.Save();
+ SaveCompactBinaryObject(ManifestPath, Manifest);
+ IsNew = true;
+ }
+ else
+ {
+ return false;
+ }
+
+ OpenLog(IsNew);
+
+ if (!IsNew)
+ {
+ Stopwatch Timer;
+ const auto _ =
+ MakeGuard([&] { ZEN_INFO("read store manifest '{}' in {}", ManifestPath, NiceTimeSpanMs(Timer.GetElapsedTimeMs())); });
+
+ for (CbFieldView Entry : Manifest["Timestamps"sv])
+ {
+ const CbObjectView Obj = Entry.AsObjectView();
+ const IoHash Key = Obj["Key"sv].AsHash();
+
+ if (auto It = m_Index.find(Key); It != m_Index.end())
+ {
+ size_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size());
+ m_AccessTimes[EntryIndex] = Obj["LastAccess"sv].AsInt64();
+ }
+ }
+ for (CbFieldView Entry : Manifest["RawInfo"sv])
+ {
+ const CbObjectView Obj = Entry.AsObjectView();
+ const IoHash Key = Obj["Key"sv].AsHash();
+ if (auto It = m_Index.find(Key); It != m_Index.end())
+ {
+ size_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_Payloads.size());
+ m_Payloads[EntryIndex].RawHash = Obj["RawHash"sv].AsHash();
+ m_Payloads[EntryIndex].RawSize = Obj["RawSize"sv].AsUInt64();
+ }
+ }
+ }
+
+ return true;
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::MakeIndexSnapshot()
+{
+ uint64_t LogCount = m_SlogFile.GetLogCount();
+ if (m_LogFlushPosition == LogCount)
+ {
+ return;
+ }
+
+ ZEN_DEBUG("write store snapshot for '{}'", m_BucketDir / m_BucketName);
+ uint64_t EntryCount = 0;
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("wrote store snapshot for '{}' containing {} entries in {}",
+ m_BucketDir / m_BucketName,
+ EntryCount,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ namespace fs = std::filesystem;
+
+ fs::path IndexPath = GetIndexPath(m_BucketDir, m_BucketName);
+ fs::path STmpIndexPath = GetTempIndexPath(m_BucketDir, m_BucketName);
+
+ // Move index away, we keep it if something goes wrong
+ if (fs::is_regular_file(STmpIndexPath))
+ {
+ fs::remove(STmpIndexPath);
+ }
+ if (fs::is_regular_file(IndexPath))
+ {
+ fs::rename(IndexPath, STmpIndexPath);
+ }
+
+ try
+ {
+ // Write the current state of the location map to a new index state
+ std::vector<DiskIndexEntry> Entries;
+
+ {
+ Entries.resize(m_Index.size());
+
+ uint64_t EntryIndex = 0;
+ for (auto& Entry : m_Index)
+ {
+ DiskIndexEntry& IndexEntry = Entries[EntryIndex++];
+ IndexEntry.Key = Entry.first;
+ IndexEntry.Location = m_Payloads[Entry.second].Location;
+ }
+ }
+
+ BasicFile ObjectIndexFile;
+ ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kTruncate);
+ CacheBucketIndexHeader Header = {.EntryCount = Entries.size(),
+ .LogPosition = LogCount,
+ .PayloadAlignment = gsl::narrow<uint32_t>(m_PayloadAlignment)};
+
+ Header.Checksum = CacheBucketIndexHeader::ComputeChecksum(Header);
+
+ ObjectIndexFile.Write(&Header, sizeof(CacheBucketIndexHeader), 0);
+ ObjectIndexFile.Write(Entries.data(), Entries.size() * sizeof(DiskIndexEntry), sizeof(CacheBucketIndexHeader));
+ ObjectIndexFile.Flush();
+ ObjectIndexFile.Close();
+ EntryCount = Entries.size();
+ m_LogFlushPosition = LogCount;
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("snapshot FAILED, reason: '{}'", Err.what());
+
+ // Restore any previous snapshot
+
+ if (fs::is_regular_file(STmpIndexPath))
+ {
+ fs::remove(IndexPath);
+ fs::rename(STmpIndexPath, IndexPath);
+ }
+ }
+ if (fs::is_regular_file(STmpIndexPath))
+ {
+ fs::remove(STmpIndexPath);
+ }
+}
+
+uint64_t
+ZenCacheDiskLayer::CacheBucket::ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t& OutVersion)
+{
+ if (std::filesystem::is_regular_file(IndexPath))
+ {
+ BasicFile ObjectIndexFile;
+ ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kRead);
+ uint64_t Size = ObjectIndexFile.FileSize();
+ if (Size >= sizeof(CacheBucketIndexHeader))
+ {
+ CacheBucketIndexHeader Header;
+ ObjectIndexFile.Read(&Header, sizeof(Header), 0);
+ if ((Header.Magic == CacheBucketIndexHeader::ExpectedMagic) &&
+ (Header.Checksum == CacheBucketIndexHeader::ComputeChecksum(Header)) && (Header.PayloadAlignment > 0))
+ {
+ switch (Header.Version)
+ {
+ case CacheBucketIndexHeader::Version2:
+ {
+ uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(CacheBucketIndexHeader))) / sizeof(DiskIndexEntry);
+ if (Header.EntryCount > ExpectedEntryCount)
+ {
+ break;
+ }
+ size_t EntryCount = 0;
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("read store '{}' index containing {} entries in {}",
+ IndexPath,
+ EntryCount,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ m_PayloadAlignment = Header.PayloadAlignment;
+
+ std::vector<DiskIndexEntry> Entries;
+ Entries.resize(Header.EntryCount);
+ ObjectIndexFile.Read(Entries.data(),
+ Header.EntryCount * sizeof(DiskIndexEntry),
+ sizeof(CacheBucketIndexHeader));
+
+ m_Payloads.reserve(Header.EntryCount);
+ m_AccessTimes.reserve(Header.EntryCount);
+ m_Index.reserve(Header.EntryCount);
+
+ std::string InvalidEntryReason;
+ for (const DiskIndexEntry& Entry : Entries)
+ {
+ if (!ValidateEntry(Entry, InvalidEntryReason))
+ {
+ ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", IndexPath, InvalidEntryReason);
+ continue;
+ }
+ size_t EntryIndex = m_Payloads.size();
+ m_Payloads.emplace_back(BucketPayload{.Location = Entry.Location, .RawSize = 0, .RawHash = IoHash::Zero});
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_Index.insert_or_assign(Entry.Key, EntryIndex);
+ EntryCount++;
+ }
+ OutVersion = CacheBucketIndexHeader::Version2;
+ return Header.LogPosition;
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ ZEN_WARN("skipping invalid index file '{}'", IndexPath);
+ }
+ return 0;
+}
+
+uint64_t
+ZenCacheDiskLayer::CacheBucket::ReadLog(const std::filesystem::path& LogPath, uint64_t SkipEntryCount)
+{
+ if (std::filesystem::is_regular_file(LogPath))
+ {
+ uint64_t LogEntryCount = 0;
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("read store '{}' log containing {} entries in {}", LogPath, LogEntryCount, NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+ TCasLogFile<DiskIndexEntry> CasLog;
+ CasLog.Open(LogPath, CasLogFile::Mode::kRead);
+ if (CasLog.Initialize())
+ {
+ uint64_t EntryCount = CasLog.GetLogCount();
+ if (EntryCount < SkipEntryCount)
+ {
+ ZEN_WARN("reading full log at '{}', reason: Log position from index snapshot is out of range", LogPath);
+ SkipEntryCount = 0;
+ }
+ LogEntryCount = EntryCount - SkipEntryCount;
+ m_Index.reserve(LogEntryCount);
+ uint64_t InvalidEntryCount = 0;
+ CasLog.Replay(
+ [&](const DiskIndexEntry& Record) {
+ std::string InvalidEntryReason;
+ if (Record.Location.Flags & DiskLocation::kTombStone)
+ {
+ m_Index.erase(Record.Key);
+ return;
+ }
+ if (!ValidateEntry(Record, InvalidEntryReason))
+ {
+ ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", LogPath, InvalidEntryReason);
+ ++InvalidEntryCount;
+ return;
+ }
+ size_t EntryIndex = m_Payloads.size();
+ m_Payloads.emplace_back(BucketPayload{.Location = Record.Location, .RawSize = 0u, .RawHash = IoHash::Zero});
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_Index.insert_or_assign(Record.Key, EntryIndex);
+ },
+ SkipEntryCount);
+ if (InvalidEntryCount)
+ {
+ ZEN_WARN("found {} invalid entries in '{}'", InvalidEntryCount, m_BucketDir / m_BucketName);
+ }
+ return LogEntryCount;
+ }
+ }
+ return 0;
+};
+
+void
+ZenCacheDiskLayer::CacheBucket::OpenLog(const bool IsNew)
+{
+ m_TotalStandaloneSize = 0;
+
+ m_Index.clear();
+ m_Payloads.clear();
+ m_AccessTimes.clear();
+
+ std::filesystem::path LogPath = GetLogPath(m_BucketDir, m_BucketName);
+ std::filesystem::path IndexPath = GetIndexPath(m_BucketDir, m_BucketName);
+
+ if (IsNew)
+ {
+ fs::remove(LogPath);
+ fs::remove(IndexPath);
+ fs::remove_all(m_BlocksBasePath);
+ }
+
+ uint64_t LogEntryCount = 0;
+ {
+ uint32_t IndexVersion = 0;
+ m_LogFlushPosition = ReadIndexFile(IndexPath, IndexVersion);
+ if (IndexVersion == 0 && std::filesystem::is_regular_file(IndexPath))
+ {
+ ZEN_WARN("removing invalid index file at '{}'", IndexPath);
+ fs::remove(IndexPath);
+ }
+
+ if (TCasLogFile<DiskIndexEntry>::IsValid(LogPath))
+ {
+ LogEntryCount = ReadLog(LogPath, m_LogFlushPosition);
+ }
+ else
+ {
+ ZEN_WARN("removing invalid cas log at '{}'", LogPath);
+ fs::remove(LogPath);
+ }
+ }
+
+ CreateDirectories(m_BucketDir);
+
+ m_SlogFile.Open(LogPath, CasLogFile::Mode::kWrite);
+
+ std::vector<BlockStoreLocation> KnownLocations;
+ KnownLocations.reserve(m_Index.size());
+ for (const auto& Entry : m_Index)
+ {
+ size_t EntryIndex = Entry.second;
+ const BucketPayload& Payload = m_Payloads[EntryIndex];
+ const DiskLocation& Location = Payload.Location;
+
+ if (Location.IsFlagSet(DiskLocation::kStandaloneFile))
+ {
+ m_TotalStandaloneSize.fetch_add(Location.Size(), std::memory_order::relaxed);
+ continue;
+ }
+ const BlockStoreLocation& BlockLocation = Location.GetBlockLocation(m_PayloadAlignment);
+ KnownLocations.push_back(BlockLocation);
+ }
+
+ m_BlockStore.Initialize(m_BlocksBasePath, MaxBlockSize, BlockStoreDiskLocation::MaxBlockIndex + 1, KnownLocations);
+ if (IsNew || LogEntryCount > 0)
+ {
+ MakeIndexSnapshot();
+ }
+ // TODO: should validate integrity of container files here
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::BuildPath(PathBuilderBase& Path, const IoHash& HashKey) const
+{
+ char HexString[sizeof(HashKey.Hash) * 2];
+ ToHexBytes(HashKey.Hash, sizeof HashKey.Hash, HexString);
+
+ Path.Append(m_BucketDir);
+ Path.AppendSeparator();
+ Path.Append(L"blob");
+ Path.AppendSeparator();
+ Path.AppendAsciiRange(HexString, HexString + 3);
+ Path.AppendSeparator();
+ Path.AppendAsciiRange(HexString + 3, HexString + 5);
+ Path.AppendSeparator();
+ Path.AppendAsciiRange(HexString + 5, HexString + sizeof(HexString));
+}
+
+IoBuffer
+ZenCacheDiskLayer::CacheBucket::GetInlineCacheValue(const DiskLocation& Loc) const
+{
+ BlockStoreLocation Location = Loc.GetBlockLocation(m_PayloadAlignment);
+
+ IoBuffer Value = m_BlockStore.TryGetChunk(Location);
+ if (Value)
+ {
+ Value.SetContentType(Loc.GetContentType());
+ }
+
+ return Value;
+}
+
+IoBuffer
+ZenCacheDiskLayer::CacheBucket::GetStandaloneCacheValue(const DiskLocation& Loc, const IoHash& HashKey) const
+{
+ ExtendablePathBuilder<256> DataFilePath;
+ BuildPath(DataFilePath, HashKey);
+
+ RwLock::SharedLockScope ValueLock(LockForHash(HashKey));
+
+ if (IoBuffer Data = IoBufferBuilder::MakeFromFile(DataFilePath.ToPath()))
+ {
+ Data.SetContentType(Loc.GetContentType());
+
+ return Data;
+ }
+
+ return {};
+}
+
+bool
+ZenCacheDiskLayer::CacheBucket::Get(const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ RwLock::SharedLockScope _(m_IndexLock);
+ auto It = m_Index.find(HashKey);
+ if (It == m_Index.end())
+ {
+ return false;
+ }
+ size_t EntryIndex = It.value();
+ const BucketPayload& Payload = m_Payloads[EntryIndex];
+ m_AccessTimes[EntryIndex] = GcClock::TickCount();
+ DiskLocation Location = Payload.Location;
+ OutValue.RawSize = Payload.RawSize;
+ OutValue.RawHash = Payload.RawHash;
+ if (Location.IsFlagSet(DiskLocation::kStandaloneFile))
+ {
+ // We don't need to hold the index lock when we read a standalone file
+ _.ReleaseNow();
+ OutValue.Value = GetStandaloneCacheValue(Location, HashKey);
+ }
+ else
+ {
+ OutValue.Value = GetInlineCacheValue(Location);
+ }
+ _.ReleaseNow();
+
+ if (!Location.IsFlagSet(DiskLocation::kStructured))
+ {
+ if (OutValue.RawHash == IoHash::Zero && OutValue.RawSize == 0 && OutValue.Value.GetSize() > 0)
+ {
+ if (Location.IsFlagSet(DiskLocation::kCompressed))
+ {
+ (void)CompressedBuffer::FromCompressed(SharedBuffer(OutValue.Value), OutValue.RawHash, OutValue.RawSize);
+ }
+ else
+ {
+ OutValue.RawHash = IoHash::HashBuffer(OutValue.Value);
+ OutValue.RawSize = OutValue.Value.GetSize();
+ }
+ RwLock::ExclusiveLockScope __(m_IndexLock);
+ if (auto WriteIt = m_Index.find(HashKey); WriteIt != m_Index.end())
+ {
+ BucketPayload& WritePayload = m_Payloads[WriteIt.value()];
+ WritePayload.RawHash = OutValue.RawHash;
+ WritePayload.RawSize = OutValue.RawSize;
+
+ m_LogFlushPosition = 0; // Force resave of index on exit
+ }
+ }
+ }
+
+ return (bool)OutValue.Value;
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::Put(const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ if (Value.Value.Size() >= m_LargeObjectThreshold)
+ {
+ return PutStandaloneCacheValue(HashKey, Value);
+ }
+ PutInlineCacheValue(HashKey, Value);
+}
+
+bool
+ZenCacheDiskLayer::CacheBucket::Drop()
+{
+ RwLock::ExclusiveLockScope _(m_IndexLock);
+
+ std::vector<std::unique_ptr<RwLock::ExclusiveLockScope>> ShardLocks;
+ ShardLocks.reserve(256);
+ for (RwLock& Lock : m_ShardedLocks)
+ {
+ ShardLocks.push_back(std::make_unique<RwLock::ExclusiveLockScope>(Lock));
+ }
+ m_BlockStore.Close();
+ m_SlogFile.Close();
+
+ bool Deleted = MoveAndDeleteDirectory(m_BucketDir);
+
+ m_Index.clear();
+ m_Payloads.clear();
+ m_AccessTimes.clear();
+ return Deleted;
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::Flush()
+{
+ m_BlockStore.Flush();
+
+ RwLock::SharedLockScope _(m_IndexLock);
+ m_SlogFile.Flush();
+ MakeIndexSnapshot();
+ SaveManifest();
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::SaveManifest()
+{
+ using namespace std::literals;
+
+ CbObjectWriter Writer;
+ Writer << "BucketId"sv << m_BucketId;
+
+ if (!m_Index.empty())
+ {
+ Writer.BeginArray("Timestamps"sv);
+ for (auto& Kv : m_Index)
+ {
+ const IoHash& Key = Kv.first;
+ GcClock::Tick AccessTime = m_AccessTimes[Kv.second];
+
+ Writer.BeginObject();
+ Writer << "Key"sv << Key;
+ Writer << "LastAccess"sv << AccessTime;
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+
+ Writer.BeginArray("RawInfo"sv);
+ {
+ for (auto& Kv : m_Index)
+ {
+ const IoHash& Key = Kv.first;
+ const BucketPayload& Payload = m_Payloads[Kv.second];
+ if (Payload.RawHash != IoHash::Zero)
+ {
+ Writer.BeginObject();
+ Writer << "Key"sv << Key;
+ Writer << "RawHash"sv << Payload.RawHash;
+ Writer << "RawSize"sv << Payload.RawSize;
+ Writer.EndObject();
+ }
+ }
+ }
+ Writer.EndArray();
+ }
+
+ SaveCompactBinaryObject(m_BucketDir / "zen_manifest", Writer.Save());
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::Scrub(ScrubContext& Ctx)
+{
+ std::vector<IoHash> BadKeys;
+ uint64_t ChunkCount{0}, ChunkBytes{0};
+ std::vector<BlockStoreLocation> ChunkLocations;
+ std::vector<IoHash> ChunkIndexToChunkHash;
+
+ auto ValidateEntry = [](const IoHash& Hash, ZenContentType ContentType, IoBuffer Buffer) {
+ if (ContentType == ZenContentType::kCbObject)
+ {
+ CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All);
+ return Error == CbValidateError::None;
+ }
+ if (ContentType == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize))
+ {
+ return false;
+ }
+ if (RawHash != Hash)
+ {
+ return false;
+ }
+ }
+ return true;
+ };
+
+ RwLock::SharedLockScope _(m_IndexLock);
+
+ const size_t BlockChunkInitialCount = m_Index.size() / 4;
+ ChunkLocations.reserve(BlockChunkInitialCount);
+ ChunkIndexToChunkHash.reserve(BlockChunkInitialCount);
+
+ for (auto& Kv : m_Index)
+ {
+ const IoHash& HashKey = Kv.first;
+ const BucketPayload& Payload = m_Payloads[Kv.second];
+ const DiskLocation& Loc = Payload.Location;
+
+ if (Loc.IsFlagSet(DiskLocation::kStandaloneFile))
+ {
+ ++ChunkCount;
+ ChunkBytes += Loc.Size();
+ if (Loc.GetContentType() == ZenContentType::kBinary)
+ {
+ ExtendablePathBuilder<256> DataFilePath;
+ BuildPath(DataFilePath, HashKey);
+
+ RwLock::SharedLockScope ValueLock(LockForHash(HashKey));
+
+ std::error_code Ec;
+ uintmax_t size = std::filesystem::file_size(DataFilePath.ToPath(), Ec);
+ if (Ec)
+ {
+ BadKeys.push_back(HashKey);
+ }
+ if (size != Loc.Size())
+ {
+ BadKeys.push_back(HashKey);
+ }
+ continue;
+ }
+ IoBuffer Buffer = GetStandaloneCacheValue(Loc, HashKey);
+ if (!Buffer)
+ {
+ BadKeys.push_back(HashKey);
+ continue;
+ }
+ if (!ValidateEntry(HashKey, Loc.GetContentType(), Buffer))
+ {
+ BadKeys.push_back(HashKey);
+ continue;
+ }
+ }
+ else
+ {
+ ChunkLocations.emplace_back(Loc.GetBlockLocation(m_PayloadAlignment));
+ ChunkIndexToChunkHash.push_back(HashKey);
+ continue;
+ }
+ }
+
+ const auto ValidateSmallChunk = [&](size_t ChunkIndex, const void* Data, uint64_t Size) {
+ ++ChunkCount;
+ ChunkBytes += Size;
+ const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex];
+ if (!Data)
+ {
+ // ChunkLocation out of range of stored blocks
+ BadKeys.push_back(Hash);
+ return;
+ }
+ IoBuffer Buffer(IoBuffer::Wrap, Data, Size);
+ if (!Buffer)
+ {
+ BadKeys.push_back(Hash);
+ return;
+ }
+ const BucketPayload& Payload = m_Payloads[m_Index.at(Hash)];
+ ZenContentType ContentType = Payload.Location.GetContentType();
+ if (!ValidateEntry(Hash, ContentType, Buffer))
+ {
+ BadKeys.push_back(Hash);
+ return;
+ }
+ };
+
+ const auto ValidateLargeChunk = [&](size_t ChunkIndex, BlockStoreFile& File, uint64_t Offset, uint64_t Size) {
+ ++ChunkCount;
+ ChunkBytes += Size;
+ const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex];
+ // TODO: Add API to verify compressed buffer and possible structure data without having to memorymap the whole file
+ IoBuffer Buffer(IoBuffer::BorrowedFile, File.GetBasicFile().Handle(), Offset, Size);
+ if (!Buffer)
+ {
+ BadKeys.push_back(Hash);
+ return;
+ }
+ const BucketPayload& Payload = m_Payloads[m_Index.at(Hash)];
+ ZenContentType ContentType = Payload.Location.GetContentType();
+ if (!ValidateEntry(Hash, ContentType, Buffer))
+ {
+ BadKeys.push_back(Hash);
+ return;
+ }
+ };
+
+ m_BlockStore.IterateChunks(ChunkLocations, ValidateSmallChunk, ValidateLargeChunk);
+
+ _.ReleaseNow();
+
+ Ctx.ReportScrubbed(ChunkCount, ChunkBytes);
+
+ if (!BadKeys.empty())
+ {
+ ZEN_WARN("Scrubbing found {} bad chunks in '{}'", BadKeys.size(), m_BucketDir / m_BucketName);
+
+ if (Ctx.RunRecovery())
+ {
+ // Deal with bad chunks by removing them from our lookup map
+
+ std::vector<DiskIndexEntry> LogEntries;
+ LogEntries.reserve(BadKeys.size());
+
+ {
+ RwLock::ExclusiveLockScope __(m_IndexLock);
+ for (const IoHash& BadKey : BadKeys)
+ {
+ // Log a tombstone and delete the in-memory index for the bad entry
+ const auto It = m_Index.find(BadKey);
+ const BucketPayload& Payload = m_Payloads[It->second];
+ DiskLocation Location = Payload.Location;
+ Location.Flags |= DiskLocation::kTombStone;
+ LogEntries.push_back(DiskIndexEntry{.Key = BadKey, .Location = Location});
+ m_Index.erase(BadKey);
+ }
+ }
+ for (const DiskIndexEntry& Entry : LogEntries)
+ {
+ if (Entry.Location.IsFlagSet(DiskLocation::kStandaloneFile))
+ {
+ ExtendablePathBuilder<256> Path;
+ BuildPath(Path, Entry.Key);
+ fs::path FilePath = Path.ToPath();
+ RwLock::ExclusiveLockScope ValueLock(LockForHash(Entry.Key));
+ if (fs::is_regular_file(FilePath))
+ {
+ ZEN_DEBUG("deleting bad standalone cache file '{}'", Path.ToUtf8());
+ std::error_code Ec;
+ fs::remove(FilePath, Ec); // We don't care if we fail, we are no longer tracking this file...
+ }
+ m_TotalStandaloneSize.fetch_sub(Entry.Location.Size(), std::memory_order::relaxed);
+ }
+ }
+ m_SlogFile.Append(LogEntries);
+
+ // Clean up m_AccessTimes and m_Payloads vectors
+ {
+ std::vector<BucketPayload> Payloads;
+ std::vector<AccessTime> AccessTimes;
+ IndexMap Index;
+
+ {
+ RwLock::ExclusiveLockScope __(m_IndexLock);
+ size_t EntryCount = m_Index.size();
+ Payloads.reserve(EntryCount);
+ AccessTimes.reserve(EntryCount);
+ Index.reserve(EntryCount);
+ for (auto It : m_Index)
+ {
+ size_t EntryIndex = Payloads.size();
+ Payloads.push_back(m_Payloads[EntryIndex]);
+ AccessTimes.push_back(m_AccessTimes[EntryIndex]);
+ Index.insert({It.first, EntryIndex});
+ }
+ m_Index.swap(Index);
+ m_Payloads.swap(Payloads);
+ m_AccessTimes.swap(AccessTimes);
+ }
+ }
+ }
+ }
+
+ // Let whomever it concerns know about the bad chunks. This could
+ // be used to invalidate higher level data structures more efficiently
+ // than a full validation pass might be able to do
+ Ctx.ReportBadCidChunks(BadKeys);
+
+ ZEN_INFO("cache bucket scrubbed: {} chunks ({})", ChunkCount, NiceBytes(ChunkBytes));
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::GatherReferences(GcContext& GcCtx)
+{
+ ZEN_TRACE_CPU("Z$::DiskLayer::CacheBucket::GatherReferences");
+
+ uint64_t WriteBlockTimeUs = 0;
+ uint64_t WriteBlockLongestTimeUs = 0;
+ uint64_t ReadBlockTimeUs = 0;
+ uint64_t ReadBlockLongestTimeUs = 0;
+
+ Stopwatch TotalTimer;
+ const auto _ = MakeGuard([&] {
+ ZEN_DEBUG("gathered references from '{}' in {} write lock: {} ({}), read lock: {} ({})",
+ m_BucketDir / m_BucketName,
+ NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()),
+ NiceLatencyNs(WriteBlockTimeUs),
+ NiceLatencyNs(WriteBlockLongestTimeUs),
+ NiceLatencyNs(ReadBlockTimeUs),
+ NiceLatencyNs(ReadBlockLongestTimeUs));
+ });
+
+ const GcClock::TimePoint ExpireTime = GcCtx.ExpireTime();
+
+ const GcClock::Tick ExpireTicks = ExpireTime.time_since_epoch().count();
+
+ IndexMap Index;
+ std::vector<AccessTime> AccessTimes;
+ std::vector<BucketPayload> Payloads;
+ {
+ RwLock::SharedLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ___ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ Index = m_Index;
+ AccessTimes = m_AccessTimes;
+ Payloads = m_Payloads;
+ }
+
+ std::vector<IoHash> ExpiredKeys;
+ ExpiredKeys.reserve(1024);
+
+ std::vector<IoHash> Cids;
+ Cids.reserve(1024);
+
+ for (const auto& Entry : Index)
+ {
+ const IoHash& Key = Entry.first;
+ GcClock::Tick AccessTime = AccessTimes[Entry.second];
+ if (AccessTime < ExpireTicks)
+ {
+ ExpiredKeys.push_back(Key);
+ continue;
+ }
+
+ const DiskLocation& Loc = Payloads[Entry.second].Location;
+
+ if (Loc.IsFlagSet(DiskLocation::kStructured))
+ {
+ if (Cids.size() > 1024)
+ {
+ GcCtx.AddRetainedCids(Cids);
+ Cids.clear();
+ }
+
+ IoBuffer Buffer;
+ {
+ RwLock::SharedLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ___ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ if (Loc.IsFlagSet(DiskLocation::kStandaloneFile))
+ {
+ // We don't need to hold the index lock when we read a standalone file
+ __.ReleaseNow();
+ if (Buffer = GetStandaloneCacheValue(Loc, Key); !Buffer)
+ {
+ continue;
+ }
+ }
+ else if (Buffer = GetInlineCacheValue(Loc); !Buffer)
+ {
+ continue;
+ }
+ }
+
+ ZEN_ASSERT(Buffer);
+ ZEN_ASSERT(Buffer.GetContentType() == ZenContentType::kCbObject);
+ CbObject Obj(SharedBuffer{Buffer});
+ Obj.IterateAttachments([&Cids](CbFieldView Field) { Cids.push_back(Field.AsAttachment()); });
+ }
+ }
+
+ GcCtx.AddRetainedCids(Cids);
+ GcCtx.SetExpiredCacheKeys(m_BucketDir.string(), std::move(ExpiredKeys));
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::CollectGarbage(GcContext& GcCtx)
+{
+ ZEN_TRACE_CPU("Z$::DiskLayer::CacheBucket::CollectGarbage");
+
+ ZEN_DEBUG("collecting garbage from '{}'", m_BucketDir / m_BucketName);
+
+ Stopwatch TotalTimer;
+ uint64_t WriteBlockTimeUs = 0;
+ uint64_t WriteBlockLongestTimeUs = 0;
+ uint64_t ReadBlockTimeUs = 0;
+ uint64_t ReadBlockLongestTimeUs = 0;
+ uint64_t TotalChunkCount = 0;
+ uint64_t DeletedSize = 0;
+ uint64_t OldTotalSize = TotalSize();
+
+ std::unordered_set<IoHash> DeletedChunks;
+ uint64_t MovedCount = 0;
+
+ const auto _ = MakeGuard([&] {
+ ZEN_DEBUG(
+ "garbage collect from '{}' DONE after {}, write lock: {} ({}), read lock: {} ({}), collected {} bytes, deleted {} and moved "
+ "{} "
+ "of {} "
+ "entires ({}).",
+ m_BucketDir / m_BucketName,
+ NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()),
+ NiceLatencyNs(WriteBlockTimeUs),
+ NiceLatencyNs(WriteBlockLongestTimeUs),
+ NiceLatencyNs(ReadBlockTimeUs),
+ NiceLatencyNs(ReadBlockLongestTimeUs),
+ NiceBytes(DeletedSize),
+ DeletedChunks.size(),
+ MovedCount,
+ TotalChunkCount,
+ NiceBytes(OldTotalSize));
+ RwLock::SharedLockScope _(m_IndexLock);
+ SaveManifest();
+ });
+
+ m_SlogFile.Flush();
+
+ std::span<const IoHash> ExpiredCacheKeys = GcCtx.ExpiredCacheKeys(m_BucketDir.string());
+ std::vector<IoHash> DeleteCacheKeys;
+ DeleteCacheKeys.reserve(ExpiredCacheKeys.size());
+ GcCtx.FilterCids(ExpiredCacheKeys, [&](const IoHash& ChunkHash, bool Keep) {
+ if (Keep)
+ {
+ return;
+ }
+ DeleteCacheKeys.push_back(ChunkHash);
+ });
+ if (DeleteCacheKeys.empty())
+ {
+ ZEN_DEBUG("garbage collect SKIPPED, for '{}', no expired cache keys found", m_BucketDir / m_BucketName);
+ return;
+ }
+
+ auto __ = MakeGuard([&]() {
+ if (!DeletedChunks.empty())
+ {
+ // Clean up m_AccessTimes and m_Payloads vectors
+ std::vector<BucketPayload> Payloads;
+ std::vector<AccessTime> AccessTimes;
+ IndexMap Index;
+
+ {
+ RwLock::ExclusiveLockScope _(m_IndexLock);
+ Stopwatch Timer;
+ const auto ___ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ size_t EntryCount = m_Index.size();
+ Payloads.reserve(EntryCount);
+ AccessTimes.reserve(EntryCount);
+ Index.reserve(EntryCount);
+ for (auto It : m_Index)
+ {
+ size_t EntryIndex = Payloads.size();
+ Payloads.push_back(m_Payloads[EntryIndex]);
+ AccessTimes.push_back(m_AccessTimes[EntryIndex]);
+ Index.insert({It.first, EntryIndex});
+ }
+ m_Index.swap(Index);
+ m_Payloads.swap(Payloads);
+ m_AccessTimes.swap(AccessTimes);
+ }
+ GcCtx.AddDeletedCids(std::vector<IoHash>(DeletedChunks.begin(), DeletedChunks.end()));
+ }
+ });
+
+ std::vector<DiskIndexEntry> ExpiredStandaloneEntries;
+ IndexMap Index;
+ BlockStore::ReclaimSnapshotState BlockStoreState;
+ {
+ RwLock::SharedLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ____ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ if (m_Index.empty())
+ {
+ ZEN_DEBUG("garbage collect SKIPPED, for '{}', container is empty", m_BucketDir / m_BucketName);
+ return;
+ }
+ BlockStoreState = m_BlockStore.GetReclaimSnapshotState();
+
+ SaveManifest();
+ Index = m_Index;
+
+ for (const IoHash& Key : DeleteCacheKeys)
+ {
+ if (auto It = Index.find(Key); It != Index.end())
+ {
+ const BucketPayload& Payload = m_Payloads[It->second];
+ DiskIndexEntry Entry = {.Key = It->first, .Location = Payload.Location};
+ if (Entry.Location.Flags & DiskLocation::kStandaloneFile)
+ {
+ Entry.Location.Flags |= DiskLocation::kTombStone;
+ ExpiredStandaloneEntries.push_back(Entry);
+ }
+ }
+ }
+ if (GcCtx.IsDeletionMode())
+ {
+ for (const auto& Entry : ExpiredStandaloneEntries)
+ {
+ m_Index.erase(Entry.Key);
+ m_TotalStandaloneSize.fetch_sub(Entry.Location.Size(), std::memory_order::relaxed);
+ DeletedChunks.insert(Entry.Key);
+ }
+ m_SlogFile.Append(ExpiredStandaloneEntries);
+ }
+ }
+
+ if (GcCtx.IsDeletionMode())
+ {
+ std::error_code Ec;
+ ExtendablePathBuilder<256> Path;
+
+ for (const auto& Entry : ExpiredStandaloneEntries)
+ {
+ const IoHash& Key = Entry.Key;
+ const DiskLocation& Loc = Entry.Location;
+
+ Path.Reset();
+ BuildPath(Path, Key);
+ fs::path FilePath = Path.ToPath();
+
+ {
+ RwLock::SharedLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ____ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ if (m_Index.contains(Key))
+ {
+ // Someone added it back, let the file on disk be
+ ZEN_DEBUG("skipping z$ delete standalone of file '{}' FAILED, it has been added back", Path.ToUtf8());
+ continue;
+ }
+ __.ReleaseNow();
+
+ RwLock::ExclusiveLockScope ValueLock(LockForHash(Key));
+ if (fs::is_regular_file(FilePath))
+ {
+ ZEN_DEBUG("deleting standalone cache file '{}'", Path.ToUtf8());
+ fs::remove(FilePath, Ec);
+ }
+ }
+
+ if (Ec)
+ {
+ ZEN_WARN("delete expired z$ standalone file '{}' FAILED, reason: '{}'", Path.ToUtf8(), Ec.message());
+ Ec.clear();
+ DiskLocation RestoreLocation = Loc;
+ RestoreLocation.Flags &= ~DiskLocation::kTombStone;
+
+ RwLock::ExclusiveLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ___ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ ReadBlockTimeUs += ElapsedUs;
+ ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs);
+ });
+ if (m_Index.contains(Key))
+ {
+ continue;
+ }
+ m_SlogFile.Append(DiskIndexEntry{.Key = Key, .Location = RestoreLocation});
+ size_t EntryIndex = m_Payloads.size();
+ m_Payloads.emplace_back(BucketPayload{.Location = RestoreLocation});
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_Index.insert({Key, EntryIndex});
+ m_TotalStandaloneSize.fetch_add(RestoreLocation.Size(), std::memory_order::relaxed);
+ DeletedChunks.erase(Key);
+ continue;
+ }
+ DeletedSize += Entry.Location.Size();
+ }
+ }
+
+ TotalChunkCount = Index.size();
+
+ std::vector<IoHash> TotalChunkHashes;
+ TotalChunkHashes.reserve(TotalChunkCount);
+ for (const auto& Entry : Index)
+ {
+ const DiskLocation& Location = m_Payloads[Entry.second].Location;
+
+ if (Location.Flags & DiskLocation::kStandaloneFile)
+ {
+ continue;
+ }
+ TotalChunkHashes.push_back(Entry.first);
+ }
+
+ if (TotalChunkHashes.empty())
+ {
+ return;
+ }
+ TotalChunkCount = TotalChunkHashes.size();
+
+ std::vector<BlockStoreLocation> ChunkLocations;
+ BlockStore::ChunkIndexArray KeepChunkIndexes;
+ std::vector<IoHash> ChunkIndexToChunkHash;
+ ChunkLocations.reserve(TotalChunkCount);
+ ChunkLocations.reserve(TotalChunkCount);
+ ChunkIndexToChunkHash.reserve(TotalChunkCount);
+
+ GcCtx.FilterCids(TotalChunkHashes, [&](const IoHash& ChunkHash, bool Keep) {
+ auto KeyIt = Index.find(ChunkHash);
+ const DiskLocation& DiskLocation = m_Payloads[KeyIt->second].Location;
+ BlockStoreLocation Location = DiskLocation.GetBlockLocation(m_PayloadAlignment);
+ size_t ChunkIndex = ChunkLocations.size();
+ ChunkLocations.push_back(Location);
+ ChunkIndexToChunkHash[ChunkIndex] = ChunkHash;
+ if (Keep)
+ {
+ KeepChunkIndexes.push_back(ChunkIndex);
+ }
+ });
+
+ size_t DeleteCount = TotalChunkCount - KeepChunkIndexes.size();
+
+ const bool PerformDelete = GcCtx.IsDeletionMode() && GcCtx.CollectSmallObjects();
+ if (!PerformDelete)
+ {
+ m_BlockStore.ReclaimSpace(BlockStoreState, ChunkLocations, KeepChunkIndexes, m_PayloadAlignment, true);
+ uint64_t CurrentTotalSize = TotalSize();
+ ZEN_DEBUG("garbage collect from '{}' DISABLED, found {} chunks of total {} {}",
+ m_BucketDir / m_BucketName,
+ DeleteCount,
+ TotalChunkCount,
+ NiceBytes(CurrentTotalSize));
+ return;
+ }
+
+ m_BlockStore.ReclaimSpace(
+ BlockStoreState,
+ ChunkLocations,
+ KeepChunkIndexes,
+ m_PayloadAlignment,
+ false,
+ [&](const BlockStore::MovedChunksArray& MovedChunks, const BlockStore::ChunkIndexArray& RemovedChunks) {
+ std::vector<DiskIndexEntry> LogEntries;
+ LogEntries.reserve(MovedChunks.size() + RemovedChunks.size());
+ for (const auto& Entry : MovedChunks)
+ {
+ size_t ChunkIndex = Entry.first;
+ const BlockStoreLocation& NewLocation = Entry.second;
+ const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex];
+ const BucketPayload& OldPayload = m_Payloads[Index[ChunkHash]];
+ const DiskLocation& OldDiskLocation = OldPayload.Location;
+ LogEntries.push_back(
+ {.Key = ChunkHash, .Location = DiskLocation(NewLocation, m_PayloadAlignment, OldDiskLocation.GetFlags())});
+ }
+ for (const size_t ChunkIndex : RemovedChunks)
+ {
+ const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex];
+ const BucketPayload& OldPayload = m_Payloads[Index[ChunkHash]];
+ const DiskLocation& OldDiskLocation = OldPayload.Location;
+ LogEntries.push_back({.Key = ChunkHash,
+ .Location = DiskLocation(OldDiskLocation.GetBlockLocation(m_PayloadAlignment),
+ m_PayloadAlignment,
+ OldDiskLocation.GetFlags() | DiskLocation::kTombStone)});
+ DeletedChunks.insert(ChunkHash);
+ }
+
+ m_SlogFile.Append(LogEntries);
+ m_SlogFile.Flush();
+ {
+ RwLock::ExclusiveLockScope __(m_IndexLock);
+ Stopwatch Timer;
+ const auto ____ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ ReadBlockTimeUs += ElapsedUs;
+ ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs);
+ });
+ for (const DiskIndexEntry& Entry : LogEntries)
+ {
+ if (Entry.Location.GetFlags() & DiskLocation::kTombStone)
+ {
+ m_Index.erase(Entry.Key);
+ continue;
+ }
+ m_Payloads[m_Index[Entry.Key]].Location = Entry.Location;
+ }
+ }
+ },
+ [&]() { return GcCtx.CollectSmallObjects(); });
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::UpdateAccessTimes(const std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes)
+{
+ using namespace access_tracking;
+
+ for (const KeyAccessTime& KeyTime : AccessTimes)
+ {
+ if (auto It = m_Index.find(KeyTime.Key); It != m_Index.end())
+ {
+ size_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size());
+ m_AccessTimes[EntryIndex] = KeyTime.LastAccess;
+ }
+ }
+}
+
+uint64_t
+ZenCacheDiskLayer::CacheBucket::EntryCount() const
+{
+ RwLock::SharedLockScope _(m_IndexLock);
+ return static_cast<uint64_t>(m_Index.size());
+}
+
+CacheValueDetails::ValueDetails
+ZenCacheDiskLayer::CacheBucket::GetValueDetails(const IoHash& Key, size_t Index) const
+{
+ std::vector<IoHash> Attachments;
+ const BucketPayload& Payload = m_Payloads[Index];
+ if (Payload.Location.IsFlagSet(DiskLocation::kStructured))
+ {
+ IoBuffer Value = Payload.Location.IsFlagSet(DiskLocation::kStandaloneFile) ? GetStandaloneCacheValue(Payload.Location, Key)
+ : GetInlineCacheValue(Payload.Location);
+ CbObject Obj(SharedBuffer{Value});
+ Obj.IterateAttachments([&Attachments](CbFieldView Field) { Attachments.emplace_back(Field.AsAttachment()); });
+ }
+ return CacheValueDetails::ValueDetails{.Size = Payload.Location.Size(),
+ .RawSize = Payload.RawSize,
+ .RawHash = Payload.RawHash,
+ .LastAccess = m_AccessTimes[Index],
+ .Attachments = std::move(Attachments),
+ .ContentType = Payload.Location.GetContentType()};
+}
+
+CacheValueDetails::BucketDetails
+ZenCacheDiskLayer::CacheBucket::GetValueDetails(const std::string_view ValueFilter) const
+{
+ CacheValueDetails::BucketDetails Details;
+ RwLock::SharedLockScope _(m_IndexLock);
+ if (ValueFilter.empty())
+ {
+ Details.Values.reserve(m_Index.size());
+ for (const auto& It : m_Index)
+ {
+ Details.Values.insert_or_assign(It.first, GetValueDetails(It.first, It.second));
+ }
+ }
+ else
+ {
+ IoHash Key = IoHash::FromHexString(ValueFilter);
+ if (auto It = m_Index.find(Key); It != m_Index.end())
+ {
+ Details.Values.insert_or_assign(It->first, GetValueDetails(It->first, It->second));
+ }
+ }
+ return Details;
+}
+
+void
+ZenCacheDiskLayer::CollectGarbage(GcContext& GcCtx)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ CacheBucket& Bucket = *Kv.second;
+ Bucket.CollectGarbage(GcCtx);
+ }
+}
+
+void
+ZenCacheDiskLayer::UpdateAccessTimes(const zen::access_tracking::AccessTimes& AccessTimes)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (const auto& Kv : AccessTimes.Buckets)
+ {
+ if (auto It = m_Buckets.find(Kv.first); It != m_Buckets.end())
+ {
+ CacheBucket& Bucket = *It->second;
+ Bucket.UpdateAccessTimes(Kv.second);
+ }
+ }
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::PutStandaloneCacheValue(const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ uint64_t NewFileSize = Value.Value.Size();
+
+ TemporaryFile DataFile;
+
+ std::error_code Ec;
+ DataFile.CreateTemporary(m_BucketDir.c_str(), Ec);
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("Failed to open temporary file for put in '{}'", m_BucketDir));
+ }
+
+ bool CleanUpTempFile = false;
+ auto __ = MakeGuard([&] {
+ if (CleanUpTempFile)
+ {
+ std::error_code Ec;
+ std::filesystem::remove(DataFile.GetPath(), Ec);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to clean up temporary file '{}' for put in '{}', reason '{}'",
+ DataFile.GetPath(),
+ m_BucketDir,
+ Ec.message());
+ }
+ }
+ });
+
+ DataFile.WriteAll(Value.Value, Ec);
+ if (Ec)
+ {
+ throw std::system_error(Ec,
+ fmt::format("Failed to write payload ({} bytes) to temporary file '{}' for put in '{}'",
+ NiceBytes(NewFileSize),
+ DataFile.GetPath().string(),
+ m_BucketDir));
+ }
+
+ ExtendablePathBuilder<256> DataFilePath;
+ BuildPath(DataFilePath, HashKey);
+ std::filesystem::path FsPath{DataFilePath.ToPath()};
+
+ RwLock::ExclusiveLockScope ValueLock(LockForHash(HashKey));
+
+ // We do a speculative remove of the file instead of probing with a exists call and check the error code instead
+ std::filesystem::remove(FsPath, Ec);
+ if (Ec)
+ {
+ if (Ec.value() != ENOENT)
+ {
+ ZEN_WARN("Failed to remove file '{}' for put in '{}', reason: '{}', retrying.", FsPath, m_BucketDir, Ec.message());
+ Sleep(100);
+ Ec.clear();
+ std::filesystem::remove(FsPath, Ec);
+ if (Ec && Ec.value() != ENOENT)
+ {
+ throw std::system_error(Ec, fmt::format("Failed to remove file '{}' for put in '{}'", FsPath, m_BucketDir));
+ }
+ }
+ }
+
+ DataFile.MoveTemporaryIntoPlace(FsPath, Ec);
+ if (Ec)
+ {
+ CreateDirectories(FsPath.parent_path());
+ Ec.clear();
+
+ // Try again
+ DataFile.MoveTemporaryIntoPlace(FsPath, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to finalize file '{}', moving from '{}' for put in '{}', reason: '{}', retrying.",
+ FsPath,
+ DataFile.GetPath(),
+ m_BucketDir,
+ Ec.message());
+ Sleep(100);
+ Ec.clear();
+ DataFile.MoveTemporaryIntoPlace(FsPath, Ec);
+ if (Ec)
+ {
+ throw std::system_error(
+ Ec,
+ fmt::format("Failed to finalize file '{}', moving from '{}' for put in '{}'", FsPath, DataFile.GetPath(), m_BucketDir));
+ }
+ }
+ }
+
+ // Once we have called MoveTemporaryIntoPlace automatic clean up the temp file
+ // will be disabled as the file handle has already been closed
+ CleanUpTempFile = false;
+
+ uint8_t EntryFlags = DiskLocation::kStandaloneFile;
+
+ if (Value.Value.GetContentType() == ZenContentType::kCbObject)
+ {
+ EntryFlags |= DiskLocation::kStructured;
+ }
+ else if (Value.Value.GetContentType() == ZenContentType::kCompressedBinary)
+ {
+ EntryFlags |= DiskLocation::kCompressed;
+ }
+
+ DiskLocation Loc(NewFileSize, EntryFlags);
+
+ RwLock::ExclusiveLockScope _(m_IndexLock);
+ if (auto It = m_Index.find(HashKey); It == m_Index.end())
+ {
+ // Previously unknown object
+ size_t EntryIndex = m_Payloads.size();
+ m_Payloads.emplace_back(BucketPayload{.Location = Loc, .RawSize = Value.RawSize, .RawHash = Value.RawHash});
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_Index.insert_or_assign(HashKey, EntryIndex);
+ }
+ else
+ {
+ // TODO: should check if write is idempotent and bail out if it is?
+ size_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size());
+ m_Payloads[EntryIndex] = BucketPayload{.Location = Loc, .RawSize = Value.RawSize, .RawHash = Value.RawHash};
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_TotalStandaloneSize.fetch_sub(Loc.Size(), std::memory_order::relaxed);
+ }
+
+ m_SlogFile.Append({.Key = HashKey, .Location = Loc});
+ m_TotalStandaloneSize.fetch_add(NewFileSize, std::memory_order::relaxed);
+}
+
+void
+ZenCacheDiskLayer::CacheBucket::PutInlineCacheValue(const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ uint8_t EntryFlags = 0;
+
+ if (Value.Value.GetContentType() == ZenContentType::kCbObject)
+ {
+ EntryFlags |= DiskLocation::kStructured;
+ }
+ else if (Value.Value.GetContentType() == ZenContentType::kCompressedBinary)
+ {
+ EntryFlags |= DiskLocation::kCompressed;
+ }
+
+ m_BlockStore.WriteChunk(Value.Value.Data(), Value.Value.Size(), m_PayloadAlignment, [&](const BlockStoreLocation& BlockStoreLocation) {
+ DiskLocation Location(BlockStoreLocation, m_PayloadAlignment, EntryFlags);
+ m_SlogFile.Append({.Key = HashKey, .Location = Location});
+
+ RwLock::ExclusiveLockScope _(m_IndexLock);
+ if (auto It = m_Index.find(HashKey); It != m_Index.end())
+ {
+ // TODO: should check if write is idempotent and bail out if it is?
+ // this would requiring comparing contents on disk unless we add a
+ // content hash to the index entry
+ size_t EntryIndex = It.value();
+ ZEN_ASSERT_SLOW(EntryIndex < m_AccessTimes.size());
+ m_Payloads[EntryIndex] = (BucketPayload{.Location = Location, .RawSize = Value.RawSize, .RawHash = Value.RawHash});
+ m_AccessTimes[EntryIndex] = GcClock::TickCount();
+ }
+ else
+ {
+ size_t EntryIndex = m_Payloads.size();
+ m_Payloads.emplace_back(BucketPayload{.Location = Location, .RawSize = Value.RawSize, .RawHash = Value.RawHash});
+ m_AccessTimes.emplace_back(GcClock::TickCount());
+ m_Index.insert_or_assign(HashKey, EntryIndex);
+ }
+ });
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenCacheDiskLayer::ZenCacheDiskLayer(const std::filesystem::path& RootDir) : m_RootDir(RootDir)
+{
+}
+
+ZenCacheDiskLayer::~ZenCacheDiskLayer() = default;
+
+bool
+ZenCacheDiskLayer::Get(std::string_view InBucket, const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ const auto BucketName = std::string(InBucket);
+ CacheBucket* Bucket = nullptr;
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+
+ auto It = m_Buckets.find(BucketName);
+
+ if (It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ }
+
+ if (Bucket == nullptr)
+ {
+ // Bucket needs to be opened/created
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ else
+ {
+ auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName));
+ Bucket = InsertResult.first->second.get();
+
+ std::filesystem::path BucketPath = m_RootDir;
+ BucketPath /= BucketName;
+
+ if (!Bucket->OpenOrCreate(BucketPath))
+ {
+ m_Buckets.erase(InsertResult.first);
+ return false;
+ }
+ }
+ }
+
+ ZEN_ASSERT(Bucket != nullptr);
+ return Bucket->Get(HashKey, OutValue);
+}
+
+void
+ZenCacheDiskLayer::Put(std::string_view InBucket, const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ const auto BucketName = std::string(InBucket);
+ CacheBucket* Bucket = nullptr;
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+
+ auto It = m_Buckets.find(BucketName);
+
+ if (It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ }
+
+ if (Bucket == nullptr)
+ {
+ // New bucket needs to be created
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end())
+ {
+ Bucket = It->second.get();
+ }
+ else
+ {
+ auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName));
+ Bucket = InsertResult.first->second.get();
+
+ std::filesystem::path BucketPath = m_RootDir;
+ BucketPath /= BucketName;
+
+ try
+ {
+ if (!Bucket->OpenOrCreate(BucketPath))
+ {
+ ZEN_WARN("Found directory '{}' in our base directory '{}' but it is not a valid bucket", BucketName, m_RootDir);
+ m_Buckets.erase(InsertResult.first);
+ return;
+ }
+ }
+ catch (const std::exception& Err)
+ {
+ ZEN_ERROR("creating bucket '{}' in '{}' FAILED, reason: '{}'", BucketName, BucketPath, Err.what());
+ return;
+ }
+ }
+ }
+
+ ZEN_ASSERT(Bucket != nullptr);
+
+ Bucket->Put(HashKey, Value);
+}
+
+void
+ZenCacheDiskLayer::DiscoverBuckets()
+{
+ DirectoryContent DirContent;
+ GetDirectoryContent(m_RootDir, DirectoryContent::IncludeDirsFlag, DirContent);
+
+ // Initialize buckets
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ for (const std::filesystem::path& BucketPath : DirContent.Directories)
+ {
+ const std::string BucketName = PathToUtf8(BucketPath.stem());
+ // New bucket needs to be created
+ if (auto It = m_Buckets.find(BucketName); It != m_Buckets.end())
+ {
+ continue;
+ }
+
+ auto InsertResult = m_Buckets.emplace(BucketName, std::make_unique<CacheBucket>(BucketName));
+ CacheBucket& Bucket = *InsertResult.first->second;
+
+ try
+ {
+ if (!Bucket.OpenOrCreate(BucketPath, /* AllowCreate */ false))
+ {
+ ZEN_WARN("Found directory '{}' in our base directory '{}' but it is not a valid bucket", BucketName, m_RootDir);
+
+ m_Buckets.erase(InsertResult.first);
+ continue;
+ }
+ }
+ catch (const std::exception& Err)
+ {
+ ZEN_ERROR("creating bucket '{}' in '{}' FAILED, reason: '{}'", BucketName, BucketPath, Err.what());
+ return;
+ }
+ ZEN_INFO("Discovered bucket '{}'", BucketName);
+ }
+}
+
+bool
+ZenCacheDiskLayer::DropBucket(std::string_view InBucket)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ auto It = m_Buckets.find(std::string(InBucket));
+
+ if (It != m_Buckets.end())
+ {
+ CacheBucket& Bucket = *It->second;
+ m_DroppedBuckets.push_back(std::move(It->second));
+ m_Buckets.erase(It);
+
+ return Bucket.Drop();
+ }
+
+ // Make sure we remove the folder even if we don't know about the bucket
+ std::filesystem::path BucketPath = m_RootDir;
+ BucketPath /= std::string(InBucket);
+ return MoveAndDeleteDirectory(BucketPath);
+}
+
+bool
+ZenCacheDiskLayer::Drop()
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+
+ std::vector<std::unique_ptr<CacheBucket>> Buckets;
+ Buckets.reserve(m_Buckets.size());
+ while (!m_Buckets.empty())
+ {
+ const auto& It = m_Buckets.begin();
+ CacheBucket& Bucket = *It->second;
+ m_DroppedBuckets.push_back(std::move(It->second));
+ m_Buckets.erase(It->first);
+ if (!Bucket.Drop())
+ {
+ return false;
+ }
+ }
+ return MoveAndDeleteDirectory(m_RootDir);
+}
+
+void
+ZenCacheDiskLayer::Flush()
+{
+ std::vector<CacheBucket*> Buckets;
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ Buckets.reserve(m_Buckets.size());
+ for (auto& Kv : m_Buckets)
+ {
+ CacheBucket* Bucket = Kv.second.get();
+ Buckets.push_back(Bucket);
+ }
+ }
+
+ for (auto& Bucket : Buckets)
+ {
+ Bucket->Flush();
+ }
+}
+
+void
+ZenCacheDiskLayer::Scrub(ScrubContext& Ctx)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ CacheBucket& Bucket = *Kv.second;
+ Bucket.Scrub(Ctx);
+ }
+}
+
+void
+ZenCacheDiskLayer::GatherReferences(GcContext& GcCtx)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ CacheBucket& Bucket = *Kv.second;
+ Bucket.GatherReferences(GcCtx);
+ }
+}
+
+uint64_t
+ZenCacheDiskLayer::TotalSize() const
+{
+ uint64_t TotalSize{};
+ RwLock::SharedLockScope _(m_Lock);
+
+ for (auto& Kv : m_Buckets)
+ {
+ TotalSize += Kv.second->TotalSize();
+ }
+
+ return TotalSize;
+}
+
+ZenCacheDiskLayer::Info
+ZenCacheDiskLayer::GetInfo() const
+{
+ ZenCacheDiskLayer::Info Info = {.Config = {.RootDir = m_RootDir}, .TotalSize = TotalSize()};
+
+ RwLock::SharedLockScope _(m_Lock);
+ Info.BucketNames.reserve(m_Buckets.size());
+ for (auto& Kv : m_Buckets)
+ {
+ Info.BucketNames.push_back(Kv.first);
+ Info.EntryCount += Kv.second->EntryCount();
+ }
+ return Info;
+}
+
+std::optional<ZenCacheDiskLayer::BucketInfo>
+ZenCacheDiskLayer::GetBucketInfo(std::string_view Bucket) const
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ if (auto It = m_Buckets.find(std::string(Bucket)); It != m_Buckets.end())
+ {
+ return ZenCacheDiskLayer::BucketInfo{.EntryCount = It->second->EntryCount(), .TotalSize = It->second->TotalSize()};
+ }
+ return {};
+}
+
+CacheValueDetails::NamespaceDetails
+ZenCacheDiskLayer::GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const
+{
+ RwLock::SharedLockScope _(m_Lock);
+ CacheValueDetails::NamespaceDetails Details;
+ if (BucketFilter.empty())
+ {
+ Details.Buckets.reserve(BucketFilter.empty() ? m_Buckets.size() : 1);
+ for (auto& Kv : m_Buckets)
+ {
+ Details.Buckets[Kv.first] = Kv.second->GetValueDetails(ValueFilter);
+ }
+ }
+ else if (auto It = m_Buckets.find(std::string(BucketFilter)); It != m_Buckets.end())
+ {
+ Details.Buckets[It->first] = It->second->GetValueDetails(ValueFilter);
+ }
+ return Details;
+}
+
+//////////////////////////// ZenCacheStore
+
+static constexpr std::string_view UE4DDCNamespaceName = "ue4.ddc";
+
+ZenCacheStore::ZenCacheStore(GcManager& Gc, const Configuration& Configuration) : m_Gc(Gc), m_Configuration(Configuration)
+{
+ CreateDirectories(m_Configuration.BasePath);
+
+ DirectoryContent DirContent;
+ GetDirectoryContent(m_Configuration.BasePath, DirectoryContent::IncludeDirsFlag, DirContent);
+
+ std::vector<std::string> Namespaces;
+ for (const std::filesystem::path& DirPath : DirContent.Directories)
+ {
+ std::string DirName = PathToUtf8(DirPath.filename());
+ if (DirName.starts_with(NamespaceDiskPrefix))
+ {
+ Namespaces.push_back(DirName.substr(NamespaceDiskPrefix.length()));
+ continue;
+ }
+ }
+
+ ZEN_INFO("Found {} namespaces in '{}'", Namespaces.size(), m_Configuration.BasePath);
+
+ if (std::find(Namespaces.begin(), Namespaces.end(), UE4DDCNamespaceName) == Namespaces.end())
+ {
+ // default (unspecified) and ue4-ddc namespace points to the same namespace instance
+
+ std::filesystem::path DefaultNamespaceFolder =
+ m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, UE4DDCNamespaceName);
+ CreateDirectories(DefaultNamespaceFolder);
+ Namespaces.push_back(std::string(UE4DDCNamespaceName));
+ }
+
+ for (const std::string& NamespaceName : Namespaces)
+ {
+ m_Namespaces[NamespaceName] =
+ std::make_unique<ZenCacheNamespace>(Gc, m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, NamespaceName));
+ }
+}
+
+ZenCacheStore::~ZenCacheStore()
+{
+ m_Namespaces.clear();
+}
+
+bool
+ZenCacheStore::Get(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue)
+{
+ if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store)
+ {
+ return Store->Get(Bucket, HashKey, OutValue);
+ }
+ ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Get, bucket '{}', key '{}'", Namespace, Bucket, HashKey.ToHexString());
+
+ return false;
+}
+
+void
+ZenCacheStore::Put(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value)
+{
+ if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store)
+ {
+ return Store->Put(Bucket, HashKey, Value);
+ }
+ ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::Put, bucket '{}', key '{}'", Namespace, Bucket, HashKey.ToHexString());
+}
+
+bool
+ZenCacheStore::DropBucket(std::string_view Namespace, std::string_view Bucket)
+{
+ if (ZenCacheNamespace* Store = GetNamespace(Namespace); Store)
+ {
+ return Store->DropBucket(Bucket);
+ }
+ ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::DropBucket, bucket '{}'", Namespace, Bucket);
+ return false;
+}
+
+bool
+ZenCacheStore::DropNamespace(std::string_view InNamespace)
+{
+ RwLock::SharedLockScope _(m_NamespacesLock);
+ if (auto It = m_Namespaces.find(std::string(InNamespace)); It != m_Namespaces.end())
+ {
+ ZenCacheNamespace& Namespace = *It->second;
+ m_DroppedNamespaces.push_back(std::move(It->second));
+ m_Namespaces.erase(It);
+ return Namespace.Drop();
+ }
+ ZEN_WARN("request for unknown namespace '{}' in ZenCacheStore::DropNamespace", InNamespace);
+ return false;
+}
+
+void
+ZenCacheStore::Flush()
+{
+ IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) { Store.Flush(); });
+}
+
+void
+ZenCacheStore::Scrub(ScrubContext& Ctx)
+{
+ IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) { Store.Scrub(Ctx); });
+}
+
+CacheValueDetails
+ZenCacheStore::GetValueDetails(const std::string_view NamespaceFilter,
+ const std::string_view BucketFilter,
+ const std::string_view ValueFilter) const
+{
+ CacheValueDetails Details;
+ if (NamespaceFilter.empty())
+ {
+ IterateNamespaces([&](std::string_view Namespace, ZenCacheNamespace& Store) {
+ Details.Namespaces[std::string(Namespace)] = Store.GetValueDetails(BucketFilter, ValueFilter);
+ });
+ }
+ else if (const ZenCacheNamespace* Store = FindNamespace(NamespaceFilter); Store != nullptr)
+ {
+ Details.Namespaces[std::string(NamespaceFilter)] = Store->GetValueDetails(BucketFilter, ValueFilter);
+ }
+ return Details;
+}
+
+ZenCacheNamespace*
+ZenCacheStore::GetNamespace(std::string_view Namespace)
+{
+ RwLock::SharedLockScope _(m_NamespacesLock);
+ if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end())
+ {
+ return It->second.get();
+ }
+ if (Namespace == DefaultNamespace)
+ {
+ if (auto It = m_Namespaces.find(std::string(UE4DDCNamespaceName)); It != m_Namespaces.end())
+ {
+ return It->second.get();
+ }
+ }
+ _.ReleaseNow();
+
+ if (!m_Configuration.AllowAutomaticCreationOfNamespaces)
+ {
+ return nullptr;
+ }
+
+ RwLock::ExclusiveLockScope __(m_NamespacesLock);
+ if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end())
+ {
+ return It->second.get();
+ }
+
+ auto NewNamespace = m_Namespaces.insert_or_assign(
+ std::string(Namespace),
+ std::make_unique<ZenCacheNamespace>(m_Gc, m_Configuration.BasePath / fmt::format("{}{}", NamespaceDiskPrefix, Namespace)));
+ return NewNamespace.first->second.get();
+}
+
+const ZenCacheNamespace*
+ZenCacheStore::FindNamespace(std::string_view Namespace) const
+{
+ RwLock::SharedLockScope _(m_NamespacesLock);
+ if (auto It = m_Namespaces.find(std::string(Namespace)); It != m_Namespaces.end())
+ {
+ return It->second.get();
+ }
+ if (Namespace == DefaultNamespace)
+ {
+ if (auto It = m_Namespaces.find(std::string(UE4DDCNamespaceName)); It != m_Namespaces.end())
+ {
+ return It->second.get();
+ }
+ }
+ return nullptr;
+}
+
+void
+ZenCacheStore::IterateNamespaces(const std::function<void(std::string_view Namespace, ZenCacheNamespace& Store)>& Callback) const
+{
+ std::vector<std::pair<std::string, ZenCacheNamespace&>> Namespaces;
+ {
+ RwLock::SharedLockScope _(m_NamespacesLock);
+ Namespaces.reserve(m_Namespaces.size());
+ for (const auto& Entry : m_Namespaces)
+ {
+ if (Entry.first == DefaultNamespace)
+ {
+ continue;
+ }
+ Namespaces.push_back({Entry.first, *Entry.second});
+ }
+ }
+ for (auto& Entry : Namespaces)
+ {
+ Callback(Entry.first, Entry.second);
+ }
+}
+
+GcStorageSize
+ZenCacheStore::StorageSize() const
+{
+ GcStorageSize Size;
+ IterateNamespaces([&](std::string_view, ZenCacheNamespace& Store) {
+ GcStorageSize StoreSize = Store.StorageSize();
+ Size.MemorySize += StoreSize.MemorySize;
+ Size.DiskSize += StoreSize.DiskSize;
+ });
+ return Size;
+}
+
+ZenCacheStore::Info
+ZenCacheStore::GetInfo() const
+{
+ ZenCacheStore::Info Info = {.Config = m_Configuration, .StorageSize = StorageSize()};
+
+ IterateNamespaces([&Info](std::string_view NamespaceName, ZenCacheNamespace& Namespace) {
+ Info.NamespaceNames.push_back(std::string(NamespaceName));
+ ZenCacheNamespace::Info NamespaceInfo = Namespace.GetInfo();
+ Info.DiskEntryCount += NamespaceInfo.DiskLayerInfo.EntryCount;
+ Info.MemoryEntryCount += NamespaceInfo.MemoryLayerInfo.EntryCount;
+ });
+
+ return Info;
+}
+
+std::optional<ZenCacheNamespace::Info>
+ZenCacheStore::GetNamespaceInfo(std::string_view NamespaceName)
+{
+ if (const ZenCacheNamespace* Namespace = FindNamespace(NamespaceName); Namespace)
+ {
+ return Namespace->GetInfo();
+ }
+ return {};
+}
+
+std::optional<ZenCacheNamespace::BucketInfo>
+ZenCacheStore::GetBucketInfo(std::string_view NamespaceName, std::string_view BucketName)
+{
+ if (const ZenCacheNamespace* Namespace = FindNamespace(NamespaceName); Namespace)
+ {
+ return Namespace->GetBucketInfo(BucketName);
+ }
+ return {};
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+using namespace std::literals;
+
+namespace testutils {
+ IoHash CreateKey(size_t KeyValue) { return IoHash::HashBuffer(&KeyValue, sizeof(size_t)); }
+
+ IoBuffer CreateBinaryCacheValue(uint64_t Size)
+ {
+ static std::random_device rd;
+ static std::mt19937 g(rd());
+
+ std::vector<uint8_t> Values;
+ Values.resize(Size);
+ for (size_t Idx = 0; Idx < Size; ++Idx)
+ {
+ Values[Idx] = static_cast<uint8_t>(Idx);
+ }
+ std::shuffle(Values.begin(), Values.end(), g);
+
+ IoBuffer Buf(IoBuffer::Clone, Values.data(), Values.size());
+ Buf.SetContentType(ZenContentType::kBinary);
+ return Buf;
+ };
+
+} // namespace testutils
+
+TEST_CASE("z$.store")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ const int kIterationCount = 100;
+
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ const IoHash Key = IoHash::HashBuffer(&i, sizeof i);
+
+ CbObjectWriter Cbo;
+ Cbo << "hey" << i;
+ CbObject Obj = Cbo.Save();
+
+ ZenCacheValue Value;
+ Value.Value = Obj.GetBuffer().AsIoBuffer();
+ Value.Value.SetContentType(ZenContentType::kCbObject);
+
+ Zcs.Put("test_bucket"sv, Key, Value);
+ }
+
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ const IoHash Key = IoHash::HashBuffer(&i, sizeof i);
+
+ ZenCacheValue Value;
+ Zcs.Get("test_bucket"sv, Key, /* out */ Value);
+
+ REQUIRE(Value.Value);
+ CHECK(Value.Value.GetContentType() == ZenContentType::kCbObject);
+ CHECK_EQ(ValidateCompactBinary(Value.Value, CbValidateMode::All), CbValidateError::None);
+ CbObject Obj = LoadCompactBinaryObject(Value.Value);
+ CHECK_EQ(Obj["hey"].AsInt32(), i);
+ }
+}
+
+TEST_CASE("z$.size")
+{
+ const auto CreateCacheValue = [](size_t Size) -> CbObject {
+ std::vector<uint8_t> Buf;
+ Buf.resize(Size);
+
+ CbObjectWriter Writer;
+ Writer.AddBinary("Binary"sv, Buf.data(), Buf.size());
+ return Writer.Save();
+ };
+
+ SUBCASE("mem/disklayer")
+ {
+ const size_t Count = 16;
+ ScopedTemporaryDirectory TempDir;
+
+ GcStorageSize CacheSize;
+
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ CbObject CacheValue = CreateCacheValue(Zcs.DiskLayerThreshold() - 256);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ for (size_t Key = 0; Key < Count; ++Key)
+ {
+ const size_t Bucket = Key % 4;
+ Zcs.Put(fmt::format("test_bucket-{}", Bucket), IoHash::HashBuffer(&Key, sizeof(uint32_t)), ZenCacheValue{.Value = Buffer});
+ }
+
+ CacheSize = Zcs.StorageSize();
+ CHECK_LE(CacheValue.GetSize() * Count, CacheSize.DiskSize);
+ CHECK_LE(CacheValue.GetSize() * Count, CacheSize.MemorySize);
+ }
+
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ const GcStorageSize SerializedSize = Zcs.StorageSize();
+ CHECK_EQ(SerializedSize.MemorySize, 0);
+ CHECK_LE(SerializedSize.DiskSize, CacheSize.DiskSize);
+
+ for (size_t Bucket = 0; Bucket < 4; ++Bucket)
+ {
+ Zcs.DropBucket(fmt::format("test_bucket-{}", Bucket));
+ }
+ CHECK_EQ(0, Zcs.StorageSize().DiskSize);
+ }
+ }
+
+ SUBCASE("disklayer")
+ {
+ const size_t Count = 16;
+ ScopedTemporaryDirectory TempDir;
+
+ GcStorageSize CacheSize;
+
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ CbObject CacheValue = CreateCacheValue(Zcs.DiskLayerThreshold() + 64);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ for (size_t Key = 0; Key < Count; ++Key)
+ {
+ const size_t Bucket = Key % 4;
+ Zcs.Put(fmt::format("test_bucket-{}", Bucket), IoHash::HashBuffer(&Key, sizeof(uint32_t)), {.Value = Buffer});
+ }
+
+ CacheSize = Zcs.StorageSize();
+ CHECK_LE(CacheValue.GetSize() * Count, CacheSize.DiskSize);
+ CHECK_EQ(0, CacheSize.MemorySize);
+ }
+
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ const GcStorageSize SerializedSize = Zcs.StorageSize();
+ CHECK_EQ(SerializedSize.MemorySize, 0);
+ CHECK_LE(SerializedSize.DiskSize, CacheSize.DiskSize);
+
+ for (size_t Bucket = 0; Bucket < 4; ++Bucket)
+ {
+ Zcs.DropBucket(fmt::format("test_bucket-{}", Bucket));
+ }
+ CHECK_EQ(0, Zcs.StorageSize().DiskSize);
+ }
+ }
+}
+
+TEST_CASE("z$.gc")
+{
+ using namespace testutils;
+
+ SUBCASE("gather references does NOT add references for expired cache entries")
+ {
+ ScopedTemporaryDirectory TempDir;
+ std::vector<IoHash> Cids{CreateKey(1), CreateKey(2), CreateKey(3)};
+
+ const auto CollectAndFilter = [](GcManager& Gc,
+ GcClock::TimePoint Time,
+ GcClock::Duration MaxDuration,
+ std::span<const IoHash> Cids,
+ std::vector<IoHash>& OutKeep) {
+ GcContext GcCtx(Time - MaxDuration);
+ Gc.CollectGarbage(GcCtx);
+ OutKeep.clear();
+ GcCtx.FilterCids(Cids, [&OutKeep](const IoHash& Hash) { OutKeep.push_back(Hash); });
+ };
+
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+ const auto Bucket = "teardrinker"sv;
+
+ // Create a cache record
+ const IoHash Key = CreateKey(42);
+ CbObjectWriter Record;
+ Record << "Key"sv
+ << "SomeRecord"sv;
+
+ for (size_t Idx = 0; auto& Cid : Cids)
+ {
+ Record.AddBinaryAttachment(fmt::format("attachment-{}", Idx++), Cid);
+ }
+
+ IoBuffer Buffer = Record.Save().GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ Zcs.Put(Bucket, Key, {.Value = Buffer});
+
+ std::vector<IoHash> Keep;
+
+ // Collect garbage with 1 hour max cache duration
+ {
+ CollectAndFilter(Gc, GcClock::Now(), std::chrono::hours(1), Cids, Keep);
+ CHECK_EQ(Cids.size(), Keep.size());
+ }
+
+ // Move forward in time
+ {
+ CollectAndFilter(Gc, GcClock::Now() + std::chrono::hours(2), std::chrono::hours(1), Cids, Keep);
+ CHECK_EQ(0, Keep.size());
+ }
+ }
+
+ // Expect timestamps to be serialized
+ {
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+ std::vector<IoHash> Keep;
+
+ // Collect garbage with 1 hour max cache duration
+ {
+ CollectAndFilter(Gc, GcClock::Now(), std::chrono::hours(1), Cids, Keep);
+ CHECK_EQ(3, Keep.size());
+ }
+
+ // Move forward in time
+ {
+ CollectAndFilter(Gc, GcClock::Now() + std::chrono::hours(2), std::chrono::hours(1), Cids, Keep);
+ CHECK_EQ(0, Keep.size());
+ }
+ }
+ }
+
+ SUBCASE("gc removes standalone values")
+ {
+ ScopedTemporaryDirectory TempDir;
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+ const auto Bucket = "fortysixandtwo"sv;
+ const GcClock::TimePoint CurrentTime = GcClock::Now();
+
+ std::vector<IoHash> Keys{CreateKey(1), CreateKey(2), CreateKey(3)};
+
+ for (const auto& Key : Keys)
+ {
+ IoBuffer Value = testutils::CreateBinaryCacheValue(128 << 10);
+ Zcs.Put(Bucket, Key, {.Value = Value});
+ }
+
+ {
+ GcContext GcCtx(CurrentTime - std::chrono::hours(46));
+
+ Gc.CollectGarbage(GcCtx);
+
+ for (const auto& Key : Keys)
+ {
+ ZenCacheValue CacheValue;
+ const bool Exists = Zcs.Get(Bucket, Key, CacheValue);
+ CHECK(Exists);
+ }
+ }
+
+ // Move forward in time and collect again
+ {
+ GcContext GcCtx(CurrentTime + std::chrono::minutes(2));
+ Gc.CollectGarbage(GcCtx);
+
+ for (const auto& Key : Keys)
+ {
+ ZenCacheValue CacheValue;
+ const bool Exists = Zcs.Get(Bucket, Key, CacheValue);
+ CHECK(!Exists);
+ }
+
+ CHECK_EQ(0, Zcs.StorageSize().DiskSize);
+ }
+ }
+
+ SUBCASE("gc removes small objects")
+ {
+ ScopedTemporaryDirectory TempDir;
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+ const auto Bucket = "rightintwo"sv;
+
+ std::vector<IoHash> Keys{CreateKey(1), CreateKey(2), CreateKey(3)};
+
+ for (const auto& Key : Keys)
+ {
+ IoBuffer Value = testutils::CreateBinaryCacheValue(128);
+ Zcs.Put(Bucket, Key, {.Value = Value});
+ }
+
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(2));
+ GcCtx.CollectSmallObjects(true);
+
+ Gc.CollectGarbage(GcCtx);
+
+ for (const auto& Key : Keys)
+ {
+ ZenCacheValue CacheValue;
+ const bool Exists = Zcs.Get(Bucket, Key, CacheValue);
+ CHECK(Exists);
+ }
+ }
+
+ // Move forward in time and collect again
+ {
+ GcContext GcCtx(GcClock::Now() + std::chrono::minutes(2));
+ GcCtx.CollectSmallObjects(true);
+
+ Zcs.Flush();
+ Gc.CollectGarbage(GcCtx);
+
+ for (const auto& Key : Keys)
+ {
+ ZenCacheValue CacheValue;
+ const bool Exists = Zcs.Get(Bucket, Key, CacheValue);
+ CHECK(!Exists);
+ }
+
+ CHECK_EQ(0, Zcs.StorageSize().DiskSize);
+ }
+ }
+}
+
+TEST_CASE("z$.threadedinsert") // * doctest::skip(true))
+{
+ // for (uint32_t i = 0; i < 100; ++i)
+ {
+ ScopedTemporaryDirectory TempDir;
+
+ const uint64_t kChunkSize = 1048;
+ const int32_t kChunkCount = 8192;
+
+ struct Chunk
+ {
+ std::string Bucket;
+ IoBuffer Buffer;
+ };
+ std::unordered_map<IoHash, Chunk, IoHash::Hasher> Chunks;
+ Chunks.reserve(kChunkCount);
+
+ const std::string Bucket1 = "rightinone";
+ const std::string Bucket2 = "rightintwo";
+
+ for (int32_t Idx = 0; Idx < kChunkCount; ++Idx)
+ {
+ while (true)
+ {
+ IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize);
+ IoHash Hash = HashBuffer(Chunk);
+ if (Chunks.contains(Hash))
+ {
+ continue;
+ }
+ Chunks[Hash] = {.Bucket = Bucket1, .Buffer = Chunk};
+ break;
+ }
+ while (true)
+ {
+ IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize);
+ IoHash Hash = HashBuffer(Chunk);
+ if (Chunks.contains(Hash))
+ {
+ continue;
+ }
+ Chunks[Hash] = {.Bucket = Bucket2, .Buffer = Chunk};
+ break;
+ }
+ }
+
+ CreateDirectories(TempDir.Path());
+
+ WorkerThreadPool ThreadPool(4);
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path());
+
+ {
+ std::atomic<size_t> WorkCompleted = 0;
+ for (const auto& Chunk : Chunks)
+ {
+ ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, &Chunk]() {
+ Zcs.Put(Chunk.second.Bucket, Chunk.first, {.Value = Chunk.second.Buffer});
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < Chunks.size())
+ {
+ Sleep(1);
+ }
+ }
+
+ const uint64_t TotalSize = Zcs.StorageSize().DiskSize;
+ CHECK_LE(kChunkSize * Chunks.size(), TotalSize);
+
+ {
+ std::atomic<size_t> WorkCompleted = 0;
+ for (const auto& Chunk : Chunks)
+ {
+ ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, &Chunk]() {
+ std::string Bucket = Chunk.second.Bucket;
+ IoHash ChunkHash = Chunk.first;
+ ZenCacheValue CacheValue;
+
+ CHECK(Zcs.Get(Bucket, ChunkHash, CacheValue));
+ IoHash Hash = IoHash::HashBuffer(CacheValue.Value);
+ CHECK(ChunkHash == Hash);
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < Chunks.size())
+ {
+ Sleep(1);
+ }
+ }
+ std::unordered_map<IoHash, std::string, IoHash::Hasher> GcChunkHashes;
+ GcChunkHashes.reserve(Chunks.size());
+ for (const auto& Chunk : Chunks)
+ {
+ GcChunkHashes[Chunk.first] = Chunk.second.Bucket;
+ }
+ {
+ std::unordered_map<IoHash, Chunk, IoHash::Hasher> NewChunks;
+
+ for (int32_t Idx = 0; Idx < kChunkCount; ++Idx)
+ {
+ {
+ IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize);
+ IoHash Hash = HashBuffer(Chunk);
+ NewChunks[Hash] = {.Bucket = Bucket1, .Buffer = Chunk};
+ }
+ {
+ IoBuffer Chunk = testutils::CreateBinaryCacheValue(kChunkSize);
+ IoHash Hash = HashBuffer(Chunk);
+ NewChunks[Hash] = {.Bucket = Bucket2, .Buffer = Chunk};
+ }
+ }
+
+ std::atomic<size_t> WorkCompleted = 0;
+ std::atomic_uint32_t AddedChunkCount = 0;
+ for (const auto& Chunk : NewChunks)
+ {
+ ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk, &AddedChunkCount]() {
+ Zcs.Put(Chunk.second.Bucket, Chunk.first, {.Value = Chunk.second.Buffer});
+ AddedChunkCount.fetch_add(1);
+ WorkCompleted.fetch_add(1);
+ });
+ }
+
+ for (const auto& Chunk : Chunks)
+ {
+ ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk]() {
+ ZenCacheValue CacheValue;
+ if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue))
+ {
+ CHECK(Chunk.first == IoHash::HashBuffer(CacheValue.Value));
+ }
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (AddedChunkCount.load() < NewChunks.size())
+ {
+ // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope
+ for (const auto& Chunk : NewChunks)
+ {
+ ZenCacheValue CacheValue;
+ if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue))
+ {
+ GcChunkHashes[Chunk.first] = Chunk.second.Bucket;
+ }
+ }
+ std::vector<IoHash> KeepHashes;
+ KeepHashes.reserve(GcChunkHashes.size());
+ for (const auto& Entry : GcChunkHashes)
+ {
+ KeepHashes.push_back(Entry.first);
+ }
+ size_t C = 0;
+ while (C < KeepHashes.size())
+ {
+ if (C % 155 == 0)
+ {
+ if (C < KeepHashes.size() - 1)
+ {
+ KeepHashes[C] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ if (C + 3 < KeepHashes.size() - 1)
+ {
+ KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ }
+ C++;
+ }
+
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ GcCtx.AddRetainedCids(KeepHashes);
+ Zcs.CollectGarbage(GcCtx);
+ const HashKeySet& Deleted = GcCtx.DeletedCids();
+ Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); });
+ }
+
+ while (WorkCompleted < NewChunks.size() + Chunks.size())
+ {
+ Sleep(1);
+ }
+
+ {
+ // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope
+ for (const auto& Chunk : NewChunks)
+ {
+ ZenCacheValue CacheValue;
+ if (Zcs.Get(Chunk.second.Bucket, Chunk.first, CacheValue))
+ {
+ GcChunkHashes[Chunk.first] = Chunk.second.Bucket;
+ }
+ }
+ std::vector<IoHash> KeepHashes;
+ KeepHashes.reserve(GcChunkHashes.size());
+ for (const auto& Entry : GcChunkHashes)
+ {
+ KeepHashes.push_back(Entry.first);
+ }
+ size_t C = 0;
+ while (C < KeepHashes.size())
+ {
+ if (C % 155 == 0)
+ {
+ if (C < KeepHashes.size() - 1)
+ {
+ KeepHashes[C] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ if (C + 3 < KeepHashes.size() - 1)
+ {
+ KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ }
+ C++;
+ }
+
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ GcCtx.AddRetainedCids(KeepHashes);
+ Zcs.CollectGarbage(GcCtx);
+ const HashKeySet& Deleted = GcCtx.DeletedCids();
+ Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); });
+ }
+ }
+ {
+ std::atomic<size_t> WorkCompleted = 0;
+ for (const auto& Chunk : GcChunkHashes)
+ {
+ ThreadPool.ScheduleWork([&Zcs, &WorkCompleted, Chunk]() {
+ ZenCacheValue CacheValue;
+ CHECK(Zcs.Get(Chunk.second, Chunk.first, CacheValue));
+ CHECK(Chunk.first == IoHash::HashBuffer(CacheValue.Value));
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < GcChunkHashes.size())
+ {
+ Sleep(1);
+ }
+ }
+ }
+}
+
+TEST_CASE("z$.namespaces")
+{
+ using namespace testutils;
+
+ const auto CreateCacheValue = [](size_t Size) -> CbObject {
+ std::vector<uint8_t> Buf;
+ Buf.resize(Size);
+
+ CbObjectWriter Writer;
+ Writer.AddBinary("Binary"sv, Buf.data(), Buf.size());
+ return Writer.Save();
+ };
+
+ ScopedTemporaryDirectory TempDir;
+ CreateDirectories(TempDir.Path());
+
+ IoHash Key1;
+ IoHash Key2;
+ {
+ GcManager Gc;
+ ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = false});
+ const auto Bucket = "teardrinker"sv;
+ const auto CustomNamespace = "mynamespace"sv;
+
+ // Create a cache record
+ Key1 = CreateKey(42);
+ CbObject CacheValue = CreateCacheValue(4096);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ ZenCacheValue PutValue = {.Value = Buffer};
+ Zcs.Put(ZenCacheStore::DefaultNamespace, Bucket, Key1, PutValue);
+
+ ZenCacheValue GetValue;
+ CHECK(Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key1, GetValue));
+ CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue));
+
+ // This should just be dropped as we don't allow creating of namespaces on the fly
+ Zcs.Put(CustomNamespace, Bucket, Key1, PutValue);
+ CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue));
+ }
+
+ {
+ GcManager Gc;
+ ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true});
+ const auto Bucket = "teardrinker"sv;
+ const auto CustomNamespace = "mynamespace"sv;
+
+ Key2 = CreateKey(43);
+ CbObject CacheValue2 = CreateCacheValue(4096);
+
+ IoBuffer Buffer2 = CacheValue2.GetBuffer().AsIoBuffer();
+ Buffer2.SetContentType(ZenContentType::kCbObject);
+ ZenCacheValue PutValue2 = {.Value = Buffer2};
+ Zcs.Put(CustomNamespace, Bucket, Key2, PutValue2);
+
+ ZenCacheValue GetValue;
+ CHECK(!Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key2, GetValue));
+ CHECK(Zcs.Get(ZenCacheStore::DefaultNamespace, Bucket, Key1, GetValue));
+ CHECK(!Zcs.Get(CustomNamespace, Bucket, Key1, GetValue));
+ CHECK(Zcs.Get(CustomNamespace, Bucket, Key2, GetValue));
+ }
+}
+
+TEST_CASE("z$.drop.bucket")
+{
+ using namespace testutils;
+
+ const auto CreateCacheValue = [](size_t Size) -> CbObject {
+ std::vector<uint8_t> Buf;
+ Buf.resize(Size);
+
+ CbObjectWriter Writer;
+ Writer.AddBinary("Binary"sv, Buf.data(), Buf.size());
+ return Writer.Save();
+ };
+
+ ScopedTemporaryDirectory TempDir;
+ CreateDirectories(TempDir.Path());
+
+ IoHash Key1;
+ IoHash Key2;
+
+ auto PutValue =
+ [&CreateCacheValue](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, size_t KeyIndex, size_t Size) {
+ // Create a cache record
+ IoHash Key = CreateKey(KeyIndex);
+ CbObject CacheValue = CreateCacheValue(Size);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ ZenCacheValue PutValue = {.Value = Buffer};
+ Zcs.Put(Namespace, Bucket, Key, PutValue);
+ return Key;
+ };
+ auto GetValue = [](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, const IoHash& Key) {
+ ZenCacheValue GetValue;
+ Zcs.Get(Namespace, Bucket, Key, GetValue);
+ return GetValue;
+ };
+ WorkerThreadPool Workers(1);
+ {
+ GcManager Gc;
+ ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true});
+ const auto Bucket = "teardrinker"sv;
+ const auto Namespace = "mynamespace"sv;
+
+ Key1 = PutValue(Zcs, Namespace, Bucket, 42, 4096);
+ Key2 = PutValue(Zcs, Namespace, Bucket, 43, 2048);
+
+ ZenCacheValue Value1 = GetValue(Zcs, Namespace, Bucket, Key1);
+ CHECK(Value1.Value);
+
+ std::atomic_bool WorkComplete = false;
+ Workers.ScheduleWork([&]() {
+ zen::Sleep(100);
+ Value1.Value = IoBuffer{};
+ WorkComplete = true;
+ });
+ // On Windows, DropBucket() will be blocked as long as we hold a reference to a buffer in the bucket
+ // Our DropBucket execution blocks any incoming request from completing until we are done with the drop
+ CHECK(Zcs.DropBucket(Namespace, Bucket));
+ while (!WorkComplete)
+ {
+ zen::Sleep(1);
+ }
+
+ // Entire bucket should be dropped, but doing a request should will re-create the namespace but it must still be empty
+ Value1 = GetValue(Zcs, Namespace, Bucket, Key1);
+ CHECK(!Value1.Value);
+ ZenCacheValue Value2 = GetValue(Zcs, Namespace, Bucket, Key2);
+ CHECK(!Value2.Value);
+ }
+}
+
+TEST_CASE("z$.drop.namespace")
+{
+ using namespace testutils;
+
+ const auto CreateCacheValue = [](size_t Size) -> CbObject {
+ std::vector<uint8_t> Buf;
+ Buf.resize(Size);
+
+ CbObjectWriter Writer;
+ Writer.AddBinary("Binary"sv, Buf.data(), Buf.size());
+ return Writer.Save();
+ };
+
+ ScopedTemporaryDirectory TempDir;
+ CreateDirectories(TempDir.Path());
+
+ auto PutValue =
+ [&CreateCacheValue](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, size_t KeyIndex, size_t Size) {
+ // Create a cache record
+ IoHash Key = CreateKey(KeyIndex);
+ CbObject CacheValue = CreateCacheValue(Size);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ ZenCacheValue PutValue = {.Value = Buffer};
+ Zcs.Put(Namespace, Bucket, Key, PutValue);
+ return Key;
+ };
+ auto GetValue = [](ZenCacheStore& Zcs, std::string_view Namespace, std::string_view Bucket, const IoHash& Key) {
+ ZenCacheValue GetValue;
+ Zcs.Get(Namespace, Bucket, Key, GetValue);
+ return GetValue;
+ };
+ WorkerThreadPool Workers(1);
+ {
+ GcManager Gc;
+ ZenCacheStore Zcs(Gc, {.BasePath = TempDir.Path() / "cache", .AllowAutomaticCreationOfNamespaces = true});
+ const auto Bucket1 = "teardrinker1"sv;
+ const auto Bucket2 = "teardrinker2"sv;
+ const auto Namespace1 = "mynamespace1"sv;
+ const auto Namespace2 = "mynamespace2"sv;
+
+ IoHash Key1 = PutValue(Zcs, Namespace1, Bucket1, 42, 4096);
+ IoHash Key2 = PutValue(Zcs, Namespace1, Bucket2, 43, 2048);
+ IoHash Key3 = PutValue(Zcs, Namespace2, Bucket1, 44, 4096);
+ IoHash Key4 = PutValue(Zcs, Namespace2, Bucket2, 45, 2048);
+
+ ZenCacheValue Value1 = GetValue(Zcs, Namespace1, Bucket1, Key1);
+ CHECK(Value1.Value);
+ ZenCacheValue Value2 = GetValue(Zcs, Namespace1, Bucket2, Key2);
+ CHECK(Value2.Value);
+ ZenCacheValue Value3 = GetValue(Zcs, Namespace2, Bucket1, Key3);
+ CHECK(Value3.Value);
+ ZenCacheValue Value4 = GetValue(Zcs, Namespace2, Bucket2, Key4);
+ CHECK(Value4.Value);
+
+ std::atomic_bool WorkComplete = false;
+ Workers.ScheduleWork([&]() {
+ zen::Sleep(100);
+ Value1.Value = IoBuffer{};
+ Value2.Value = IoBuffer{};
+ Value3.Value = IoBuffer{};
+ Value4.Value = IoBuffer{};
+ WorkComplete = true;
+ });
+ // On Windows, DropBucket() will be blocked as long as we hold a reference to a buffer in the bucket
+ // Our DropBucket execution blocks any incoming request from completing until we are done with the drop
+ CHECK(Zcs.DropNamespace(Namespace1));
+ while (!WorkComplete)
+ {
+ zen::Sleep(1);
+ }
+
+ // Entire namespace should be dropped, but doing a request should will re-create the namespace but it must still be empty
+ Value1 = GetValue(Zcs, Namespace1, Bucket1, Key1);
+ CHECK(!Value1.Value);
+ Value2 = GetValue(Zcs, Namespace1, Bucket2, Key2);
+ CHECK(!Value2.Value);
+ Value3 = GetValue(Zcs, Namespace2, Bucket1, Key3);
+ CHECK(Value3.Value);
+ Value4 = GetValue(Zcs, Namespace2, Bucket2, Key4);
+ CHECK(Value4.Value);
+ }
+}
+
+TEST_CASE("z$.blocked.disklayer.put")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ GcStorageSize CacheSize;
+
+ const auto CreateCacheValue = [](size_t Size) -> CbObject {
+ std::vector<uint8_t> Buf;
+ Buf.resize(Size, Size & 0xff);
+
+ CbObjectWriter Writer;
+ Writer.AddBinary("Binary"sv, Buf.data(), Buf.size());
+ return Writer.Save();
+ };
+
+ GcManager Gc;
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+
+ CbObject CacheValue = CreateCacheValue(64 * 1024 + 64);
+
+ IoBuffer Buffer = CacheValue.GetBuffer().AsIoBuffer();
+ Buffer.SetContentType(ZenContentType::kCbObject);
+
+ size_t Key = Buffer.Size();
+ IoHash HashKey = IoHash::HashBuffer(&Key, sizeof(uint32_t));
+ Zcs.Put("test_bucket", HashKey, {.Value = Buffer});
+
+ ZenCacheValue BufferGet;
+ CHECK(Zcs.Get("test_bucket", HashKey, BufferGet));
+
+ CbObject CacheValue2 = CreateCacheValue(64 * 1024 + 64 + 1);
+ IoBuffer Buffer2 = CacheValue2.GetBuffer().AsIoBuffer();
+ Buffer2.SetContentType(ZenContentType::kCbObject);
+
+ // We should be able to overwrite even if the file is open for read
+ Zcs.Put("test_bucket", HashKey, {.Value = Buffer2});
+
+ MemoryView OldView = BufferGet.Value.GetView();
+
+ ZenCacheValue BufferGet2;
+ CHECK(Zcs.Get("test_bucket", HashKey, BufferGet2));
+ MemoryView NewView = BufferGet2.Value.GetView();
+
+ // Make sure file openend for read before we wrote it still have old data
+ CHECK(OldView.GetSize() == Buffer.GetSize());
+ CHECK(memcmp(OldView.GetData(), Buffer.GetData(), OldView.GetSize()) == 0);
+
+ // Make sure we get the new data when reading after we write new data
+ CHECK(NewView.GetSize() == Buffer2.GetSize());
+ CHECK(memcmp(NewView.GetData(), Buffer2.GetData(), NewView.GetSize()) == 0);
+}
+
+TEST_CASE("z$.scrub")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ using namespace testutils;
+
+ struct CacheRecord
+ {
+ IoBuffer Record;
+ std::vector<CompressedBuffer> Attachments;
+ };
+
+ auto CreateCacheRecord = [](bool Structured, std::string_view Bucket, const IoHash& Key, const std::vector<size_t>& AttachmentSizes) {
+ CacheRecord Result;
+ if (Structured)
+ {
+ Result.Attachments.resize(AttachmentSizes.size());
+ CbObjectWriter Record;
+ Record.BeginObject("Key"sv);
+ {
+ Record << "Bucket"sv << Bucket;
+ Record << "Hash"sv << Key;
+ }
+ Record.EndObject();
+ for (size_t Index = 0; Index < AttachmentSizes.size(); Index++)
+ {
+ IoBuffer AttachmentData = CreateBinaryCacheValue(AttachmentSizes[Index]);
+ CompressedBuffer CompressedAttachmentData = CompressedBuffer::Compress(SharedBuffer(AttachmentData));
+ Record.AddBinaryAttachment(fmt::format("attachment-{}", Index), CompressedAttachmentData.DecodeRawHash());
+ Result.Attachments[Index] = CompressedAttachmentData;
+ }
+ Result.Record = Record.Save().GetBuffer().AsIoBuffer();
+ Result.Record.SetContentType(ZenContentType::kCbObject);
+ }
+ else
+ {
+ std::string RecordData = fmt::format("{}:{}", Bucket, Key.ToHexString());
+ size_t TotalSize = RecordData.length() + 1;
+ for (size_t AttachmentSize : AttachmentSizes)
+ {
+ TotalSize += AttachmentSize;
+ }
+ Result.Record = IoBuffer(TotalSize);
+ char* DataPtr = (char*)Result.Record.MutableData();
+ memcpy(DataPtr, RecordData.c_str(), RecordData.length() + 1);
+ DataPtr += RecordData.length() + 1;
+ for (size_t AttachmentSize : AttachmentSizes)
+ {
+ IoBuffer AttachmentData = CreateBinaryCacheValue(AttachmentSize);
+ memcpy(DataPtr, AttachmentData.GetData(), AttachmentData.GetSize());
+ DataPtr += AttachmentData.GetSize();
+ }
+ }
+ return Result;
+ };
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ ZenCacheNamespace Zcs(Gc, TempDir.Path() / "cache");
+ CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ CidStore.Initialize(CidConfig);
+
+ auto CreateRecords =
+ [&](bool IsStructured, std::string_view BucketName, const std::vector<IoHash>& Cids, const std::vector<size_t>& AttachmentSizes) {
+ for (const IoHash& Cid : Cids)
+ {
+ CacheRecord Record = CreateCacheRecord(IsStructured, BucketName, Cid, AttachmentSizes);
+ Zcs.Put("mybucket", Cid, {.Value = Record.Record});
+ for (const CompressedBuffer& Attachment : Record.Attachments)
+ {
+ CidStore.AddChunk(Attachment.GetCompressed().Flatten().AsIoBuffer(), Attachment.DecodeRawHash());
+ }
+ }
+ };
+
+ std::vector<size_t> AttachmentSizes = {16, 1000, 2000, 4000, 8000, 64000, 80000};
+
+ std::vector<IoHash> UnstructuredCids{CreateKey(4), CreateKey(5), CreateKey(6)};
+ CreateRecords(false, "mybucket"sv, UnstructuredCids, AttachmentSizes);
+
+ std::vector<IoHash> StructuredCids{CreateKey(1), CreateKey(2), CreateKey(3)};
+ CreateRecords(true, "mybucket"sv, StructuredCids, AttachmentSizes);
+
+ ScrubContext ScrubCtx;
+ Zcs.Scrub(ScrubCtx);
+ CidStore.Scrub(ScrubCtx);
+ CHECK(ScrubCtx.ScrubbedChunks() == (StructuredCids.size() + StructuredCids.size() * AttachmentSizes.size()) + UnstructuredCids.size());
+ CHECK(ScrubCtx.BadCids().GetSize() == 0);
+}
+
+#endif
+
+void
+z$_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenserver/cache/structuredcachestore.h b/src/zenserver/cache/structuredcachestore.h
new file mode 100644
index 000000000..3fb4f035d
--- /dev/null
+++ b/src/zenserver/cache/structuredcachestore.h
@@ -0,0 +1,535 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/thread.h>
+#include <zencore/uid.h>
+#include <zenstore/blockstore.h>
+#include <zenstore/caslog.h>
+#include <zenstore/gc.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <atomic>
+#include <compare>
+#include <filesystem>
+#include <unordered_map>
+
+#define ZEN_USE_CACHE_TRACKER 0
+
+namespace zen {
+
+class PathBuilderBase;
+class GcManager;
+class ZenCacheTracker;
+class ScrubContext;
+
+/******************************************************************************
+
+ /$$$$$$$$ /$$$$$$ /$$
+ |_____ $$ /$$__ $$ | $$
+ /$$/ /$$$$$$ /$$$$$$$ | $$ \__/ /$$$$$$ /$$$$$$| $$$$$$$ /$$$$$$
+ /$$/ /$$__ $| $$__ $$ | $$ |____ $$/$$_____| $$__ $$/$$__ $$
+ /$$/ | $$$$$$$| $$ \ $$ | $$ /$$$$$$| $$ | $$ \ $| $$$$$$$$
+ /$$/ | $$_____| $$ | $$ | $$ $$/$$__ $| $$ | $$ | $| $$_____/
+ /$$$$$$$| $$$$$$| $$ | $$ | $$$$$$| $$$$$$| $$$$$$| $$ | $| $$$$$$$
+ |________/\_______|__/ |__/ \______/ \_______/\_______|__/ |__/\_______/
+
+ Cache store for UE5. Restricts keys to "{bucket}/{hash}" pairs where the hash
+ is 40 (hex) chars in size. Values may be opaque blobs or structured objects
+ which can in turn contain references to other objects (or blobs).
+
+******************************************************************************/
+
+namespace access_tracking {
+
+ struct KeyAccessTime
+ {
+ IoHash Key;
+ GcClock::Tick LastAccess{};
+ };
+
+ struct AccessTimes
+ {
+ std::unordered_map<std::string, std::vector<KeyAccessTime>> Buckets;
+ };
+}; // namespace access_tracking
+
+struct ZenCacheValue
+{
+ IoBuffer Value;
+ uint64_t RawSize = 0;
+ IoHash RawHash = IoHash::Zero;
+};
+
+struct CacheValueDetails
+{
+ struct ValueDetails
+ {
+ uint64_t Size;
+ uint64_t RawSize;
+ IoHash RawHash;
+ GcClock::Tick LastAccess{};
+ std::vector<IoHash> Attachments;
+ ZenContentType ContentType;
+ };
+
+ struct BucketDetails
+ {
+ std::unordered_map<IoHash, ValueDetails, IoHash::Hasher> Values;
+ };
+
+ struct NamespaceDetails
+ {
+ std::unordered_map<std::string, BucketDetails> Buckets;
+ };
+
+ std::unordered_map<std::string, NamespaceDetails> Namespaces;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+#pragma pack(push)
+#pragma pack(1)
+
+struct DiskLocation
+{
+ inline DiskLocation() = default;
+
+ inline DiskLocation(uint64_t ValueSize, uint8_t Flags) : Flags(Flags | kStandaloneFile) { Location.StandaloneSize = ValueSize; }
+
+ inline DiskLocation(const BlockStoreLocation& Location, uint64_t PayloadAlignment, uint8_t Flags) : Flags(Flags & ~kStandaloneFile)
+ {
+ this->Location.BlockLocation = BlockStoreDiskLocation(Location, PayloadAlignment);
+ }
+
+ inline BlockStoreLocation GetBlockLocation(uint64_t PayloadAlignment) const
+ {
+ ZEN_ASSERT(!(Flags & kStandaloneFile));
+ return Location.BlockLocation.Get(PayloadAlignment);
+ }
+
+ inline uint64_t Size() const { return (Flags & kStandaloneFile) ? Location.StandaloneSize : Location.BlockLocation.GetSize(); }
+ inline uint8_t IsFlagSet(uint64_t Flag) const { return Flags & Flag; }
+ inline uint8_t GetFlags() const { return Flags; }
+ inline ZenContentType GetContentType() const
+ {
+ ZenContentType ContentType = ZenContentType::kBinary;
+
+ if (IsFlagSet(kStructured))
+ {
+ ContentType = ZenContentType::kCbObject;
+ }
+
+ if (IsFlagSet(kCompressed))
+ {
+ ContentType = ZenContentType::kCompressedBinary;
+ }
+
+ return ContentType;
+ }
+
+ union
+ {
+ BlockStoreDiskLocation BlockLocation; // 10 bytes
+ uint64_t StandaloneSize = 0; // 8 bytes
+ } Location;
+
+ static const uint8_t kStandaloneFile = 0x80u; // Stored as a separate file
+ static const uint8_t kStructured = 0x40u; // Serialized as compact binary
+ static const uint8_t kTombStone = 0x20u; // Represents a deleted key/value
+ static const uint8_t kCompressed = 0x10u; // Stored in compressed buffer format
+
+ uint8_t Flags = 0;
+ uint8_t Reserved = 0;
+};
+
+struct DiskIndexEntry
+{
+ IoHash Key; // 20 bytes
+ DiskLocation Location; // 12 bytes
+};
+
+#pragma pack(pop)
+
+static_assert(sizeof(DiskIndexEntry) == 32);
+
+// This store the access time as seconds since epoch internally in a 32-bit value giving is a range of 136 years since epoch
+struct AccessTime
+{
+ explicit AccessTime(GcClock::Tick Tick) noexcept : SecondsSinceEpoch(ToSeconds(Tick)) {}
+ AccessTime& operator=(GcClock::Tick Tick) noexcept
+ {
+ SecondsSinceEpoch.store(ToSeconds(Tick), std::memory_order_relaxed);
+ return *this;
+ }
+ operator GcClock::Tick() const noexcept
+ {
+ return std::chrono::duration_cast<GcClock::Duration>(std::chrono::seconds(SecondsSinceEpoch.load(std::memory_order_relaxed)))
+ .count();
+ }
+
+ AccessTime(AccessTime&& Rhs) noexcept : SecondsSinceEpoch(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed)) {}
+ AccessTime(const AccessTime& Rhs) noexcept : SecondsSinceEpoch(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed)) {}
+ AccessTime& operator=(AccessTime&& Rhs) noexcept
+ {
+ SecondsSinceEpoch.store(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed), std::memory_order_relaxed);
+ return *this;
+ }
+ AccessTime& operator=(const AccessTime& Rhs) noexcept
+ {
+ SecondsSinceEpoch.store(Rhs.SecondsSinceEpoch.load(std::memory_order_relaxed), std::memory_order_relaxed);
+ return *this;
+ }
+
+private:
+ static uint32_t ToSeconds(GcClock::Tick Tick)
+ {
+ return gsl::narrow<uint32_t>(std::chrono::duration_cast<std::chrono::seconds>(GcClock::Duration(Tick)).count());
+ }
+ std::atomic_uint32_t SecondsSinceEpoch;
+};
+
+/** In-memory cache storage
+
+ Intended for small values which are frequently accessed
+
+ This should have a better memory management policy to maintain reasonable
+ footprint.
+ */
+class ZenCacheMemoryLayer
+{
+public:
+ struct Configuration
+ {
+ uint64_t TargetFootprintBytes = 16 * 1024 * 1024;
+ uint64_t ScavengeThreshold = 4 * 1024 * 1024;
+ };
+
+ struct BucketInfo
+ {
+ uint64_t EntryCount = 0;
+ uint64_t TotalSize = 0;
+ };
+
+ struct Info
+ {
+ Configuration Config;
+ std::vector<std::string> BucketNames;
+ uint64_t EntryCount = 0;
+ uint64_t TotalSize = 0;
+ };
+
+ ZenCacheMemoryLayer();
+ ~ZenCacheMemoryLayer();
+
+ bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value);
+ void Drop();
+ bool DropBucket(std::string_view Bucket);
+ void Scrub(ScrubContext& Ctx);
+ void GatherAccessTimes(zen::access_tracking::AccessTimes& AccessTimes);
+ void Reset();
+ uint64_t TotalSize() const;
+
+ Info GetInfo() const;
+ std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const;
+
+ const Configuration& GetConfiguration() const { return m_Configuration; }
+ void SetConfiguration(const Configuration& NewConfig) { m_Configuration = NewConfig; }
+
+private:
+ struct CacheBucket
+ {
+#pragma pack(push)
+#pragma pack(1)
+ struct BucketPayload
+ {
+ IoBuffer Payload; // 8
+ uint32_t RawSize; // 4
+ IoHash RawHash; // 20
+ };
+#pragma pack(pop)
+ static_assert(sizeof(BucketPayload) == 32u);
+ static_assert(sizeof(AccessTime) == 4u);
+
+ mutable RwLock m_BucketLock;
+ std::vector<AccessTime> m_AccessTimes;
+ std::vector<BucketPayload> m_Payloads;
+ tsl::robin_map<IoHash, uint32_t> m_CacheMap;
+
+ std::atomic_uint64_t m_TotalSize{};
+
+ bool Get(const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(const IoHash& HashKey, const ZenCacheValue& Value);
+ void Drop();
+ void Scrub(ScrubContext& Ctx);
+ void GatherAccessTimes(std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes);
+ inline uint64_t TotalSize() const { return m_TotalSize; }
+ uint64_t EntryCount() const;
+ };
+
+ mutable RwLock m_Lock;
+ std::unordered_map<std::string, std::unique_ptr<CacheBucket>> m_Buckets;
+ std::vector<std::unique_ptr<CacheBucket>> m_DroppedBuckets;
+ Configuration m_Configuration;
+
+ ZenCacheMemoryLayer(const ZenCacheMemoryLayer&) = delete;
+ ZenCacheMemoryLayer& operator=(const ZenCacheMemoryLayer&) = delete;
+};
+
+class ZenCacheDiskLayer
+{
+public:
+ struct Configuration
+ {
+ std::filesystem::path RootDir;
+ };
+
+ struct BucketInfo
+ {
+ uint64_t EntryCount = 0;
+ uint64_t TotalSize = 0;
+ };
+
+ struct Info
+ {
+ Configuration Config;
+ std::vector<std::string> BucketNames;
+ uint64_t EntryCount = 0;
+ uint64_t TotalSize = 0;
+ };
+
+ explicit ZenCacheDiskLayer(const std::filesystem::path& RootDir);
+ ~ZenCacheDiskLayer();
+
+ bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value);
+ bool Drop();
+ bool DropBucket(std::string_view Bucket);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ void GatherReferences(GcContext& GcCtx);
+ void CollectGarbage(GcContext& GcCtx);
+ void UpdateAccessTimes(const zen::access_tracking::AccessTimes& AccessTimes);
+
+ void DiscoverBuckets();
+ uint64_t TotalSize() const;
+
+ Info GetInfo() const;
+ std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const;
+
+ CacheValueDetails::NamespaceDetails GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const;
+
+private:
+ /** A cache bucket manages a single directory containing
+ metadata and data for that bucket
+ */
+ struct CacheBucket
+ {
+ CacheBucket(std::string BucketName);
+ ~CacheBucket();
+
+ bool OpenOrCreate(std::filesystem::path BucketDir, bool AllowCreate = true);
+ bool Get(const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(const IoHash& HashKey, const ZenCacheValue& Value);
+ bool Drop();
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ void GatherReferences(GcContext& GcCtx);
+ void CollectGarbage(GcContext& GcCtx);
+ void UpdateAccessTimes(const std::vector<zen::access_tracking::KeyAccessTime>& AccessTimes);
+
+ inline uint64_t TotalSize() const { return m_TotalStandaloneSize.load(std::memory_order::relaxed) + m_BlockStore.TotalSize(); }
+ uint64_t EntryCount() const;
+
+ CacheValueDetails::BucketDetails GetValueDetails(const std::string_view ValueFilter) const;
+
+ private:
+ const uint64_t MaxBlockSize = 1ull << 30;
+ uint64_t m_PayloadAlignment = 1ull << 4;
+
+ std::string m_BucketName;
+ std::filesystem::path m_BucketDir;
+ std::filesystem::path m_BlocksBasePath;
+ BlockStore m_BlockStore;
+ Oid m_BucketId;
+ uint64_t m_LargeObjectThreshold = 128 * 1024;
+
+ // These files are used to manage storage of small objects for this bucket
+
+ TCasLogFile<DiskIndexEntry> m_SlogFile;
+ uint64_t m_LogFlushPosition = 0;
+
+#pragma pack(push)
+#pragma pack(1)
+ struct BucketPayload
+ {
+ DiskLocation Location; // 12
+ uint64_t RawSize; // 8
+ IoHash RawHash; // 20
+ };
+#pragma pack(pop)
+ static_assert(sizeof(BucketPayload) == 40u);
+ static_assert(sizeof(AccessTime) == 4u);
+
+ using IndexMap = tsl::robin_map<IoHash, size_t, IoHash::Hasher>;
+
+ mutable RwLock m_IndexLock;
+ std::vector<AccessTime> m_AccessTimes;
+ std::vector<BucketPayload> m_Payloads;
+ IndexMap m_Index;
+
+ std::atomic_uint64_t m_TotalStandaloneSize{};
+
+ void BuildPath(PathBuilderBase& Path, const IoHash& HashKey) const;
+ void PutStandaloneCacheValue(const IoHash& HashKey, const ZenCacheValue& Value);
+ IoBuffer GetStandaloneCacheValue(const DiskLocation& Loc, const IoHash& HashKey) const;
+ void PutInlineCacheValue(const IoHash& HashKey, const ZenCacheValue& Value);
+ IoBuffer GetInlineCacheValue(const DiskLocation& Loc) const;
+ void MakeIndexSnapshot();
+ uint64_t ReadIndexFile(const std::filesystem::path& IndexPath, uint32_t& OutVersion);
+ uint64_t ReadLog(const std::filesystem::path& LogPath, uint64_t LogPosition);
+ void OpenLog(const bool IsNew);
+ void SaveManifest();
+ CacheValueDetails::ValueDetails GetValueDetails(const IoHash& Key, size_t Index) const;
+ // These locks are here to avoid contention on file creation, therefore it's sufficient
+ // that we take the same lock for the same hash
+ //
+ // These locks are small and should really be spaced out so they don't share cache lines,
+ // but we don't currently access them at particularly high frequency so it should not be
+ // an issue in practice
+
+ mutable RwLock m_ShardedLocks[256];
+ inline RwLock& LockForHash(const IoHash& Hash) const { return m_ShardedLocks[Hash.Hash[19]]; }
+ };
+
+ std::filesystem::path m_RootDir;
+ mutable RwLock m_Lock;
+ std::unordered_map<std::string, std::unique_ptr<CacheBucket>> m_Buckets; // TODO: make this case insensitive
+ std::vector<std::unique_ptr<CacheBucket>> m_DroppedBuckets;
+
+ ZenCacheDiskLayer(const ZenCacheDiskLayer&) = delete;
+ ZenCacheDiskLayer& operator=(const ZenCacheDiskLayer&) = delete;
+};
+
+class ZenCacheNamespace final : public RefCounted, public GcStorage, public GcContributor
+{
+public:
+ struct Configuration
+ {
+ std::filesystem::path RootDir;
+ uint64_t DiskLayerThreshold = 0;
+ };
+ struct BucketInfo
+ {
+ ZenCacheDiskLayer::BucketInfo DiskLayerInfo;
+ ZenCacheMemoryLayer::BucketInfo MemoryLayerInfo;
+ };
+ struct Info
+ {
+ Configuration Config;
+ std::vector<std::string> BucketNames;
+ ZenCacheDiskLayer::Info DiskLayerInfo;
+ ZenCacheMemoryLayer::Info MemoryLayerInfo;
+ };
+
+ ZenCacheNamespace(GcManager& Gc, const std::filesystem::path& RootDir);
+ ~ZenCacheNamespace();
+
+ bool Get(std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value);
+ bool Drop();
+ bool DropBucket(std::string_view Bucket);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ uint64_t DiskLayerThreshold() const { return m_DiskLayerSizeThreshold; }
+ virtual void GatherReferences(GcContext& GcCtx) override;
+ virtual void CollectGarbage(GcContext& GcCtx) override;
+ virtual GcStorageSize StorageSize() const override;
+ Info GetInfo() const;
+ std::optional<BucketInfo> GetBucketInfo(std::string_view Bucket) const;
+
+ CacheValueDetails::NamespaceDetails GetValueDetails(const std::string_view BucketFilter, const std::string_view ValueFilter) const;
+
+private:
+ std::filesystem::path m_RootDir;
+ ZenCacheMemoryLayer m_MemLayer;
+ ZenCacheDiskLayer m_DiskLayer;
+ uint64_t m_DiskLayerSizeThreshold = 1 * 1024;
+ uint64_t m_LastScrubTime = 0;
+
+#if ZEN_USE_CACHE_TRACKER
+ std::unique_ptr<ZenCacheTracker> m_AccessTracker;
+#endif
+
+ ZenCacheNamespace(const ZenCacheNamespace&) = delete;
+ ZenCacheNamespace& operator=(const ZenCacheNamespace&) = delete;
+};
+
+class ZenCacheStore final
+{
+public:
+ static constexpr std::string_view DefaultNamespace =
+ "!default!"; // This is intentionally not a valid namespace name and will only be used for mapping when no namespace is given
+ static constexpr std::string_view NamespaceDiskPrefix = "ns_";
+
+ struct Configuration
+ {
+ std::filesystem::path BasePath;
+ bool AllowAutomaticCreationOfNamespaces = false;
+ };
+
+ struct Info
+ {
+ Configuration Config;
+ std::vector<std::string> NamespaceNames;
+ uint64_t DiskEntryCount = 0;
+ uint64_t MemoryEntryCount = 0;
+ GcStorageSize StorageSize;
+ };
+
+ ZenCacheStore(GcManager& Gc, const Configuration& Configuration);
+ ~ZenCacheStore();
+
+ bool Get(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, ZenCacheValue& OutValue);
+ void Put(std::string_view Namespace, std::string_view Bucket, const IoHash& HashKey, const ZenCacheValue& Value);
+ bool DropBucket(std::string_view Namespace, std::string_view Bucket);
+ bool DropNamespace(std::string_view Namespace);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+
+ CacheValueDetails GetValueDetails(const std::string_view NamespaceFilter,
+ const std::string_view BucketFilter,
+ const std::string_view ValueFilter) const;
+
+ GcStorageSize StorageSize() const;
+ // const Configuration& GetConfiguration() const { return m_Configuration; }
+
+ Info GetInfo() const;
+ std::optional<ZenCacheNamespace::Info> GetNamespaceInfo(std::string_view Namespace);
+ std::optional<ZenCacheNamespace::BucketInfo> GetBucketInfo(std::string_view Namespace, std::string_view Bucket);
+
+private:
+ const ZenCacheNamespace* FindNamespace(std::string_view Namespace) const;
+ ZenCacheNamespace* GetNamespace(std::string_view Namespace);
+ void IterateNamespaces(const std::function<void(std::string_view Namespace, ZenCacheNamespace& Store)>& Callback) const;
+
+ typedef std::unordered_map<std::string, std::unique_ptr<ZenCacheNamespace>> NamespaceMap;
+
+ mutable RwLock m_NamespacesLock;
+ NamespaceMap m_Namespaces;
+ std::vector<std::unique_ptr<ZenCacheNamespace>> m_DroppedNamespaces;
+
+ GcManager& m_Gc;
+ Configuration m_Configuration;
+};
+
+void z$_forcelink();
+
+} // namespace zen
diff --git a/src/zenserver/cidstore.cpp b/src/zenserver/cidstore.cpp
new file mode 100644
index 000000000..bce4f1dfb
--- /dev/null
+++ b/src/zenserver/cidstore.cpp
@@ -0,0 +1,124 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "cidstore.h"
+
+#include <zencore/compress.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zenstore/cidstore.h>
+
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+HttpCidService::HttpCidService(CidStore& Store) : m_CidStore(Store)
+{
+ m_Router.AddPattern("cid", "([0-9A-Fa-f]{40})");
+
+ m_Router.RegisterRoute(
+ "{cid}",
+ [this](HttpRouterRequest& Req) {
+ IoHash Hash = IoHash::FromHexString(Req.GetCapture(1));
+ ZEN_DEBUG("CID request for {}", Hash);
+
+ HttpServerRequest& ServerRequest = Req.ServerRequest();
+
+ switch (ServerRequest.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ case HttpVerb::kHead:
+ {
+ if (IoBuffer Value = m_CidStore.FindChunkByCid(Hash))
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Value);
+ }
+
+ return ServerRequest.WriteResponse(HttpResponseCode::NotFound);
+ }
+ break;
+
+ case HttpVerb::kPut:
+ {
+ IoBuffer Payload = ServerRequest.ReadPayload();
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize))
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::UnsupportedMediaType);
+ }
+
+ // URI hash must match content hash
+ if (RawHash != Hash)
+ {
+ return ServerRequest.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ m_CidStore.AddChunk(Payload, RawHash);
+
+ return ServerRequest.WriteResponse(HttpResponseCode::OK);
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPut | HttpVerb::kHead);
+}
+
+const char*
+HttpCidService::BaseUri() const
+{
+ return "/cid/";
+}
+
+void
+HttpCidService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ if (Request.RelativeUri().empty())
+ {
+ // Root URI request
+
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kPut:
+ case HttpVerb::kPost:
+ {
+ IoBuffer Payload = Request.ReadPayload();
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize))
+ {
+ return Request.WriteResponse(HttpResponseCode::UnsupportedMediaType);
+ }
+
+ ZEN_DEBUG("CID POST request for {} ({} bytes)", RawHash, Payload.Size());
+
+ auto InsertResult = m_CidStore.AddChunk(Payload, RawHash);
+
+ if (InsertResult.New)
+ {
+ return Request.WriteResponse(HttpResponseCode::Created);
+ }
+ else
+ {
+ return Request.WriteResponse(HttpResponseCode::OK);
+ }
+ }
+ break;
+
+ case HttpVerb::kGet:
+ case HttpVerb::kHead:
+ break;
+
+ default:
+ break;
+ }
+ }
+ else
+ {
+ m_Router.HandleRequest(Request);
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/cidstore.h b/src/zenserver/cidstore.h
new file mode 100644
index 000000000..8e7832b35
--- /dev/null
+++ b/src/zenserver/cidstore.h
@@ -0,0 +1,35 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+/**
+ * Simple CID store HTTP endpoint
+ *
+ * Note that since this does not end up pinning any of the chunks it's only really useful for a small subset of use cases where you know a
+ * chunk exists in the underlying CID store. Thus it's mainly useful for internal use when communicating between Zen store instances
+ *
+ * Using this interface for adding CID chunks makes little sense except for testing purposes as garbage collection may reap anything you add
+ * before anything ever gets to access it
+ */
+
+class CidStore;
+
+class HttpCidService : public HttpService
+{
+public:
+ explicit HttpCidService(CidStore& Store);
+ ~HttpCidService() = default;
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ CidStore& m_CidStore;
+ HttpRequestRouter m_Router;
+};
+
+} // namespace zen
diff --git a/src/zenserver/compute/function.cpp b/src/zenserver/compute/function.cpp
new file mode 100644
index 000000000..493e2666e
--- /dev/null
+++ b/src/zenserver/compute/function.cpp
@@ -0,0 +1,629 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "function.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <upstream/jupiter.h>
+# include <upstream/upstreamapply.h>
+# include <upstream/upstreamcache.h>
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compress.h>
+# include <zencore/except.h>
+# include <zencore/filesystem.h>
+# include <zencore/fmtutils.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/scopeguard.h>
+# include <zenstore/cidstore.h>
+
+# include <span>
+
+using namespace std::literals;
+
+namespace zen {
+
+HttpFunctionService::HttpFunctionService(CidStore& InCidStore,
+ const CloudCacheClientOptions& ComputeOptions,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ AuthMgr& Mgr)
+: m_Log(logging::Get("apply"))
+, m_CidStore(InCidStore)
+{
+ m_UpstreamApply = UpstreamApply::Create({}, m_CidStore);
+
+ InitializeThread = std::thread{[this, ComputeOptions, StorageOptions, ComputeAuthConfig, StorageAuthConfig, &Mgr] {
+ auto HordeUpstreamEndpoint = UpstreamApplyEndpoint::CreateHordeEndpoint(ComputeOptions,
+ ComputeAuthConfig,
+ StorageOptions,
+ StorageAuthConfig,
+ m_CidStore,
+ Mgr);
+ m_UpstreamApply->RegisterEndpoint(std::move(HordeUpstreamEndpoint));
+ m_UpstreamApply->Initialize();
+ }};
+
+ m_Router.AddPattern("job", "([[:digit:]]+)");
+ m_Router.AddPattern("worker", "([[:xdigit:]]{40})");
+ m_Router.AddPattern("action", "([[:xdigit:]]{40})");
+
+ m_Router.RegisterRoute(
+ "ready",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ return HttpReq.WriteResponse(m_UpstreamApply->IsHealthy() ? HttpResponseCode::OK : HttpResponseCode::ServiceUnavailable);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "workers/{worker}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1));
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ RwLock::SharedLockScope _(m_WorkerLock);
+
+ if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ else
+ {
+ const WorkerDesc& Desc = It->second;
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Desc.Descriptor);
+ }
+ }
+ break;
+
+ case HttpVerb::kPost:
+ {
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ CbObject FunctionSpec = HttpReq.ReadPayloadObject();
+
+ // Determine which pieces are missing and need to be transmitted to populate CAS
+
+ HashKeySet ChunkSet;
+
+ FunctionSpec.IterateAttachments([&](CbFieldView Field) {
+ const IoHash Hash = Field.AsHash();
+ ChunkSet.AddHashToSet(Hash);
+ });
+
+ // Note that we store executables uncompressed to make it
+ // more straightforward and efficient to materialize them, hence
+ // the CAS lookup here instead of CID for the input payloads
+
+ m_CidStore.FilterChunks(ChunkSet);
+
+ if (ChunkSet.IsEmpty())
+ {
+ RwLock::ExclusiveLockScope _(m_WorkerLock);
+
+ m_WorkerMap.insert_or_assign(WorkerId, WorkerDesc{FunctionSpec});
+
+ ZEN_DEBUG("worker {}: all attachments already available", WorkerId);
+
+ return HttpReq.WriteResponse(HttpResponseCode::NoContent);
+ }
+ else
+ {
+ CbObjectWriter ResponseWriter;
+ ResponseWriter.BeginArray("need");
+
+ ChunkSet.IterateHashes([&](const IoHash& Hash) {
+ ZEN_DEBUG("worker {}: need chunk {}", WorkerId, Hash);
+
+ ResponseWriter.AddHash(Hash);
+ });
+
+ ResponseWriter.EndArray();
+
+ ZEN_DEBUG("worker {}: need {} attachments", WorkerId, ChunkSet.GetSize());
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, ResponseWriter.Save());
+ }
+ }
+ break;
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage FunctionSpec = HttpReq.ReadPayloadPackage();
+
+ CbObject Obj = FunctionSpec.GetObject();
+
+ std::span<const CbAttachment> Attachments = FunctionSpec.GetAttachments();
+
+ int AttachmentCount = 0;
+ int NewAttachmentCount = 0;
+ uint64_t TotalAttachmentBytes = 0;
+ uint64_t TotalNewBytes = 0;
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ ZEN_ASSERT(Attachment.IsCompressedBinary());
+
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer Buffer = Attachment.AsCompressedBinary();
+
+ ZEN_UNUSED(DataHash);
+ TotalAttachmentBytes += Buffer.GetCompressedSize();
+ ++AttachmentCount;
+
+ const CidStore::InsertResult InsertResult =
+ m_CidStore.AddChunk(Buffer.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+
+ if (InsertResult.New)
+ {
+ TotalNewBytes += Buffer.GetCompressedSize();
+ ++NewAttachmentCount;
+ }
+ }
+
+ ZEN_DEBUG("worker {}: {} in {} attachments, {} in {} new attachments",
+ WorkerId,
+ zen::NiceBytes(TotalAttachmentBytes),
+ AttachmentCount,
+ zen::NiceBytes(TotalNewBytes),
+ NewAttachmentCount);
+
+ RwLock::ExclusiveLockScope _(m_WorkerLock);
+
+ m_WorkerMap.insert_or_assign(WorkerId, WorkerDesc{.Descriptor = Obj});
+
+ return HttpReq.WriteResponse(HttpResponseCode::NoContent);
+ }
+ break;
+
+ default:
+ break;
+ }
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "jobs/{job}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ break;
+
+ case HttpVerb::kPost:
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "jobs/{worker}/{action}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1));
+ const IoHash ActionId = IoHash::FromHexString(Req.GetCapture(2));
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ CbPackage Output;
+ HttpResponseCode ResponseCode = ExecActionUpstreamResult(WorkerId, ActionId, Output);
+ if (ResponseCode != HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+ break;
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "simple/{worker}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1));
+
+ WorkerDesc Worker;
+
+ {
+ RwLock::SharedLockScope _(m_WorkerLock);
+
+ if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ else
+ {
+ Worker = It->second;
+ }
+ }
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ CbObject Output;
+ HttpResponseCode ResponseCode = ExecActionUpstreamResult(WorkerId, Output);
+ if (ResponseCode != HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+
+ {
+ RwLock::SharedLockScope _(m_WorkerLock);
+ m_WorkerMap.erase(WorkerId);
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+ break;
+
+ case HttpVerb::kPost:
+ {
+ CbObject Output;
+ HttpResponseCode ResponseCode = ExecActionUpstream(Worker, Output);
+ if (ResponseCode != HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "jobs/{worker}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const IoHash WorkerId = IoHash::FromHexString(Req.GetCapture(1));
+
+ WorkerDesc Worker;
+
+ {
+ RwLock::SharedLockScope _(m_WorkerLock);
+
+ if (auto It = m_WorkerMap.find(WorkerId); It == m_WorkerMap.end())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ else
+ {
+ Worker = It->second;
+ }
+ }
+
+ switch (HttpReq.RequestVerb())
+ {
+ case HttpVerb::kGet:
+ // TODO: return status of all pending or executing jobs
+ break;
+
+ case HttpVerb::kPost:
+ switch (HttpReq.RequestContentType())
+ {
+ case HttpContentType::kCbObject:
+ {
+ // This operation takes the proposed job spec and identifies which
+ // chunks are not present on this server. This list is then returned in
+ // the "need" list in the response
+
+ IoBuffer Payload = HttpReq.ReadPayload();
+ CbObject RequestObject = LoadCompactBinaryObject(Payload);
+
+ std::vector<IoHash> NeedList;
+
+ RequestObject.IterateAttachments([&](CbFieldView Field) {
+ const IoHash FileHash = Field.AsHash();
+
+ if (!m_CidStore.ContainsChunk(FileHash))
+ {
+ NeedList.push_back(FileHash);
+ }
+ });
+
+ if (NeedList.empty())
+ {
+ // We already have everything
+ CbObject Output;
+ HttpResponseCode ResponseCode = ExecActionUpstream(Worker, RequestObject, Output);
+
+ if (ResponseCode != HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+ CbObject Response = Cbo.Save();
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, Response);
+ }
+ break;
+
+ case HttpContentType::kCbPackage:
+ {
+ CbPackage Action = HttpReq.ReadPayloadPackage();
+ CbObject ActionObj = Action.GetObject();
+
+ std::span<const CbAttachment> Attachments = Action.GetAttachments();
+
+ int AttachmentCount = 0;
+ int NewAttachmentCount = 0;
+ uint64_t TotalAttachmentBytes = 0;
+ uint64_t TotalNewBytes = 0;
+
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ ZEN_ASSERT(Attachment.IsCompressedBinary());
+
+ const IoHash DataHash = Attachment.GetHash();
+ CompressedBuffer DataView = Attachment.AsCompressedBinary();
+
+ ZEN_UNUSED(DataHash);
+
+ const uint64_t CompressedSize = DataView.GetCompressedSize();
+
+ TotalAttachmentBytes += CompressedSize;
+ ++AttachmentCount;
+
+ const CidStore::InsertResult InsertResult =
+ m_CidStore.AddChunk(DataView.GetCompressed().Flatten().AsIoBuffer(), DataHash);
+
+ if (InsertResult.New)
+ {
+ TotalNewBytes += CompressedSize;
+ ++NewAttachmentCount;
+ }
+ }
+
+ ZEN_DEBUG("new action: {} in {} attachments. {} new ({} attachments)",
+ zen::NiceBytes(TotalAttachmentBytes),
+ AttachmentCount,
+ zen::NiceBytes(TotalNewBytes),
+ NewAttachmentCount);
+
+ CbObject Output;
+ HttpResponseCode ResponseCode = ExecActionUpstream(Worker, ActionObj, Output);
+
+ if (ResponseCode != HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(ResponseCode);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Output);
+ }
+ break;
+
+ default:
+ break;
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kPost);
+}
+
+HttpFunctionService::~HttpFunctionService()
+{
+}
+
+const char*
+HttpFunctionService::BaseUri() const
+{
+ return "/apply/";
+}
+
+void
+HttpFunctionService::HandleRequest(HttpServerRequest& Request)
+{
+ if (m_Router.HandleRequest(Request) == false)
+ {
+ ZEN_WARN("No route found for {0}", Request.RelativeUri());
+ }
+}
+
+HttpResponseCode
+HttpFunctionService::ExecActionUpstream(const WorkerDesc& Worker, CbObject& Object)
+{
+ const IoHash WorkerId = Worker.Descriptor.GetHash();
+
+ ZEN_INFO("Action {} being processed...", WorkerId.ToHexString());
+
+ auto EnqueueResult = m_UpstreamApply->EnqueueUpstream({.WorkerDescriptor = Worker.Descriptor, .Type = UpstreamApplyType::Simple});
+ if (!EnqueueResult.Success)
+ {
+ ZEN_ERROR("Error enqueuing upstream Action {}", WorkerId.ToHexString());
+ return HttpResponseCode::InternalServerError;
+ }
+
+ CbObjectWriter Writer;
+ Writer.AddHash("worker", WorkerId);
+
+ Object = Writer.Save();
+ return HttpResponseCode::OK;
+}
+
+HttpResponseCode
+HttpFunctionService::ExecActionUpstreamResult(const IoHash& WorkerId, CbObject& Object)
+{
+ const static IoHash Empty = CbObject().GetHash();
+ auto Status = m_UpstreamApply->GetStatus(WorkerId, Empty);
+ if (!Status.Success)
+ {
+ return HttpResponseCode::NotFound;
+ }
+
+ if (Status.Status.State != UpstreamApplyState::Complete)
+ {
+ return HttpResponseCode::Accepted;
+ }
+
+ GetUpstreamApplyResult& Completed = Status.Status.Result;
+
+ if (!Completed.Success)
+ {
+ ZEN_ERROR("Action {} failed:\n stdout: {}\n stderr: {}\n reason: {}\n errorcode: {}",
+ WorkerId.ToHexString(),
+ Completed.StdOut,
+ Completed.StdErr,
+ Completed.Error.Reason,
+ Completed.Error.ErrorCode);
+
+ if (Completed.Error.ErrorCode == 0)
+ {
+ Completed.Error.ErrorCode = -1;
+ }
+ if (Completed.StdErr.empty() && !Completed.Error.Reason.empty())
+ {
+ Completed.StdErr = Completed.Error.Reason;
+ }
+ }
+ else
+ {
+ ZEN_INFO("Action {} completed with {} files ExitCode={}",
+ WorkerId.ToHexString(),
+ Completed.OutputFiles.size(),
+ Completed.Error.ErrorCode);
+ }
+
+ CbObjectWriter ResultObject;
+
+ ResultObject.AddString("agent"sv, Completed.Agent);
+ ResultObject.AddString("detail"sv, Completed.Detail);
+ ResultObject.AddString("stdout"sv, Completed.StdOut);
+ ResultObject.AddString("stderr"sv, Completed.StdErr);
+ ResultObject.AddInteger("exitcode"sv, Completed.Error.ErrorCode);
+ ResultObject.BeginArray("stats"sv);
+ for (const auto& Timepoint : Completed.Timepoints)
+ {
+ ResultObject.BeginObject();
+ ResultObject.AddString("name"sv, Timepoint.first);
+ ResultObject.AddDateTimeTicks("time"sv, Timepoint.second);
+ ResultObject.EndObject();
+ }
+ ResultObject.EndArray();
+
+ ResultObject.BeginArray("files"sv);
+ for (const auto& File : Completed.OutputFiles)
+ {
+ ResultObject.BeginObject();
+ ResultObject.AddString("name"sv, File.first.string());
+ ResultObject.AddBinary("data"sv, Completed.FileData[File.second]);
+ ResultObject.EndObject();
+ }
+ ResultObject.EndArray();
+
+ Object = ResultObject.Save();
+ return HttpResponseCode::OK;
+}
+
+HttpResponseCode
+HttpFunctionService::ExecActionUpstream(const WorkerDesc& Worker, CbObject Action, CbObject& Object)
+{
+ const IoHash WorkerId = Worker.Descriptor.GetHash();
+ const IoHash ActionId = Action.GetHash();
+
+ Action.MakeOwned();
+
+ ZEN_INFO("Action {}/{} being processed...", WorkerId.ToHexString(), ActionId.ToHexString());
+
+ auto EnqueueResult = m_UpstreamApply->EnqueueUpstream(
+ {.WorkerDescriptor = Worker.Descriptor, .Action = std::move(Action), .Type = UpstreamApplyType::Asset});
+
+ if (!EnqueueResult.Success)
+ {
+ ZEN_ERROR("Error enqueuing upstream Action {}/{}", WorkerId.ToHexString(), ActionId.ToHexString());
+ return HttpResponseCode::InternalServerError;
+ }
+
+ CbObjectWriter Writer;
+ Writer.AddHash("worker", WorkerId);
+ Writer.AddHash("action", ActionId);
+
+ Object = Writer.Save();
+ return HttpResponseCode::OK;
+}
+
+HttpResponseCode
+HttpFunctionService::ExecActionUpstreamResult(const IoHash& WorkerId, const IoHash& ActionId, CbPackage& Package)
+{
+ auto Status = m_UpstreamApply->GetStatus(WorkerId, ActionId);
+ if (!Status.Success)
+ {
+ return HttpResponseCode::NotFound;
+ }
+
+ if (Status.Status.State != UpstreamApplyState::Complete)
+ {
+ return HttpResponseCode::Accepted;
+ }
+
+ GetUpstreamApplyResult& Completed = Status.Status.Result;
+ if (!Completed.Success || Completed.Error.ErrorCode != 0)
+ {
+ ZEN_ERROR("Action {}/{} failed:\n stdout: {}\n stderr: {}\n reason: {}\n errorcode: {}",
+ WorkerId.ToHexString(),
+ ActionId.ToHexString(),
+ Completed.StdOut,
+ Completed.StdErr,
+ Completed.Error.Reason,
+ Completed.Error.ErrorCode);
+
+ return HttpResponseCode::InternalServerError;
+ }
+
+ ZEN_INFO("Action {}/{} completed with {} attachments ({} compressed, {} uncompressed)",
+ WorkerId.ToHexString(),
+ ActionId.ToHexString(),
+ Completed.OutputPackage.GetAttachments().size(),
+ NiceBytes(Completed.TotalAttachmentBytes),
+ NiceBytes(Completed.TotalRawAttachmentBytes));
+
+ Package = std::move(Completed.OutputPackage);
+ return HttpResponseCode::OK;
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/compute/function.h b/src/zenserver/compute/function.h
new file mode 100644
index 000000000..650cee757
--- /dev/null
+++ b/src/zenserver/compute/function.h
@@ -0,0 +1,73 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#if !defined(ZEN_WITH_COMPUTE_SERVICES)
+# define ZEN_WITH_COMPUTE_SERVICES 1
+#endif
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/iohash.h>
+# include <zencore/logging.h>
+# include <zenhttp/httpserver.h>
+
+# include <filesystem>
+# include <unordered_map>
+
+namespace zen {
+
+class CidStore;
+class UpstreamApply;
+class CloudCacheClient;
+class AuthMgr;
+
+struct UpstreamAuthConfig;
+struct CloudCacheClientOptions;
+
+/**
+ * Lambda style compute function service
+ */
+class HttpFunctionService : public HttpService
+{
+public:
+ HttpFunctionService(CidStore& InCidStore,
+ const CloudCacheClientOptions& ComputeOptions,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ AuthMgr& Mgr);
+ ~HttpFunctionService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+
+private:
+ std::thread InitializeThread;
+ spdlog::logger& Log() { return m_Log; }
+ spdlog::logger& m_Log;
+ HttpRequestRouter m_Router;
+ CidStore& m_CidStore;
+ std::unique_ptr<UpstreamApply> m_UpstreamApply;
+
+ struct WorkerDesc
+ {
+ CbObject Descriptor;
+ };
+
+ [[nodiscard]] HttpResponseCode ExecActionUpstream(const WorkerDesc& Worker, CbObject& Object);
+ [[nodiscard]] HttpResponseCode ExecActionUpstreamResult(const IoHash& WorkerId, CbObject& Object);
+
+ [[nodiscard]] HttpResponseCode ExecActionUpstream(const WorkerDesc& Worker, CbObject Action, CbObject& Object);
+ [[nodiscard]] HttpResponseCode ExecActionUpstreamResult(const IoHash& WorkerId, const IoHash& ActionId, CbPackage& Package);
+
+ RwLock m_WorkerLock;
+ std::unordered_map<IoHash, WorkerDesc> m_WorkerMap;
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/config.cpp b/src/zenserver/config.cpp
new file mode 100644
index 000000000..cff93d67b
--- /dev/null
+++ b/src/zenserver/config.cpp
@@ -0,0 +1,902 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "config.h"
+
+#include "diag/logging.h"
+
+#include <zencore/crypto.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/string.h>
+#include <zenhttp/zenhttp.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <zencore/logging.h>
+#include <cxxopts.hpp>
+#include <sol/sol.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# include <conio.h>
+#else
+# include <pwd.h>
+#endif
+
+#if ZEN_PLATFORM_WINDOWS
+
+// Used for getting My Documents for default data directory
+# include <ShlObj.h>
+# pragma comment(lib, "shell32.lib")
+
+std::filesystem::path
+PickDefaultStateDirectory()
+{
+ // Pick sensible default
+ PWSTR programDataDir = nullptr;
+ HRESULT hRes = SHGetKnownFolderPath(FOLDERID_ProgramData, 0, NULL, &programDataDir);
+
+ if (SUCCEEDED(hRes))
+ {
+ std::filesystem::path finalPath(programDataDir);
+ finalPath /= L"Epic\\Zen\\Data";
+ ::CoTaskMemFree(programDataDir);
+
+ return finalPath;
+ }
+
+ return L"";
+}
+
+#else
+
+std::filesystem::path
+PickDefaultStateDirectory()
+{
+ int UserId = getuid();
+ const passwd* Passwd = getpwuid(UserId);
+ return std::filesystem::path(Passwd->pw_dir) / ".zen";
+}
+
+#endif
+
+void
+ValidateOptions(ZenServerOptions& ServerOptions)
+{
+ if (ServerOptions.EncryptionKey.empty() == false)
+ {
+ const auto Key = zen::AesKey256Bit::FromString(ServerOptions.EncryptionKey);
+
+ if (Key.IsValid() == false)
+ {
+ throw cxxopts::OptionParseException("Invalid AES encryption key");
+ }
+ }
+
+ if (ServerOptions.EncryptionIV.empty() == false)
+ {
+ const auto IV = zen::AesIV128Bit::FromString(ServerOptions.EncryptionIV);
+
+ if (IV.IsValid() == false)
+ {
+ throw cxxopts::OptionParseException("Invalid AES initialization vector");
+ }
+ }
+}
+
+UpstreamCachePolicy
+ParseUpstreamCachePolicy(std::string_view Options)
+{
+ if (Options == "readonly")
+ {
+ return UpstreamCachePolicy::Read;
+ }
+ else if (Options == "writeonly")
+ {
+ return UpstreamCachePolicy::Write;
+ }
+ else if (Options == "disabled")
+ {
+ return UpstreamCachePolicy::Disabled;
+ }
+ else
+ {
+ return UpstreamCachePolicy::ReadWrite;
+ }
+}
+
+ZenObjectStoreConfig
+ParseBucketConfigs(std::span<std::string> Buckets)
+{
+ using namespace std::literals;
+
+ ZenObjectStoreConfig Cfg;
+
+ // split bucket args in the form of "{BucketName};{LocalPath}"
+ for (std::string_view Bucket : Buckets)
+ {
+ ZenObjectStoreConfig::BucketConfig NewBucket;
+
+ if (auto Idx = Bucket.find_first_of(";"); Idx != std::string_view::npos)
+ {
+ NewBucket.Name = Bucket.substr(0, Idx);
+ NewBucket.Directory = Bucket.substr(Idx + 1);
+ }
+ else
+ {
+ NewBucket.Name = Bucket;
+ }
+
+ Cfg.Buckets.push_back(std::move(NewBucket));
+ }
+
+ return Cfg;
+}
+
+void ParseConfigFile(const std::filesystem::path& Path, ZenServerOptions& ServerOptions);
+
+void
+ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions)
+{
+#if ZEN_WITH_HTTPSYS
+ const char* DefaultHttp = "httpsys";
+#else
+ const char* DefaultHttp = "asio";
+#endif
+
+ // Note to those adding future options; std::filesystem::path-type options
+ // must be read into a std::string first. As of cxxopts-3.0.0 it uses a >>
+ // stream operator to convert argv value into the options type. std::fs::path
+ // expects paths in streams to be quoted but argv paths are unquoted. By
+ // going into a std::string first, paths with whitespace parse correctly.
+ std::string DataDir;
+ std::string ContentDir;
+ std::string AbsLogFile;
+ std::string ConfigFile;
+
+ cxxopts::Options options("zenserver", "Zen Server");
+ options.add_options()("dedicated",
+ "Enable dedicated server mode",
+ cxxopts::value<bool>(ServerOptions.IsDedicated)->default_value("false"));
+ options.add_options()("d, debug", "Enable debugging", cxxopts::value<bool>(ServerOptions.IsDebug)->default_value("false"));
+ options.add_options()("help", "Show command line help");
+ options.add_options()("t, test", "Enable test mode", cxxopts::value<bool>(ServerOptions.IsTest)->default_value("false"));
+ options.add_options()("log-id", "Specify id for adding context to log output", cxxopts::value<std::string>(ServerOptions.LogId));
+ options.add_options()("data-dir", "Specify persistence root", cxxopts::value<std::string>(DataDir));
+ options.add_options()("content-dir", "Frontend content directory", cxxopts::value<std::string>(ContentDir));
+ options.add_options()("abslog", "Path to log file", cxxopts::value<std::string>(AbsLogFile));
+ options.add_options()("config", "Path to Lua config file", cxxopts::value<std::string>(ConfigFile));
+ options.add_options()("no-sentry",
+ "Disable Sentry crash handler",
+ cxxopts::value<bool>(ServerOptions.NoSentry)->default_value("false"));
+
+ options.add_option("security",
+ "",
+ "encryption-aes-key",
+ "256 bit AES encryption key",
+ cxxopts::value<std::string>(ServerOptions.EncryptionKey),
+ "");
+
+ options.add_option("security",
+ "",
+ "encryption-aes-iv",
+ "128 bit AES encryption initialization vector",
+ cxxopts::value<std::string>(ServerOptions.EncryptionIV),
+ "");
+
+ std::string OpenIdProviderName;
+ options.add_option("security",
+ "",
+ "openid-provider-name",
+ "Open ID provider name",
+ cxxopts::value<std::string>(OpenIdProviderName),
+ "Default");
+
+ std::string OpenIdProviderUrl;
+ options.add_option("security", "", "openid-provider-url", "Open ID provider URL", cxxopts::value<std::string>(OpenIdProviderUrl), "");
+
+ std::string OpenIdClientId;
+ options.add_option("security", "", "openid-client-id", "Open ID client ID", cxxopts::value<std::string>(OpenIdClientId), "");
+
+ options
+ .add_option("lifetime", "", "owner-pid", "Specify owning process id", cxxopts::value<int>(ServerOptions.OwnerPid), "<identifier>");
+ options.add_option("lifetime",
+ "",
+ "child-id",
+ "Specify id which can be used to signal parent",
+ cxxopts::value<std::string>(ServerOptions.ChildId),
+ "<identifier>");
+
+#if ZEN_PLATFORM_WINDOWS
+ options.add_option("lifetime",
+ "",
+ "install",
+ "Install zenserver as a Windows service",
+ cxxopts::value<bool>(ServerOptions.InstallService),
+ "");
+ options.add_option("lifetime",
+ "",
+ "uninstall",
+ "Uninstall zenserver as a Windows service",
+ cxxopts::value<bool>(ServerOptions.UninstallService),
+ "");
+#endif
+
+ options.add_option("network",
+ "",
+ "http",
+ "Select HTTP server implementation (asio|httpsys|null)",
+ cxxopts::value<std::string>(ServerOptions.HttpServerClass)->default_value(DefaultHttp),
+ "<http class>");
+
+ options.add_option("network",
+ "p",
+ "port",
+ "Select HTTP port",
+ cxxopts::value<int>(ServerOptions.BasePort)->default_value("1337"),
+ "<port number>");
+
+ options.add_option("network",
+ "",
+ "websocket-port",
+ "Websocket server port",
+ cxxopts::value<int>(ServerOptions.WebSocketPort)->default_value("0"),
+ "<port number>");
+
+ options.add_option("network",
+ "",
+ "websocket-threads",
+ "Number of websocket I/O thread(s) (0 == hardware concurrency)",
+ cxxopts::value<int>(ServerOptions.WebSocketThreads)->default_value("0"),
+ "");
+
+#if ZEN_WITH_TRACE
+ options.add_option("ue-trace",
+ "",
+ "tracehost",
+ "Hostname to send the trace to",
+ cxxopts::value<std::string>(ServerOptions.TraceHost)->default_value(""),
+ "");
+
+ options.add_option("ue-trace",
+ "",
+ "tracefile",
+ "Path to write a trace to",
+ cxxopts::value<std::string>(ServerOptions.TraceFile)->default_value(""),
+ "");
+#endif // ZEN_WITH_TRACE
+
+ options.add_option("diagnostics",
+ "",
+ "crash",
+ "Simulate a crash",
+ cxxopts::value<bool>(ServerOptions.ShouldCrash)->default_value("false"),
+ "");
+
+ std::string UpstreamCachePolicyOptions;
+ options.add_option("cache",
+ "",
+ "upstream-cache-policy",
+ "",
+ cxxopts::value<std::string>(UpstreamCachePolicyOptions)->default_value(""),
+ "Upstream cache policy (readwrite|readonly|writeonly|disabled)");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-url",
+ "URL to a Jupiter instance",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.Url)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-oauth-url",
+ "URL to the OAuth provier",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthUrl)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-oauth-clientid",
+ "The OAuth client ID",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientId)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-oauth-clientsecret",
+ "The OAuth client secret",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientSecret)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-openid-provider",
+ "Name of a registered Open ID provider",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.OpenIdProvider)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-token",
+ "A static authentication token",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.AccessToken)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-namespace",
+ "The Common Blob Store API namespace",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.Namespace)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-jupiter-namespace-ddc",
+ "The lecacy DDC namespace",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.JupiterConfig.DdcNamespace)->default_value(""),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-zen-url",
+ "URL to remote Zen server. Use a comma separated list to choose the one with the best latency.",
+ cxxopts::value<std::vector<std::string>>(ServerOptions.UpstreamCacheConfig.ZenConfig.Urls),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-zen-dns",
+ "DNS that resolves to one or more Zen server instance(s)",
+ cxxopts::value<std::vector<std::string>>(ServerOptions.UpstreamCacheConfig.ZenConfig.Dns),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-thread-count",
+ "Number of threads used for upstream procsssing",
+ cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.UpstreamThreadCount)->default_value("4"),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-connect-timeout-ms",
+ "Connect timeout in millisecond(s). Default 5000 ms.",
+ cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.ConnectTimeoutMilliseconds)->default_value("5000"),
+ "");
+
+ options.add_option("cache",
+ "",
+ "upstream-timeout-ms",
+ "Timeout in millisecond(s). Default 0 ms",
+ cxxopts::value<int32_t>(ServerOptions.UpstreamCacheConfig.TimeoutMilliseconds)->default_value("0"),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-url",
+ "URL to a Horde instance.",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Url)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-oauth-url",
+ "URL to the OAuth provier",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthUrl)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-oauth-clientid",
+ "The OAuth client ID",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientId)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-oauth-clientsecret",
+ "The OAuth client secret",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientSecret)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-openid-provider",
+ "Name of a registered Open ID provider",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.OpenIdProvider)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-token",
+ "A static authentication token",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.AccessToken)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-storage-url",
+ "URL to a Horde Storage instance.",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageUrl)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-storage-oauth-url",
+ "URL to the OAuth provier",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthUrl)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-storage-oauth-clientid",
+ "The OAuth client ID",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientId)->default_value(""),
+ "");
+
+ options.add_option(
+ "compute",
+ "",
+ "upstream-horde-storage-oauth-clientsecret",
+ "The OAuth client secret",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientSecret)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-storage-openid-provider",
+ "Name of a registered Open ID provider",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOpenIdProvider)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-storage-token",
+ "A static authentication token",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.StorageAccessToken)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-cluster",
+ "The Horde compute cluster id",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Cluster)->default_value(""),
+ "");
+
+ options.add_option("compute",
+ "",
+ "upstream-horde-namespace",
+ "The Jupiter namespace to use with Horde compute",
+ cxxopts::value<std::string>(ServerOptions.UpstreamCacheConfig.HordeConfig.Namespace)->default_value(""),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-enabled",
+ "Whether garbage collection is enabled or not.",
+ cxxopts::value<bool>(ServerOptions.GcConfig.Enabled)->default_value("true"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-small-objects",
+ "Whether garbage collection of small objects is enabled or not.",
+ cxxopts::value<bool>(ServerOptions.GcConfig.CollectSmallObjects)->default_value("true"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-interval-seconds",
+ "Garbage collection interval in seconds. Default set to 3600 (1 hour).",
+ cxxopts::value<int32_t>(ServerOptions.GcConfig.IntervalSeconds)->default_value("3600"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-cache-duration-seconds",
+ "Max duration in seconds before Z$ entries get evicted. Default set to 1209600 (2 weeks)",
+ cxxopts::value<int32_t>(ServerOptions.GcConfig.Cache.MaxDurationSeconds)->default_value("1209600"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "disk-reserve-size",
+ "Size of gc disk reserve in bytes. Default set to 268435456 (256 Mb).",
+ cxxopts::value<uint64_t>(ServerOptions.GcConfig.DiskReserveSize)->default_value("268435456"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-monitor-interval-seconds",
+ "Garbage collection monitoring interval in seconds. Default set to 30 (30 seconds)",
+ cxxopts::value<int32_t>(ServerOptions.GcConfig.MonitorIntervalSeconds)->default_value("30"),
+ "");
+
+ options.add_option("gc",
+ "",
+ "gc-disksize-softlimit",
+ "Garbage collection disk usage soft limit. Default set to 0 (Off).",
+ cxxopts::value<uint64_t>(ServerOptions.GcConfig.Cache.DiskSizeSoftLimit)->default_value("0"),
+ "");
+
+ options.add_option("objectstore",
+ "",
+ "objectstore-enabled",
+ "Whether the object store is enabled or not.",
+ cxxopts::value<bool>(ServerOptions.ObjectStoreEnabled)->default_value("false"),
+ "");
+
+ std::vector<std::string> BucketConfigs;
+ options.add_option("objectstore",
+ "",
+ "objectstore-bucket",
+ "Object store bucket mappings.",
+ cxxopts::value<std::vector<std::string>>(BucketConfigs),
+ "");
+
+ try
+ {
+ auto result = options.parse(argc, argv);
+
+ if (result.count("help"))
+ {
+ zen::logging::ConsoleLog().info("{}", options.help());
+#if ZEN_PLATFORM_WINDOWS
+ zen::logging::ConsoleLog().info("Press any key to exit!");
+ _getch();
+#else
+ // Assume the user's in a terminal on all other platforms and that
+ // they'll use less/more/etc. if need be.
+#endif
+ exit(0);
+ }
+
+ auto MakeSafePath = [](const std::string& Path) {
+#if ZEN_PLATFORM_WINDOWS
+ if (Path.empty())
+ {
+ return Path;
+ }
+
+ std::string FixedPath = Path;
+ std::replace(FixedPath.begin(), FixedPath.end(), '/', '\\');
+ if (!FixedPath.starts_with("\\\\?\\"))
+ {
+ FixedPath.insert(0, "\\\\?\\");
+ }
+ return FixedPath;
+#else
+ return Path;
+#endif
+ };
+
+ ServerOptions.DataDir = MakeSafePath(DataDir);
+ ServerOptions.ContentDir = MakeSafePath(ContentDir);
+ ServerOptions.AbsLogFile = MakeSafePath(AbsLogFile);
+ ServerOptions.ConfigFile = MakeSafePath(ConfigFile);
+ ServerOptions.UpstreamCacheConfig.CachePolicy = ParseUpstreamCachePolicy(UpstreamCachePolicyOptions);
+
+ if (OpenIdProviderUrl.empty() == false)
+ {
+ if (OpenIdClientId.empty())
+ {
+ throw cxxopts::OptionParseException("Invalid OpenID client ID");
+ }
+
+ ServerOptions.AuthConfig.OpenIdProviders.push_back(
+ {.Name = OpenIdProviderName, .Url = OpenIdProviderUrl, .ClientId = OpenIdClientId});
+ }
+
+ ServerOptions.ObjectStoreConfig = ParseBucketConfigs(BucketConfigs);
+
+ if (!ServerOptions.ConfigFile.empty())
+ {
+ ParseConfigFile(ServerOptions.ConfigFile, ServerOptions);
+ }
+ else
+ {
+ ParseConfigFile(ServerOptions.DataDir / "zen_cfg.lua", ServerOptions);
+ }
+
+ ValidateOptions(ServerOptions);
+ }
+ catch (cxxopts::OptionParseException& e)
+ {
+ zen::logging::ConsoleLog().error("Error parsing zenserver arguments: {}\n\n{}", e.what(), options.help());
+
+ throw;
+ }
+
+ if (ServerOptions.DataDir.empty())
+ {
+ ServerOptions.DataDir = PickDefaultStateDirectory();
+ }
+
+ if (ServerOptions.AbsLogFile.empty())
+ {
+ ServerOptions.AbsLogFile = ServerOptions.DataDir / "logs" / "zenserver.log";
+ }
+}
+
+void
+ParseConfigFile(const std::filesystem::path& Path, ZenServerOptions& ServerOptions)
+{
+ zen::IoBuffer LuaScript = zen::IoBufferBuilder::MakeFromFile(Path);
+
+ if (LuaScript)
+ {
+ sol::state lua;
+
+ lua.open_libraries(sol::lib::base);
+
+ lua.set_function("getenv", [&](const std::string env) -> sol::object {
+#if ZEN_PLATFORM_WINDOWS
+ std::wstring EnvVarValue;
+ size_t RequiredSize = 0;
+ std::wstring EnvWide = zen::Utf8ToWide(env);
+ _wgetenv_s(&RequiredSize, nullptr, 0, EnvWide.c_str());
+
+ if (RequiredSize == 0)
+ return sol::make_object(lua, sol::lua_nil);
+
+ EnvVarValue.resize(RequiredSize);
+ _wgetenv_s(&RequiredSize, EnvVarValue.data(), RequiredSize, EnvWide.c_str());
+ return sol::make_object(lua, zen::WideToUtf8(EnvVarValue.c_str()));
+#else
+ ZEN_UNUSED(env);
+ return sol::make_object(lua, sol::lua_nil);
+#endif
+ });
+
+ try
+ {
+ sol::load_result config = lua.load(std::string_view((const char*)LuaScript.Data(), LuaScript.Size()), "zen_cfg");
+
+ if (!config.valid())
+ {
+ sol::error err = config;
+
+ std::string ErrorString = sol::to_string(config.status());
+
+ throw std::runtime_error(fmt::format("{} error: {}", ErrorString, err.what()));
+ }
+
+ config();
+ }
+ catch (std::exception& e)
+ {
+ throw std::runtime_error(fmt::format("failed to load config script ('{}'): {}", Path, e.what()).c_str());
+ }
+
+ if (sol::optional<sol::table> ServerConfig = lua["server"])
+ {
+ if (ServerOptions.DataDir.empty())
+ {
+ if (sol::optional<std::string> Opt = ServerConfig.value()["datadir"])
+ {
+ ServerOptions.DataDir = Opt.value();
+ }
+ }
+
+ if (ServerOptions.ContentDir.empty())
+ {
+ if (sol::optional<std::string> Opt = ServerConfig.value()["contentdir"])
+ {
+ ServerOptions.ContentDir = Opt.value();
+ }
+ }
+
+ if (ServerOptions.AbsLogFile.empty())
+ {
+ if (sol::optional<std::string> Opt = ServerConfig.value()["abslog"])
+ {
+ ServerOptions.AbsLogFile = Opt.value();
+ }
+ }
+
+ ServerOptions.IsDebug = ServerConfig->get_or("debug", ServerOptions.IsDebug);
+ }
+
+ if (sol::optional<sol::table> NetworkConfig = lua["network"])
+ {
+ if (sol::optional<std::string> Opt = NetworkConfig.value()["httpserverclass"])
+ {
+ ServerOptions.HttpServerClass = Opt.value();
+ }
+
+ ServerOptions.BasePort = NetworkConfig->get_or<int>("port", ServerOptions.BasePort);
+ }
+
+ auto UpdateStringValueFromConfig = [](const sol::table& Table, std::string_view Key, std::string& OutValue) {
+ // Update the specified config value unless it has been set, i.e. from command line
+ if (auto MaybeValue = Table.get<sol::optional<std::string>>(Key); MaybeValue.has_value() && OutValue.empty())
+ {
+ OutValue = MaybeValue.value();
+ }
+ };
+
+ if (sol::optional<sol::table> StructuredCacheConfig = lua["cache"])
+ {
+ ServerOptions.StructuredCacheEnabled = StructuredCacheConfig->get_or("enable", ServerOptions.StructuredCacheEnabled);
+
+ if (auto UpstreamConfig = StructuredCacheConfig->get<sol::optional<sol::table>>("upstream"))
+ {
+ std::string Policy = UpstreamConfig->get_or("policy", std::string());
+ ServerOptions.UpstreamCacheConfig.CachePolicy = ParseUpstreamCachePolicy(Policy);
+ ServerOptions.UpstreamCacheConfig.UpstreamThreadCount =
+ UpstreamConfig->get_or("upstreamthreadcount", ServerOptions.UpstreamCacheConfig.UpstreamThreadCount);
+
+ if (auto JupiterConfig = UpstreamConfig->get<sol::optional<sol::table>>("jupiter"))
+ {
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("name"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.Name);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("url"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.Url);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("oauthprovider"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthUrl);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("oauthclientid"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientId);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("oauthclientsecret"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.OAuthClientSecret);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("openidprovider"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.OpenIdProvider);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("token"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.AccessToken);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("namespace"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.Namespace);
+ UpdateStringValueFromConfig(JupiterConfig.value(),
+ std::string_view("ddcnamespace"),
+ ServerOptions.UpstreamCacheConfig.JupiterConfig.DdcNamespace);
+ };
+
+ if (auto ZenConfig = UpstreamConfig->get<sol::optional<sol::table>>("zen"))
+ {
+ ServerOptions.UpstreamCacheConfig.ZenConfig.Name = ZenConfig.value().get_or("name", std::string("Zen"));
+
+ if (auto Url = ZenConfig.value().get<sol::optional<std::string>>("url"))
+ {
+ ServerOptions.UpstreamCacheConfig.ZenConfig.Urls.push_back(Url.value());
+ }
+ else if (auto Urls = ZenConfig.value().get<sol::optional<sol::table>>("url"))
+ {
+ for (const auto& Kv : Urls.value())
+ {
+ ServerOptions.UpstreamCacheConfig.ZenConfig.Urls.push_back(Kv.second.as<std::string>());
+ }
+ }
+
+ if (auto Dns = ZenConfig.value().get<sol::optional<std::string>>("dns"))
+ {
+ ServerOptions.UpstreamCacheConfig.ZenConfig.Dns.push_back(Dns.value());
+ }
+ else if (auto DnsArray = ZenConfig.value().get<sol::optional<sol::table>>("dns"))
+ {
+ for (const auto& Kv : DnsArray.value())
+ {
+ ServerOptions.UpstreamCacheConfig.ZenConfig.Dns.push_back(Kv.second.as<std::string>());
+ }
+ }
+ }
+ }
+ }
+
+ if (sol::optional<sol::table> ExecConfig = lua["exec"])
+ {
+ ServerOptions.ExecServiceEnabled = ExecConfig->get_or("enable", ServerOptions.ExecServiceEnabled);
+ }
+
+ if (sol::optional<sol::table> ComputeConfig = lua["compute"])
+ {
+ ServerOptions.ComputeServiceEnabled = ComputeConfig->get_or("enable", ServerOptions.ComputeServiceEnabled);
+
+ if (auto UpstreamConfig = ComputeConfig->get<sol::optional<sol::table>>("upstream"))
+ {
+ if (auto HordeConfig = UpstreamConfig->get<sol::optional<sol::table>>("horde"))
+ {
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("name"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.Name);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("url"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.Url);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("oauthprovider"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthUrl);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("oauthclientid"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientId);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("oauthclientsecret"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.OAuthClientSecret);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("openidprovider"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.OpenIdProvider);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("token"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.AccessToken);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("cluster"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.Cluster);
+ UpdateStringValueFromConfig(HordeConfig.value(),
+ std::string_view("namespace"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.Namespace);
+ };
+
+ if (auto StorageConfig = UpstreamConfig->get<sol::optional<sol::table>>("storage"))
+ {
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("url"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageUrl);
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("oauthprovider"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthUrl);
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("oauthclientid"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientId);
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("oauthclientsecret"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOAuthClientSecret);
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("openidprovider"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageOpenIdProvider);
+ UpdateStringValueFromConfig(StorageConfig.value(),
+ std::string_view("token"),
+ ServerOptions.UpstreamCacheConfig.HordeConfig.StorageAccessToken);
+ };
+ }
+ }
+
+ if (sol::optional<sol::table> GcConfig = lua["gc"])
+ {
+ ServerOptions.GcConfig.MonitorIntervalSeconds = GcConfig.value().get_or("monitorintervalseconds", 30);
+ ServerOptions.GcConfig.IntervalSeconds = GcConfig.value().get_or("intervalseconds", 0);
+ ServerOptions.GcConfig.DiskReserveSize = GcConfig.value().get_or("diskreservesize", uint64_t(1u << 28));
+
+ if (sol::optional<sol::table> CacheGcConfig = GcConfig.value()["cache"])
+ {
+ ServerOptions.GcConfig.Cache.MaxDurationSeconds = CacheGcConfig.value().get_or("maxdurationseconds", int32_t(0));
+ ServerOptions.GcConfig.Cache.DiskSizeLimit = CacheGcConfig.value().get_or("disksizelimit", ~uint64_t(0));
+ ServerOptions.GcConfig.Cache.MemorySizeLimit = CacheGcConfig.value().get_or("memorysizelimit", ~uint64_t(0));
+ ServerOptions.GcConfig.Cache.DiskSizeSoftLimit = CacheGcConfig.value().get_or("disksizesoftlimit", 0);
+ }
+
+ if (sol::optional<sol::table> CasGcConfig = GcConfig.value()["cas"])
+ {
+ ServerOptions.GcConfig.Cas.LargeStrategySizeLimit = CasGcConfig.value().get_or("largestrategysizelimit", ~uint64_t(0));
+ ServerOptions.GcConfig.Cas.SmallStrategySizeLimit = CasGcConfig.value().get_or("smallstrategysizelimit", ~uint64_t(0));
+ ServerOptions.GcConfig.Cas.TinyStrategySizeLimit = CasGcConfig.value().get_or("tinystrategysizelimit", ~uint64_t(0));
+ }
+ }
+
+ if (sol::optional<sol::table> SecurityConfig = lua["security"])
+ {
+ if (sol::optional<sol::table> OpenIdProviders = SecurityConfig.value()["openidproviders"])
+ {
+ for (const auto& Kv : OpenIdProviders.value())
+ {
+ if (sol::optional<sol::table> OpenIdProvider = Kv.second.as<sol::table>())
+ {
+ std::string Name = OpenIdProvider.value().get_or("name", std::string("Default"));
+ std::string Url = OpenIdProvider.value().get_or("url", std::string());
+ std::string ClientId = OpenIdProvider.value().get_or("clientid", std::string());
+
+ ServerOptions.AuthConfig.OpenIdProviders.push_back(
+ {.Name = std::move(Name), .Url = std::move(Url), .ClientId = std::move(ClientId)});
+ }
+ }
+ }
+
+ ServerOptions.EncryptionKey = SecurityConfig.value().get_or("encryptionaeskey", std::string());
+ ServerOptions.EncryptionIV = SecurityConfig.value().get_or("encryptionaesiv", std::string());
+ }
+ }
+}
diff --git a/src/zenserver/config.h b/src/zenserver/config.h
new file mode 100644
index 000000000..8a5c6de4e
--- /dev/null
+++ b/src/zenserver/config.h
@@ -0,0 +1,158 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+#include <filesystem>
+#include <string>
+#include <vector>
+
+struct ZenUpstreamJupiterConfig
+{
+ std::string Name;
+ std::string Url;
+ std::string OAuthUrl;
+ std::string OAuthClientId;
+ std::string OAuthClientSecret;
+ std::string OpenIdProvider;
+ std::string AccessToken;
+ std::string Namespace;
+ std::string DdcNamespace;
+};
+
+struct ZenUpstreamHordeConfig
+{
+ std::string Name;
+ std::string Url;
+ std::string OAuthUrl;
+ std::string OAuthClientId;
+ std::string OAuthClientSecret;
+ std::string OpenIdProvider;
+ std::string AccessToken;
+
+ std::string StorageUrl;
+ std::string StorageOAuthUrl;
+ std::string StorageOAuthClientId;
+ std::string StorageOAuthClientSecret;
+ std::string StorageOpenIdProvider;
+ std::string StorageAccessToken;
+
+ std::string Cluster;
+ std::string Namespace;
+};
+
+struct ZenUpstreamZenConfig
+{
+ std::string Name;
+ std::vector<std::string> Urls;
+ std::vector<std::string> Dns;
+};
+
+enum class UpstreamCachePolicy : uint8_t
+{
+ Disabled = 0,
+ Read = 1 << 0,
+ Write = 1 << 1,
+ ReadWrite = Read | Write
+};
+
+struct ZenUpstreamCacheConfig
+{
+ ZenUpstreamJupiterConfig JupiterConfig;
+ ZenUpstreamHordeConfig HordeConfig;
+ ZenUpstreamZenConfig ZenConfig;
+ int32_t UpstreamThreadCount = 4;
+ int32_t ConnectTimeoutMilliseconds = 5000;
+ int32_t TimeoutMilliseconds = 0;
+ UpstreamCachePolicy CachePolicy = UpstreamCachePolicy::ReadWrite;
+};
+
+struct ZenCacheEvictionPolicy
+{
+ uint64_t DiskSizeLimit = ~uint64_t(0);
+ uint64_t MemorySizeLimit = 1024 * 1024 * 1024;
+ int32_t MaxDurationSeconds = 24 * 60 * 60;
+ uint64_t DiskSizeSoftLimit = 0;
+ bool Enabled = true;
+};
+
+struct ZenCasEvictionPolicy
+{
+ uint64_t LargeStrategySizeLimit = ~uint64_t(0);
+ uint64_t SmallStrategySizeLimit = ~uint64_t(0);
+ uint64_t TinyStrategySizeLimit = ~uint64_t(0);
+ bool Enabled = true;
+};
+
+struct ZenGcConfig
+{
+ ZenCasEvictionPolicy Cas;
+ ZenCacheEvictionPolicy Cache;
+ int32_t MonitorIntervalSeconds = 30;
+ int32_t IntervalSeconds = 0;
+ bool CollectSmallObjects = true;
+ bool Enabled = true;
+ uint64_t DiskReserveSize = 1ul << 28;
+};
+
+struct ZenOpenIdProviderConfig
+{
+ std::string Name;
+ std::string Url;
+ std::string ClientId;
+};
+
+struct ZenAuthConfig
+{
+ std::vector<ZenOpenIdProviderConfig> OpenIdProviders;
+};
+
+struct ZenObjectStoreConfig
+{
+ struct BucketConfig
+ {
+ std::string Name;
+ std::filesystem::path Directory;
+ };
+
+ std::vector<BucketConfig> Buckets;
+};
+
+struct ZenServerOptions
+{
+ ZenUpstreamCacheConfig UpstreamCacheConfig;
+ ZenGcConfig GcConfig;
+ ZenAuthConfig AuthConfig;
+ ZenObjectStoreConfig ObjectStoreConfig;
+ std::filesystem::path DataDir; // Root directory for state (used for testing)
+ std::filesystem::path ContentDir; // Root directory for serving frontend content (experimental)
+ std::filesystem::path AbsLogFile; // Absolute path to main log file
+ std::filesystem::path ConfigFile; // Path to Lua config file
+ std::string ChildId; // Id assigned by parent process (used for lifetime management)
+ std::string LogId; // Id for tagging log output
+ std::string HttpServerClass; // Choice of HTTP server implementation
+ std::string EncryptionKey; // 256 bit AES encryption key
+ std::string EncryptionIV; // 128 bit AES initialization vector
+ int BasePort = 1337; // Service listen port (used for both UDP and TCP)
+ int OwnerPid = 0; // Parent process id (zero for standalone)
+ int WebSocketPort = 0; // Web socket port (Zero = disabled)
+ int WebSocketThreads = 0;
+ bool InstallService = false; // Flag used to initiate service install (temporary)
+ bool UninstallService = false; // Flag used to initiate service uninstall (temporary)
+ bool IsDebug = false;
+ bool IsTest = false;
+ bool IsDedicated = false; // Indicates a dedicated/shared instance, with larger resource requirements
+ bool StructuredCacheEnabled = true;
+ bool ExecServiceEnabled = true;
+ bool ComputeServiceEnabled = true;
+ bool ShouldCrash = false; // Option for testing crash handling
+ bool IsFirstRun = false;
+ bool NoSentry = false;
+ bool ObjectStoreEnabled = false;
+#if ZEN_WITH_TRACE
+ std::string TraceHost; // Host name or IP address to send trace data to
+ std::string TraceFile; // Path of a file to write a trace
+#endif
+};
+
+void ParseCliOptions(int argc, char* argv[], ZenServerOptions& ServerOptions);
diff --git a/src/zenserver/diag/diagsvcs.cpp b/src/zenserver/diag/diagsvcs.cpp
new file mode 100644
index 000000000..29ad5c3dd
--- /dev/null
+++ b/src/zenserver/diag/diagsvcs.cpp
@@ -0,0 +1,127 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "diagsvcs.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/config.h>
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include <fstream>
+#include <sstream>
+
+#include <json11.hpp>
+
+namespace zen {
+
+using namespace std::literals;
+
+bool
+ReadFile(const std::string& Path, StringBuilderBase& Out)
+{
+ try
+ {
+ constexpr auto ReadSize = std::size_t{4096};
+ auto FileStream = std::ifstream{Path};
+
+ std::string Buf(ReadSize, '\0');
+ while (FileStream.read(&Buf[0], ReadSize))
+ {
+ Out.Append(std::string_view(&Buf[0], FileStream.gcount()));
+ }
+ Out.Append(std::string_view(&Buf[0], FileStream.gcount()));
+
+ return true;
+ }
+ catch (std::exception&)
+ {
+ Out.Reset();
+ return false;
+ }
+}
+
+HttpHealthService::HttpHealthService()
+{
+ m_Router.RegisterRoute(
+ "",
+ [](HttpRouterRequest& RoutedReq) {
+ HttpServerRequest& HttpReq = RoutedReq.ServerRequest();
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "info",
+ [this](HttpRouterRequest& RoutedReq) {
+ HttpServerRequest& HttpReq = RoutedReq.ServerRequest();
+
+ CbObjectWriter Writer;
+ Writer << "DataRoot"sv << m_HealthInfo.DataRoot.string();
+ Writer << "AbsLogPath"sv << m_HealthInfo.AbsLogPath.string();
+ Writer << "BuildVersion"sv << m_HealthInfo.BuildVersion;
+ Writer << "HttpServerClass"sv << m_HealthInfo.HttpServerClass;
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, Writer.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "log",
+ [this](HttpRouterRequest& RoutedReq) {
+ HttpServerRequest& HttpReq = RoutedReq.ServerRequest();
+
+ zen::Log().flush();
+
+ std::filesystem::path Path =
+ m_HealthInfo.AbsLogPath.empty() ? m_HealthInfo.DataRoot / "logs/zenserver.log" : m_HealthInfo.AbsLogPath;
+
+ ExtendableStringBuilder<4096> Sb;
+ if (ReadFile(Path.string(), Sb) && Sb.Size() > 0)
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, Sb.ToView());
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ },
+ HttpVerb::kGet);
+ m_Router.RegisterRoute(
+ "version",
+ [this](HttpRouterRequest& RoutedReq) {
+ HttpServerRequest& HttpReq = RoutedReq.ServerRequest();
+ if (HttpReq.GetQueryParams().GetValue("detailed") == "true")
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION_BUILD_STRING_FULL);
+ }
+ else
+ {
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, ZEN_CFG_VERSION);
+ }
+ },
+ HttpVerb::kGet);
+}
+
+void
+HttpHealthService::SetHealthInfo(HealthServiceInfo&& Info)
+{
+ m_HealthInfo = std::move(Info);
+}
+
+const char*
+HttpHealthService::BaseUri() const
+{
+ return "/health/";
+}
+
+void
+HttpHealthService::HandleRequest(HttpServerRequest& Request)
+{
+ if (!m_Router.HandleRequest(Request))
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"OK!"sv);
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/diag/diagsvcs.h b/src/zenserver/diag/diagsvcs.h
new file mode 100644
index 000000000..bd03f8023
--- /dev/null
+++ b/src/zenserver/diag/diagsvcs.h
@@ -0,0 +1,111 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+#include <zenhttp/httpserver.h>
+
+#include <filesystem>
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+class HttpTestService : public HttpService
+{
+ uint32_t LogPoint = 0;
+
+public:
+ HttpTestService() {}
+ ~HttpTestService() = default;
+
+ virtual const char* BaseUri() const override { return "/test/"; }
+
+ virtual void HandleRequest(HttpServerRequest& Request) override
+ {
+ using namespace std::literals;
+
+ auto Uri = Request.RelativeUri();
+
+ if (Uri == "hello"sv)
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, u8"hello world!"sv);
+
+ // OutputLogMessageInternal(&LogPoint, 0, 0);
+ }
+ else if (Uri == "1K"sv)
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1k);
+ }
+ else if (Uri == "1M"sv)
+ {
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, m_1m);
+ }
+ else if (Uri == "1M_1k"sv)
+ {
+ std::vector<IoBuffer> Buffers;
+ Buffers.reserve(1024);
+
+ for (int i = 0; i < 1024; ++i)
+ {
+ Buffers.push_back(m_1k);
+ }
+
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers);
+ }
+ else if (Uri == "1G"sv)
+ {
+ std::vector<IoBuffer> Buffers;
+ Buffers.reserve(1024);
+
+ for (int i = 0; i < 1024; ++i)
+ {
+ Buffers.push_back(m_1m);
+ }
+
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers);
+ }
+ else if (Uri == "1G_1k"sv)
+ {
+ std::vector<IoBuffer> Buffers;
+ Buffers.reserve(1024 * 1024);
+
+ for (int i = 0; i < 1024 * 1024; ++i)
+ {
+ Buffers.push_back(m_1k);
+ }
+
+ Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Buffers);
+ }
+ }
+
+private:
+ IoBuffer m_1m{1024 * 1024};
+ IoBuffer m_1k{m_1m, 0u, 1024};
+};
+
+struct HealthServiceInfo
+{
+ std::filesystem::path DataRoot;
+ std::filesystem::path AbsLogPath;
+ std::string HttpServerClass;
+ std::string BuildVersion;
+};
+
+class HttpHealthService : public HttpService
+{
+public:
+ HttpHealthService();
+ ~HttpHealthService() = default;
+
+ void SetHealthInfo(HealthServiceInfo&& Info);
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override final;
+
+private:
+ HttpRequestRouter m_Router;
+ HealthServiceInfo m_HealthInfo;
+};
+
+} // namespace zen
diff --git a/src/zenserver/diag/formatters.h b/src/zenserver/diag/formatters.h
new file mode 100644
index 000000000..759df58d3
--- /dev/null
+++ b/src/zenserver/diag/formatters.h
@@ -0,0 +1,71 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/iobuffer.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <fmt/format.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+template<>
+struct fmt::formatter<cpr::Response>
+{
+ constexpr auto parse(format_parse_context& Ctx) -> decltype(Ctx.begin()) { return Ctx.end(); }
+
+ template<typename FormatContext>
+ auto format(const cpr::Response& Response, FormatContext& Ctx) -> decltype(Ctx.out())
+ {
+ using namespace std::literals;
+
+ if (Response.status_code == 200 || Response.status_code == 201)
+ {
+ return fmt::format_to(Ctx.out(),
+ "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s",
+ Response.url.str(),
+ Response.status_code,
+ Response.uploaded_bytes,
+ Response.downloaded_bytes,
+ Response.elapsed);
+ }
+ else
+ {
+ const auto It = Response.header.find("Content-Type");
+ const std::string_view ContentType = It != Response.header.end() ? It->second : "<None>"sv;
+
+ if (ContentType == "application/x-ue-cb"sv)
+ {
+ zen::IoBuffer Body(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ zen::CbObjectView Obj(Body.Data());
+ zen::ExtendableStringBuilder<256> Sb;
+ std::string_view Json = Obj.ToJson(Sb).ToView();
+
+ return fmt::format_to(Ctx.out(),
+ "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Response: '{}', Reason: '{}'",
+ Response.url.str(),
+ Response.status_code,
+ Response.uploaded_bytes,
+ Response.downloaded_bytes,
+ Response.elapsed,
+ Json,
+ Response.reason);
+ }
+ else
+ {
+ return fmt::format_to(Ctx.out(),
+ "Url: {}, Status: {}, Bytes: {}/{} (Up/Down), Elapsed: {}s, Reponse: '{}', Reason: '{}'",
+ Response.url.str(),
+ Response.status_code,
+ Response.uploaded_bytes,
+ Response.downloaded_bytes,
+ Response.elapsed,
+ Response.text,
+ Response.reason);
+ }
+ }
+ }
+};
diff --git a/src/zenserver/diag/logging.cpp b/src/zenserver/diag/logging.cpp
new file mode 100644
index 000000000..24c7572f4
--- /dev/null
+++ b/src/zenserver/diag/logging.cpp
@@ -0,0 +1,467 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "logging.h"
+
+#include "config.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <spdlog/async.h>
+#include <spdlog/async_logger.h>
+#include <spdlog/pattern_formatter.h>
+#include <spdlog/sinks/ansicolor_sink.h>
+#include <spdlog/sinks/basic_file_sink.h>
+#include <spdlog/sinks/daily_file_sink.h>
+#include <spdlog/sinks/msvc_sink.h>
+#include <spdlog/sinks/rotating_file_sink.h>
+#include <spdlog/sinks/stdout_color_sinks.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <zencore/compactbinary.h>
+#include <zencore/filesystem.h>
+#include <zencore/string.h>
+
+#include <chrono>
+#include <memory>
+
+// Custom logging -- test code, this should be tweaked
+
+namespace logging {
+
+using namespace spdlog;
+using namespace spdlog::details;
+using namespace std::literals;
+
+class full_formatter final : public spdlog::formatter
+{
+public:
+ full_formatter(std::string_view LogId, std::chrono::time_point<std::chrono::system_clock> Epoch) : m_Epoch(Epoch), m_LogId(LogId) {}
+
+ virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<full_formatter>(m_LogId, m_Epoch); }
+
+ static constexpr bool UseDate = false;
+
+ virtual void format(const details::log_msg& msg, memory_buf_t& dest) override
+ {
+ using std::chrono::duration_cast;
+ using std::chrono::milliseconds;
+ using std::chrono::seconds;
+
+ if constexpr (UseDate)
+ {
+ auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch());
+ if (secs != m_LastLogSecs)
+ {
+ m_CachedTm = os::localtime(log_clock::to_time_t(msg.time));
+ m_LastLogSecs = secs;
+ }
+ }
+
+ const auto& tm_time = m_CachedTm;
+
+ // cache the date/time part for the next second.
+ auto duration = msg.time - m_Epoch;
+ auto secs = duration_cast<seconds>(duration);
+
+ if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0)
+ {
+ m_CachedDatetime.clear();
+ m_CachedDatetime.push_back('[');
+
+ if constexpr (UseDate)
+ {
+ fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime);
+ m_CachedDatetime.push_back('-');
+
+ fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime);
+ m_CachedDatetime.push_back('-');
+
+ fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime);
+ m_CachedDatetime.push_back(' ');
+
+ fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+
+ fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+
+ fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime);
+ }
+ else
+ {
+ int Count = int(secs.count());
+
+ const int LogSecs = Count % 60;
+ Count /= 60;
+
+ const int LogMins = Count % 60;
+ Count /= 60;
+
+ const int LogHours = Count;
+
+ fmt_helper::pad2(LogHours, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+ fmt_helper::pad2(LogMins, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+ fmt_helper::pad2(LogSecs, m_CachedDatetime);
+ }
+
+ m_CachedDatetime.push_back('.');
+
+ m_CacheTimestamp = secs;
+ }
+
+ dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
+
+ auto millis = fmt_helper::time_fraction<milliseconds>(msg.time);
+ fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+
+ if (!m_LogId.empty())
+ {
+ dest.push_back('[');
+ fmt_helper::append_string_view(m_LogId, dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+ }
+
+ // append logger name if exists
+ if (msg.logger_name.size() > 0)
+ {
+ dest.push_back('[');
+ fmt_helper::append_string_view(msg.logger_name, dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+ }
+
+ dest.push_back('[');
+ // wrap the level name with color
+ msg.color_range_start = dest.size();
+ fmt_helper::append_string_view(level::to_string_view(msg.level), dest);
+ msg.color_range_end = dest.size();
+ dest.push_back(']');
+ dest.push_back(' ');
+
+ // add source location if present
+ if (!msg.source.empty())
+ {
+ dest.push_back('[');
+ const char* filename = details::short_filename_formatter<details::null_scoped_padder>::basename(msg.source.filename);
+ fmt_helper::append_string_view(filename, dest);
+ dest.push_back(':');
+ fmt_helper::append_int(msg.source.line, dest);
+ dest.push_back(']');
+ dest.push_back(' ');
+ }
+
+ fmt_helper::append_string_view(msg.payload, dest);
+ fmt_helper::append_string_view("\n"sv, dest);
+ }
+
+private:
+ std::chrono::time_point<std::chrono::system_clock> m_Epoch;
+ std::tm m_CachedTm;
+ std::chrono::seconds m_LastLogSecs;
+ std::chrono::seconds m_CacheTimestamp{0};
+ memory_buf_t m_CachedDatetime;
+ std::string m_LogId;
+};
+
+class json_formatter final : public spdlog::formatter
+{
+public:
+ json_formatter(std::string_view LogId) : m_LogId(LogId) {}
+
+ virtual std::unique_ptr<formatter> clone() const override { return std::make_unique<json_formatter>(m_LogId); }
+
+ virtual void format(const details::log_msg& msg, memory_buf_t& dest) override
+ {
+ using std::chrono::duration_cast;
+ using std::chrono::milliseconds;
+ using std::chrono::seconds;
+
+ auto secs = std::chrono::duration_cast<seconds>(msg.time.time_since_epoch());
+ if (secs != m_LastLogSecs)
+ {
+ m_CachedTm = os::localtime(log_clock::to_time_t(msg.time));
+ m_LastLogSecs = secs;
+ }
+
+ const auto& tm_time = m_CachedTm;
+
+ // cache the date/time part for the next second.
+
+ if (m_CacheTimestamp != secs || m_CachedDatetime.size() == 0)
+ {
+ m_CachedDatetime.clear();
+
+ fmt_helper::append_int(tm_time.tm_year + 1900, m_CachedDatetime);
+ m_CachedDatetime.push_back('-');
+
+ fmt_helper::pad2(tm_time.tm_mon + 1, m_CachedDatetime);
+ m_CachedDatetime.push_back('-');
+
+ fmt_helper::pad2(tm_time.tm_mday, m_CachedDatetime);
+ m_CachedDatetime.push_back(' ');
+
+ fmt_helper::pad2(tm_time.tm_hour, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+
+ fmt_helper::pad2(tm_time.tm_min, m_CachedDatetime);
+ m_CachedDatetime.push_back(':');
+
+ fmt_helper::pad2(tm_time.tm_sec, m_CachedDatetime);
+
+ m_CachedDatetime.push_back('.');
+
+ m_CacheTimestamp = secs;
+ }
+ dest.append("{"sv);
+ dest.append("\"time\": \""sv);
+ dest.append(m_CachedDatetime.begin(), m_CachedDatetime.end());
+ auto millis = fmt_helper::time_fraction<milliseconds>(msg.time);
+ fmt_helper::pad3(static_cast<uint32_t>(millis.count()), dest);
+ dest.append("\", "sv);
+
+ dest.append("\"status\": \""sv);
+ dest.append(level::to_string_view(msg.level));
+ dest.append("\", "sv);
+
+ dest.append("\"source\": \""sv);
+ dest.append("zenserver"sv);
+ dest.append("\", "sv);
+
+ dest.append("\"service\": \""sv);
+ dest.append("zencache"sv);
+ dest.append("\", "sv);
+
+ if (!m_LogId.empty())
+ {
+ dest.append("\"id\": \""sv);
+ dest.append(m_LogId);
+ dest.append("\", "sv);
+ }
+
+ if (msg.logger_name.size() > 0)
+ {
+ dest.append("\"logger.name\": \""sv);
+ dest.append(msg.logger_name);
+ dest.append("\", "sv);
+ }
+
+ if (msg.thread_id != 0)
+ {
+ dest.append("\"logger.thread_name\": \""sv);
+ fmt_helper::pad_uint(msg.thread_id, 0, dest);
+ dest.append("\", "sv);
+ }
+
+ if (!msg.source.empty())
+ {
+ dest.append("\"file\": \""sv);
+ WriteEscapedString(dest, details::short_filename_formatter<details::null_scoped_padder>::basename(msg.source.filename));
+ dest.append("\","sv);
+
+ dest.append("\"line\": \""sv);
+ dest.append(fmt::format("{}", msg.source.line));
+ dest.append("\","sv);
+
+ dest.append("\"logger.method_name\": \""sv);
+ WriteEscapedString(dest, msg.source.funcname);
+ dest.append("\", "sv);
+ }
+
+ dest.append("\"message\": \""sv);
+ WriteEscapedString(dest, msg.payload);
+ dest.append("\""sv);
+
+ dest.append("}\n"sv);
+ }
+
+private:
+ static inline const std::unordered_map<char, std::string_view> SpecialCharacterMap{{'\b', "\\b"sv},
+ {'\f', "\\f"sv},
+ {'\n', "\\n"sv},
+ {'\r', "\\r"sv},
+ {'\t', "\\t"sv},
+ {'"', "\\\""sv},
+ {'\\', "\\\\"sv}};
+
+ static void WriteEscapedString(memory_buf_t& dest, const spdlog::string_view_t& payload)
+ {
+ const char* RangeStart = payload.begin();
+ for (const char* It = RangeStart; It != payload.end(); ++It)
+ {
+ if (auto SpecialIt = SpecialCharacterMap.find(*It); SpecialIt != SpecialCharacterMap.end())
+ {
+ if (RangeStart != It)
+ {
+ dest.append(RangeStart, It);
+ }
+ dest.append(SpecialIt->second);
+ RangeStart = It + 1;
+ }
+ }
+ if (RangeStart != payload.end())
+ {
+ dest.append(RangeStart, payload.end());
+ }
+ };
+
+ std::tm m_CachedTm{0, 0, 0, 0, 0, 0, 0, 0, 0};
+ std::chrono::seconds m_LastLogSecs{0};
+ std::chrono::seconds m_CacheTimestamp{0};
+ memory_buf_t m_CachedDatetime;
+ std::string m_LogId;
+};
+
+bool
+EnableVTMode()
+{
+#if ZEN_PLATFORM_WINDOWS
+ // Set output mode to handle virtual terminal sequences
+ HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE);
+ if (hOut == INVALID_HANDLE_VALUE)
+ {
+ return false;
+ }
+
+ DWORD dwMode = 0;
+ if (!GetConsoleMode(hOut, &dwMode))
+ {
+ return false;
+ }
+
+ dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING;
+ if (!SetConsoleMode(hOut, dwMode))
+ {
+ return false;
+ }
+#endif
+
+ return true;
+}
+
+} // namespace logging
+
+void
+InitializeLogging(const ZenServerOptions& GlobalOptions)
+{
+ zen::logging::InitializeLogging();
+ logging::EnableVTMode();
+
+ bool IsAsync = true;
+ spdlog::level::level_enum LogLevel = spdlog::level::info;
+
+ if (GlobalOptions.IsDebug)
+ {
+ LogLevel = spdlog::level::debug;
+ IsAsync = false;
+ }
+
+ if (GlobalOptions.IsTest)
+ {
+ LogLevel = spdlog::level::trace;
+ IsAsync = false;
+ }
+
+ if (IsAsync)
+ {
+ const int QueueSize = 8192;
+ const int ThreadCount = 1;
+ spdlog::init_thread_pool(QueueSize, ThreadCount);
+
+ auto AsyncLogger = spdlog::create_async<spdlog::sinks::ansicolor_stdout_sink_mt>("main");
+ zen::logging::SetDefault(AsyncLogger);
+ }
+
+ // Sinks
+
+ auto ConsoleSink = std::make_shared<spdlog::sinks::ansicolor_stdout_sink_mt>();
+
+ // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance
+ zen::CreateDirectories(GlobalOptions.AbsLogFile.parent_path());
+
+#if 0
+ auto FileSink = std::make_shared<spdlog::sinks::daily_file_sink_mt>(zen::PathToUtf8(GlobalOptions.AbsLogFile),
+ 0,
+ 0,
+ /* truncate */ false,
+ uint16_t(/* max files */ 14));
+#else
+ auto FileSink = std::make_shared<spdlog::sinks::rotating_file_sink_mt>(zen::PathToUtf8(GlobalOptions.AbsLogFile),
+ /* max size */ 128 * 1024 * 1024,
+ /* max files */ 16,
+ /* rotate on open */ true);
+#endif
+
+ std::set_terminate([]() { ZEN_CRITICAL("Program exited abnormally via std::terminate()"); });
+
+ // Default
+
+ auto& DefaultLogger = zen::logging::Default();
+ auto& Sinks = DefaultLogger.sinks();
+
+ Sinks.clear();
+ Sinks.push_back(ConsoleSink);
+ Sinks.push_back(FileSink);
+
+#if ZEN_PLATFORM_WINDOWS
+ if (zen::IsDebuggerPresent() && GlobalOptions.IsDebug)
+ {
+ auto DebugSink = std::make_shared<spdlog::sinks::msvc_sink_mt>();
+ DebugSink->set_level(spdlog::level::debug);
+ Sinks.push_back(DebugSink);
+ }
+#endif
+
+ // HTTP server request logging
+
+ std::filesystem::path HttpLogPath = GlobalOptions.DataDir / "logs" / "http.log";
+
+ // spdlog can't create directories that starts with `\\?\` so we make sure the folder exists before creating the logger instance
+ zen::CreateDirectories(HttpLogPath.parent_path());
+
+ auto HttpSink = std::make_shared<spdlog::sinks::rotating_file_sink_mt>(zen::PathToUtf8(HttpLogPath),
+ /* max size */ 128 * 1024 * 1024,
+ /* max files */ 16,
+ /* rotate on open */ true);
+
+ auto HttpLogger = std::make_shared<spdlog::logger>("http_requests", HttpSink);
+ spdlog::register_logger(HttpLogger);
+
+ // Jupiter - only log upstream HTTP traffic to file
+
+ auto JupiterLogger = std::make_shared<spdlog::logger>("jupiter", FileSink);
+ spdlog::register_logger(JupiterLogger);
+
+ // Zen - only log upstream HTTP traffic to file
+
+ auto ZenClientLogger = std::make_shared<spdlog::logger>("zenclient", FileSink);
+ spdlog::register_logger(ZenClientLogger);
+
+ // Configure all registered loggers according to settings
+
+ spdlog::set_level(LogLevel);
+ spdlog::flush_on(spdlog::level::err);
+ spdlog::flush_every(std::chrono::seconds{2});
+ spdlog::set_formatter(std::make_unique<logging::full_formatter>(GlobalOptions.LogId, std::chrono::system_clock::now()));
+
+ if (GlobalOptions.AbsLogFile.extension() == ".json")
+ {
+ FileSink->set_formatter(std::make_unique<logging::json_formatter>(GlobalOptions.LogId));
+ }
+ else
+ {
+ FileSink->set_pattern("[%C-%m-%d.%e %T] [%n] [%l] %v");
+ }
+ DefaultLogger.info("log starting at {}", zen::DateTime::Now().ToIso8601());
+}
+
+void
+ShutdownLogging()
+{
+ auto& DefaultLogger = zen::logging::Default();
+ DefaultLogger.info("log ending at {}", zen::DateTime::Now().ToIso8601());
+ zen::logging::ShutdownLogging();
+}
diff --git a/src/zenserver/diag/logging.h b/src/zenserver/diag/logging.h
new file mode 100644
index 000000000..8df49f842
--- /dev/null
+++ b/src/zenserver/diag/logging.h
@@ -0,0 +1,10 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging.h>
+struct ZenServerOptions;
+
+void InitializeLogging(const ZenServerOptions& GlobalOptions);
+
+void ShutdownLogging();
diff --git a/src/zenserver/frontend/frontend.cpp b/src/zenserver/frontend/frontend.cpp
new file mode 100644
index 000000000..149d97924
--- /dev/null
+++ b/src/zenserver/frontend/frontend.cpp
@@ -0,0 +1,128 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "frontend.h"
+
+#include <zencore/endian.h>
+#include <zencore/filesystem.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#if ZEN_PLATFORM_WINDOWS
+# include <Windows.h>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+////////////////////////////////////////////////////////////////////////////////
+HttpFrontendService::HttpFrontendService(std::filesystem::path Directory) : m_Directory(Directory)
+{
+ std::filesystem::path SelfPath = GetRunningExecutablePath();
+
+ // Locate a .zip file appended onto the end of this binary
+ IoBuffer SelfBuffer = IoBufferBuilder::MakeFromFile(SelfPath);
+ m_ZipFs = ZipFs(std::move(SelfBuffer));
+
+#if ZEN_BUILD_DEBUG
+ if (!Directory.empty())
+ {
+ return;
+ }
+
+ std::error_code ErrorCode;
+ auto Path = SelfPath;
+ while (Path.has_parent_path())
+ {
+ auto ParentPath = Path.parent_path();
+ if (ParentPath == Path)
+ {
+ break;
+ }
+ if (std::filesystem::is_regular_file(ParentPath / "xmake.lua", ErrorCode))
+ {
+ if (ErrorCode)
+ {
+ break;
+ }
+
+ auto HtmlDir = ParentPath / "zenserver" / "frontend" / "html";
+ if (std::filesystem::is_directory(HtmlDir, ErrorCode))
+ {
+ m_Directory = HtmlDir;
+ }
+ break;
+ }
+ Path = ParentPath;
+ };
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////
+HttpFrontendService::~HttpFrontendService()
+{
+}
+
+////////////////////////////////////////////////////////////////////////////////
+const char*
+HttpFrontendService::BaseUri() const
+{
+ return "/dashboard"; // in order to use the root path we need to remove HttpAddUrlToUrlGroup in HttpSys.cpp
+}
+
+////////////////////////////////////////////////////////////////////////////////
+void
+HttpFrontendService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ using namespace std::literals;
+
+ std::string_view Uri = Request.RelativeUriWithExtension();
+ for (; Uri[0] == '/'; Uri = Uri.substr(1))
+ ;
+ if (Uri.empty())
+ {
+ Uri = "index.html"sv;
+ }
+
+ // Dismiss if the URI contains .. anywhere to prevent arbitrary file reads
+ if (Uri.find("..") != Uri.npos)
+ {
+ return Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ // Map the file extension to a MIME type. To keep things constrained, only a
+ // small subset of file extensions is allowed
+
+ HttpContentType ContentType = HttpContentType::kUnknownContentType;
+
+ if (const size_t DotIndex = Uri.rfind("."); DotIndex != Uri.npos)
+ {
+ const std::string_view DotExt = Uri.substr(DotIndex + 1);
+
+ ContentType = ParseContentType(DotExt);
+ }
+
+ if (ContentType == HttpContentType::kUnknownContentType)
+ {
+ return Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ // The given content directory overrides any zip-fs discovered in the binary
+ if (!m_Directory.empty())
+ {
+ FileContents File = ReadFile(m_Directory / Uri);
+
+ if (!File.ErrorCode)
+ {
+ return Request.WriteResponse(HttpResponseCode::OK, ContentType, File.Data[0]);
+ }
+ }
+
+ if (IoBuffer FileBuffer = m_ZipFs.GetFile(Uri))
+ {
+ return Request.WriteResponse(HttpResponseCode::OK, ContentType, FileBuffer);
+ }
+
+ Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv);
+}
+
+} // namespace zen
diff --git a/src/zenserver/frontend/frontend.h b/src/zenserver/frontend/frontend.h
new file mode 100644
index 000000000..6eac20620
--- /dev/null
+++ b/src/zenserver/frontend/frontend.h
@@ -0,0 +1,25 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+#include "zipfs.h"
+
+#include <filesystem>
+
+namespace zen {
+
+class HttpFrontendService final : public zen::HttpService
+{
+public:
+ HttpFrontendService(std::filesystem::path Directory);
+ virtual ~HttpFrontendService();
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ ZipFs m_ZipFs;
+ std::filesystem::path m_Directory;
+};
+
+} // namespace zen
diff --git a/src/zenserver/frontend/html/index.html b/src/zenserver/frontend/html/index.html
new file mode 100644
index 000000000..252ee621e
--- /dev/null
+++ b/src/zenserver/frontend/html/index.html
@@ -0,0 +1,59 @@
+<!DOCTYPE html>
+<html>
+<head>
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-F3w7mX95PdgyTmZZMECAngseQB83DfGTowi0iMjiWaeVhAn4FJkqJByhZMI3AhiU" crossorigin="anonymous">
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.min.js" integrity="sha384-skAcpIdS7UcVUC05LJ9Dxay8AXcDYfBJqt1CJ85S/CFujBsIzCIv+l9liuYLaMQ/" crossorigin="anonymous"></script>
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/font/bootstrap-icons.css">
+ <style type="text/css">
+ body {
+ background-color: #fafafa;
+ }
+ </style>
+ <script type="text/javascript">
+ const getCacheStats = () => {
+ const opts = { headers: { "Accept": "application/json" } };
+ fetch("/stats/z$", opts)
+ .then(response => {
+ if (!response.ok) {
+ throw Error(response.statusText);
+ }
+ return response.json();
+ })
+ .then(json => {
+ document.getElementById("status").innerHTML = "connected"
+ document.getElementById("stats").innerHTML = JSON.stringify(json, null, 4);
+ })
+ .catch(error => {
+ document.getElementById("status").innerHTML = "disconnected"
+ document.getElementById("stats").innerHTML = ""
+ console.log(error);
+ })
+ .finally(() => {
+ window.setTimeout(getCacheStats, 1000);
+ });
+ };
+ getCacheStats();
+ </script>
+</head>
+<body>
+ <div class="container">
+ <div class="row">
+ <div class="text-center mt-5">
+ <pre>
+__________ _________ __
+\____ / ____ ____ / _____/_/ |_ ____ _______ ____
+ / / _/ __ \ / \ \_____ \ \ __\ / _ \ \_ __ \_/ __ \
+ / /_ \ ___/ | | \ / \ | | ( <_> ) | | \/\ ___/
+/_______ \ \___ >|___| //_______ / |__| \____/ |__| \___ >
+ \/ \/ \/ \/ \/
+ </pre>
+ <pre id="status"/>
+ </div>
+ </div>
+ <div class="row">
+ <pre class="mb-0">Z$:</pre>
+ <pre id="stats"></pre>
+ <div>
+ </div>
+</body>
+</html>
diff --git a/src/zenserver/frontend/zipfs.cpp b/src/zenserver/frontend/zipfs.cpp
new file mode 100644
index 000000000..f9c2bc8ff
--- /dev/null
+++ b/src/zenserver/frontend/zipfs.cpp
@@ -0,0 +1,169 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zipfs.h"
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+namespace {
+
+#if ZEN_COMPILER_MSC
+# pragma warning(push)
+# pragma warning(disable : 4200)
+#endif
+
+ using ZipInt16 = uint16_t;
+
+ struct ZipInt32
+ {
+ operator uint32_t() const { return *(uint32_t*)Parts; }
+ uint16_t Parts[2];
+ };
+
+ struct EocdRecord
+ {
+ enum : uint32_t
+ {
+ Magic = 0x0605'4b50,
+ };
+ ZipInt32 Signature;
+ ZipInt16 ThisDiskIndex;
+ ZipInt16 CdStartDiskIndex;
+ ZipInt16 CdRecordThisDiskCount;
+ ZipInt16 CdRecordCount;
+ ZipInt32 CdSize;
+ ZipInt32 CdOffset;
+ ZipInt16 CommentSize;
+ char Comment[];
+ };
+
+ struct CentralDirectoryRecord
+ {
+ enum : uint32_t
+ {
+ Magic = 0x0201'4b50,
+ };
+
+ ZipInt32 Signature;
+ ZipInt16 VersionMadeBy;
+ ZipInt16 VersionRequired;
+ ZipInt16 Flags;
+ ZipInt16 CompressionMethod;
+ ZipInt16 LastModTime;
+ ZipInt16 LastModDate;
+ ZipInt32 Crc32;
+ ZipInt32 CompressedSize;
+ ZipInt32 OriginalSize;
+ ZipInt16 FileNameLength;
+ ZipInt16 ExtraFieldLength;
+ ZipInt16 CommentLength;
+ ZipInt16 DiskIndex;
+ ZipInt16 InternalFileAttr;
+ ZipInt32 ExternalFileAttr;
+ ZipInt32 Offset;
+ char FileName[];
+ };
+
+ struct LocalFileHeader
+ {
+ enum : uint32_t
+ {
+ Magic = 0x0403'4b50,
+ };
+
+ ZipInt32 Signature;
+ ZipInt16 VersionRequired;
+ ZipInt16 Flags;
+ ZipInt16 CompressionMethod;
+ ZipInt16 LastModTime;
+ ZipInt16 LastModDate;
+ ZipInt32 Crc32;
+ ZipInt32 CompressedSize;
+ ZipInt32 OriginalSize;
+ ZipInt16 FileNameLength;
+ ZipInt16 ExtraFieldLength;
+ char FileName[];
+ };
+
+#if ZEN_COMPILER_MSC
+# pragma warning(pop)
+#endif
+
+} // namespace
+
+//////////////////////////////////////////////////////////////////////////
+ZipFs::ZipFs(IoBuffer&& Buffer)
+{
+ MemoryView View = Buffer.GetView();
+
+ uint8_t* Cursor = (uint8_t*)(View.GetData()) + View.GetSize();
+ if (View.GetSize() < sizeof(EocdRecord))
+ {
+ return;
+ }
+
+ const auto* EocdCursor = (EocdRecord*)(Cursor - sizeof(EocdRecord));
+
+ // It is more correct to search backwards for EocdRecord::Magic as the
+ // comment can be of a variable length. But here we're not going to support
+ // zip files with comments.
+ if (EocdCursor->Signature != EocdRecord::Magic)
+ {
+ return;
+ }
+
+ // Zip64 isn't supported either
+ if (EocdCursor->ThisDiskIndex == 0xffff)
+ {
+ return;
+ }
+
+ Cursor = (uint8_t*)EocdCursor - uint32_t(EocdCursor->CdOffset) - uint32_t(EocdCursor->CdSize);
+
+ const auto* CdCursor = (CentralDirectoryRecord*)(Cursor + EocdCursor->CdOffset);
+ for (int i = 0, n = EocdCursor->CdRecordCount; i < n; ++i)
+ {
+ const CentralDirectoryRecord& Cd = *CdCursor;
+
+ bool Acceptable = true;
+ Acceptable &= (Cd.OriginalSize > 0); // has some content
+ Acceptable &= (Cd.CompressionMethod == 0); // is stored uncomrpessed
+ if (Acceptable)
+ {
+ const uint8_t* Lfh = Cursor + Cd.Offset;
+ if (uintptr_t(Lfh - Cursor) < View.GetSize())
+ {
+ std::string_view FileName(Cd.FileName, Cd.FileNameLength);
+ m_Files.insert(std::make_pair(FileName, FileItem{Lfh, size_t(0)}));
+ }
+ }
+
+ uint32_t ExtraBytes = Cd.FileNameLength + Cd.ExtraFieldLength + Cd.CommentLength;
+ CdCursor = (CentralDirectoryRecord*)(Cd.FileName + ExtraBytes);
+ }
+
+ m_Buffer = std::move(Buffer);
+}
+
+//////////////////////////////////////////////////////////////////////////
+IoBuffer
+ZipFs::GetFile(const std::string_view& FileName) const
+{
+ FileMap::iterator Iter = m_Files.find(FileName);
+ if (Iter == m_Files.end())
+ {
+ return {};
+ }
+
+ FileItem& Item = Iter->second;
+ if (Item.GetSize() > 0)
+ {
+ return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize());
+ }
+
+ const auto* Lfh = (LocalFileHeader*)(Item.GetData());
+ Item = MemoryView(Lfh->FileName + Lfh->FileNameLength + Lfh->ExtraFieldLength, Lfh->OriginalSize);
+ return IoBuffer(IoBuffer::Wrap, Item.GetData(), Item.GetSize());
+}
+
+} // namespace zen
diff --git a/src/zenserver/frontend/zipfs.h b/src/zenserver/frontend/zipfs.h
new file mode 100644
index 000000000..e1fa4457c
--- /dev/null
+++ b/src/zenserver/frontend/zipfs.h
@@ -0,0 +1,26 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+
+#include <unordered_map>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+class ZipFs
+{
+public:
+ ZipFs() = default;
+ ZipFs(IoBuffer&& Buffer);
+ IoBuffer GetFile(const std::string_view& FileName) const;
+
+private:
+ using FileItem = MemoryView;
+ using FileMap = std::unordered_map<std::string_view, FileItem>;
+ FileMap mutable m_Files;
+ IoBuffer m_Buffer;
+};
+
+} // namespace zen
diff --git a/src/zenserver/monitoring/httpstats.cpp b/src/zenserver/monitoring/httpstats.cpp
new file mode 100644
index 000000000..4d985f8c2
--- /dev/null
+++ b/src/zenserver/monitoring/httpstats.cpp
@@ -0,0 +1,62 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpstats.h"
+
+namespace zen {
+
+HttpStatsService::HttpStatsService() : m_Log(logging::Get("stats"))
+{
+}
+
+HttpStatsService::~HttpStatsService()
+{
+}
+
+const char*
+HttpStatsService::BaseUri() const
+{
+ return "/stats/";
+}
+
+void
+HttpStatsService::RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Providers.insert_or_assign(std::string(Id), &Provider);
+}
+
+void
+HttpStatsService::UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider)
+{
+ ZEN_UNUSED(Provider);
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Providers.erase(std::string(Id));
+}
+
+void
+HttpStatsService::HandleRequest(HttpServerRequest& Request)
+{
+ using namespace std::literals;
+
+ std::string_view Key = Request.RelativeUri();
+
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ if (auto It = m_Providers.find(std::string{Key}); It != end(m_Providers))
+ {
+ return It->second->HandleStatsRequest(Request);
+ }
+ }
+
+ [[fallthrough]];
+ default:
+ return;
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/monitoring/httpstats.h b/src/zenserver/monitoring/httpstats.h
new file mode 100644
index 000000000..732815a9a
--- /dev/null
+++ b/src/zenserver/monitoring/httpstats.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+#include <map>
+
+namespace zen {
+
+struct IHttpStatsProvider
+{
+ virtual void HandleStatsRequest(HttpServerRequest& Request) = 0;
+};
+
+class HttpStatsService : public HttpService
+{
+public:
+ HttpStatsService();
+ ~HttpStatsService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+ void RegisterHandler(std::string_view Id, IHttpStatsProvider& Provider);
+ void UnregisterHandler(std::string_view Id, IHttpStatsProvider& Provider);
+
+private:
+ spdlog::logger& m_Log;
+ HttpRequestRouter m_Router;
+
+ inline spdlog::logger& Log() { return m_Log; }
+
+ RwLock m_Lock;
+ std::map<std::string, IHttpStatsProvider*> m_Providers;
+};
+
+} // namespace zen \ No newline at end of file
diff --git a/src/zenserver/monitoring/httpstatus.cpp b/src/zenserver/monitoring/httpstatus.cpp
new file mode 100644
index 000000000..8b10601dd
--- /dev/null
+++ b/src/zenserver/monitoring/httpstatus.cpp
@@ -0,0 +1,62 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httpstatus.h"
+
+namespace zen {
+
+HttpStatusService::HttpStatusService() : m_Log(logging::Get("status"))
+{
+}
+
+HttpStatusService::~HttpStatusService()
+{
+}
+
+const char*
+HttpStatusService::BaseUri() const
+{
+ return "/status/";
+}
+
+void
+HttpStatusService::RegisterHandler(std::string_view Id, IHttpStatusProvider& Provider)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Providers.insert_or_assign(std::string(Id), &Provider);
+}
+
+void
+HttpStatusService::UnregisterHandler(std::string_view Id, IHttpStatusProvider& Provider)
+{
+ ZEN_UNUSED(Provider);
+
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_Providers.erase(std::string(Id));
+}
+
+void
+HttpStatusService::HandleRequest(HttpServerRequest& Request)
+{
+ using namespace std::literals;
+
+ std::string_view Key = Request.RelativeUri();
+
+ switch (Request.RequestVerb())
+ {
+ case HttpVerb::kHead:
+ case HttpVerb::kGet:
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ if (auto It = m_Providers.find(std::string{Key}); It != end(m_Providers))
+ {
+ return It->second->HandleStatusRequest(Request);
+ }
+ }
+
+ [[fallthrough]];
+ default:
+ return;
+ }
+}
+
+} // namespace zen
diff --git a/src/zenserver/monitoring/httpstatus.h b/src/zenserver/monitoring/httpstatus.h
new file mode 100644
index 000000000..b04e45324
--- /dev/null
+++ b/src/zenserver/monitoring/httpstatus.h
@@ -0,0 +1,38 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging.h>
+#include <zenhttp/httpserver.h>
+
+#include <map>
+
+namespace zen {
+
+struct IHttpStatusProvider
+{
+ virtual void HandleStatusRequest(HttpServerRequest& Request) = 0;
+};
+
+class HttpStatusService : public HttpService
+{
+public:
+ HttpStatusService();
+ ~HttpStatusService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+ void RegisterHandler(std::string_view Id, IHttpStatusProvider& Provider);
+ void UnregisterHandler(std::string_view Id, IHttpStatusProvider& Provider);
+
+private:
+ spdlog::logger& m_Log;
+ HttpRequestRouter m_Router;
+
+ RwLock m_Lock;
+ std::map<std::string, IHttpStatusProvider*> m_Providers;
+
+ inline spdlog::logger& Log() { return m_Log; }
+};
+
+} // namespace zen \ No newline at end of file
diff --git a/src/zenserver/objectstore/objectstore.cpp b/src/zenserver/objectstore/objectstore.cpp
new file mode 100644
index 000000000..e5739418e
--- /dev/null
+++ b/src/zenserver/objectstore/objectstore.cpp
@@ -0,0 +1,232 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <objectstore/objectstore.h>
+
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include "zencore/compactbinarybuilder.h"
+#include "zenhttp/httpcommon.h"
+#include "zenhttp/httpserver.h"
+
+#include <thread>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <json11.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+ZEN_DEFINE_LOG_CATEGORY_STATIC(LogObj, "obj"sv);
+
+HttpObjectStoreService::HttpObjectStoreService(ObjectStoreConfig Cfg) : m_Cfg(std::move(Cfg))
+{
+ Inititalize();
+}
+
+HttpObjectStoreService::~HttpObjectStoreService()
+{
+}
+
+const char*
+HttpObjectStoreService::BaseUri() const
+{
+ return "/obj/";
+}
+
+void
+HttpObjectStoreService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ if (m_Router.HandleRequest(Request) == false)
+ {
+ ZEN_LOG_WARN(LogObj, "No route found for {0}", Request.RelativeUri());
+ return Request.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Not found"sv);
+ }
+}
+
+void
+HttpObjectStoreService::Inititalize()
+{
+ ZEN_LOG_INFO(LogObj, "Initialzing Object Store in '{}'", m_Cfg.RootDirectory);
+ for (const auto& Bucket : m_Cfg.Buckets)
+ {
+ ZEN_LOG_INFO(LogObj, " - bucket '{}' -> '{}'", Bucket.Name, Bucket.Directory);
+ }
+
+ m_Router.RegisterRoute(
+ "distributionpoints/{bucket}",
+ [this](zen::HttpRouterRequest& Request) {
+ const std::string BucketName = Request.GetCapture(1);
+
+ StringBuilder<1024> Json;
+ {
+ CbObjectWriter Writer;
+ Writer.BeginArray("distributions");
+ Writer << fmt::format("http://localhost:{}/obj/{}", m_Cfg.ServerPort, BucketName);
+ Writer.EndArray();
+ Writer.Save().ToJson(Json);
+ }
+
+ Request.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kJSON, Json.ToString());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{bucket}/{path}",
+ [this](zen::HttpRouterRequest& Request) { GetBlob(Request); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{bucket}/{path}",
+ [this](zen::HttpRouterRequest& Request) { PutBlob(Request); },
+ HttpVerb::kPost | HttpVerb::kPut);
+}
+
+std::filesystem::path
+HttpObjectStoreService::GetBucketDirectory(std::string_view BucketName)
+{
+ std::lock_guard _(BucketsMutex);
+
+ if (const auto It = std::find_if(std::begin(m_Cfg.Buckets),
+ std::end(m_Cfg.Buckets),
+ [&BucketName](const auto& Bucket) -> bool { return Bucket.Name == BucketName; });
+ It != std::end(m_Cfg.Buckets))
+ {
+ return It->Directory;
+ }
+
+ return std::filesystem::path();
+}
+
+void
+HttpObjectStoreService::GetBlob(zen::HttpRouterRequest& Request)
+{
+ namespace fs = std::filesystem;
+
+ const std::string& BucketName = Request.GetCapture(1);
+ const fs::path BucketDir = GetBucketDirectory(BucketName);
+
+ if (BucketDir.empty())
+ {
+ ZEN_LOG_DEBUG(LogObj, "GET - [FAILED], unknown bucket '{}'", BucketName);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ const fs::path RelativeBucketPath = Request.GetCapture(2);
+
+ if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with(".."))
+ {
+ ZEN_LOG_DEBUG(LogObj, "GET - from bucket '{}' [FAILED], invalid file path", BucketName);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ fs::path FilePath = BucketDir / RelativeBucketPath;
+ if (fs::exists(FilePath) == false)
+ {
+ ZEN_LOG_DEBUG(LogObj, "GET - '{}/{}' [FAILED], doesn't exist", BucketName, FilePath);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ zen::HttpRanges Ranges;
+ if (Request.ServerRequest().TryGetRanges(Ranges); Ranges.size() > 1)
+ {
+ // Only a single range is supported
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ FileContents File = ReadFile(FilePath);
+ if (File.ErrorCode)
+ {
+ ZEN_LOG_WARN(LogObj,
+ "GET - '{}/{}' [FAILED] ('{}': {})",
+ BucketName,
+ FilePath,
+ File.ErrorCode.category().name(),
+ File.ErrorCode.value());
+
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ const IoBuffer& FileBuf = File.Data[0];
+
+ if (Ranges.empty())
+ {
+ const uint64_t TotalServed = TotalBytesServed.fetch_add(FileBuf.Size()) + FileBuf.Size();
+
+ ZEN_LOG_DEBUG(LogObj,
+ "GET - '{}/{}' ({}) [OK] (Served: {})",
+ BucketName,
+ RelativeBucketPath,
+ NiceBytes(FileBuf.Size()),
+ NiceBytes(TotalServed));
+
+ Request.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, FileBuf);
+ }
+ else
+ {
+ const auto Range = Ranges[0];
+ const uint64_t RangeSize = Range.End - Range.Start;
+ const uint64_t TotalServed = TotalBytesServed.fetch_add(RangeSize) + RangeSize;
+
+ ZEN_LOG_DEBUG(LogObj,
+ "GET - '{}/{}' (Range: {}-{}) ({}/{}) [OK] (Served: {})",
+ BucketName,
+ RelativeBucketPath,
+ Range.Start,
+ Range.End,
+ NiceBytes(RangeSize),
+ NiceBytes(FileBuf.Size()),
+ NiceBytes(TotalServed));
+
+ MemoryView RangeView = FileBuf.GetView().Mid(Range.Start, RangeSize);
+ if (RangeView.GetSize() != RangeSize)
+ {
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ IoBuffer RangeBuf = IoBuffer(IoBuffer::Wrap, RangeView.GetData(), RangeView.GetSize());
+ Request.ServerRequest().WriteResponse(HttpResponseCode::PartialContent, HttpContentType::kBinary, RangeBuf);
+ }
+}
+
+void
+HttpObjectStoreService::PutBlob(zen::HttpRouterRequest& Request)
+{
+ namespace fs = std::filesystem;
+
+ const std::string& BucketName = Request.GetCapture(1);
+ const fs::path BucketDir = GetBucketDirectory(BucketName);
+
+ if (BucketDir.empty())
+ {
+ ZEN_LOG_DEBUG(LogObj, "PUT - [FAILED], unknown bucket '{}'", BucketName);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ const fs::path RelativeBucketPath = Request.GetCapture(2);
+
+ if (RelativeBucketPath.is_absolute() || RelativeBucketPath.string().starts_with(".."))
+ {
+ ZEN_LOG_DEBUG(LogObj, "PUT - bucket '{}' [FAILED], invalid file path", BucketName);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::Forbidden);
+ }
+
+ fs::path FilePath = BucketDir / RelativeBucketPath;
+ const IoBuffer FileBuf = Request.ServerRequest().ReadPayload();
+
+ if (FileBuf.Size() == 0)
+ {
+ ZEN_LOG_DEBUG(LogObj, "PUT - '{}/{}' [FAILED], empty file", BucketName, FilePath);
+ return Request.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ WriteFile(FilePath, FileBuf);
+ ZEN_LOG_DEBUG(LogObj, "PUT - '{}/{}' [OK] ({})", BucketName, RelativeBucketPath, NiceBytes(FileBuf.Size()));
+ Request.ServerRequest().WriteResponse(HttpResponseCode::OK);
+}
+
+} // namespace zen
diff --git a/src/zenserver/objectstore/objectstore.h b/src/zenserver/objectstore/objectstore.h
new file mode 100644
index 000000000..eaab57794
--- /dev/null
+++ b/src/zenserver/objectstore/objectstore.h
@@ -0,0 +1,48 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+#include <atomic>
+#include <filesystem>
+#include <mutex>
+
+namespace zen {
+
+class HttpRouterRequest;
+
+struct ObjectStoreConfig
+{
+ struct BucketConfig
+ {
+ std::string Name;
+ std::filesystem::path Directory;
+ };
+
+ std::filesystem::path RootDirectory;
+ std::vector<BucketConfig> Buckets;
+ uint16_t ServerPort{1337};
+};
+
+class HttpObjectStoreService final : public zen::HttpService
+{
+public:
+ HttpObjectStoreService(ObjectStoreConfig Cfg);
+ virtual ~HttpObjectStoreService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ void Inititalize();
+ std::filesystem::path GetBucketDirectory(std::string_view BucketName);
+ void GetBlob(zen::HttpRouterRequest& Request);
+ void PutBlob(zen::HttpRouterRequest& Request);
+
+ ObjectStoreConfig m_Cfg;
+ std::mutex BucketsMutex;
+ HttpRequestRouter m_Router;
+ std::atomic_uint64_t TotalBytesServed{0};
+};
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/fileremoteprojectstore.cpp b/src/zenserver/projectstore/fileremoteprojectstore.cpp
new file mode 100644
index 000000000..d7a34a6c2
--- /dev/null
+++ b/src/zenserver/projectstore/fileremoteprojectstore.cpp
@@ -0,0 +1,235 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "fileremoteprojectstore.h"
+
+#include <zencore/compress.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/timer.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+class LocalExportProjectStore : public RemoteProjectStore
+{
+public:
+ LocalExportProjectStore(std::string_view Name,
+ const std::filesystem::path& FolderPath,
+ bool ForceDisableBlocks,
+ bool ForceEnableTempBlocks)
+ : m_Name(Name)
+ , m_OutputPath(FolderPath)
+ {
+ if (ForceDisableBlocks)
+ {
+ m_EnableBlocks = false;
+ }
+ if (ForceEnableTempBlocks)
+ {
+ m_UseTempBlocks = true;
+ }
+ }
+
+ virtual RemoteStoreInfo GetInfo() const override
+ {
+ return {.CreateBlocks = m_EnableBlocks,
+ .UseTempBlockFiles = m_UseTempBlocks,
+ .Description = fmt::format("[file] {}"sv, m_OutputPath)};
+ }
+
+ virtual SaveResult SaveContainer(const IoBuffer& Payload) override
+ {
+ Stopwatch Timer;
+ SaveResult Result;
+
+ {
+ CbObject ContainerObject = LoadCompactBinaryObject(Payload);
+
+ ContainerObject.IterateAttachments([&](CbFieldView FieldView) {
+ IoHash AttachmentHash = FieldView.AsBinaryAttachment();
+ std::filesystem::path AttachmentPath = GetAttachmentPath(AttachmentHash);
+ if (!std::filesystem::exists(AttachmentPath))
+ {
+ Result.Needs.insert(AttachmentHash);
+ }
+ });
+ }
+
+ std::filesystem::path ContainerPath = m_OutputPath;
+ ContainerPath.append(m_Name);
+
+ CreateDirectories(m_OutputPath);
+ BasicFile ContainerFile;
+ ContainerFile.Open(ContainerPath, BasicFile::Mode::kTruncate);
+ std::error_code Ec;
+ ContainerFile.WriteAll(Payload, Ec);
+ if (Ec)
+ {
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.Reason = Ec.message();
+ }
+ Result.RawHash = IoHash::HashBuffer(Payload);
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override
+ {
+ Stopwatch Timer;
+ SaveAttachmentResult Result;
+ std::filesystem::path ChunkPath = GetAttachmentPath(RawHash);
+ if (!std::filesystem::exists(ChunkPath))
+ {
+ try
+ {
+ CreateDirectories(ChunkPath.parent_path());
+
+ BasicFile ChunkFile;
+ ChunkFile.Open(ChunkPath, BasicFile::Mode::kTruncate);
+ size_t Offset = 0;
+ for (const SharedBuffer& Segment : Payload.GetSegments())
+ {
+ ChunkFile.Write(Segment.GetView(), Offset);
+ Offset += Segment.GetSize();
+ }
+ }
+ catch (std::exception& Ex)
+ {
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.Reason = Ex.what();
+ }
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override
+ {
+ Stopwatch Timer;
+
+ for (const SharedBuffer& Chunk : Chunks)
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(Chunk.AsIoBuffer());
+ SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash());
+ if (ChunkResult.ErrorCode)
+ {
+ ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return SaveAttachmentsResult{ChunkResult};
+ }
+ }
+ SaveAttachmentsResult Result;
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual Result FinalizeContainer(const IoHash&) override { return {}; }
+
+ virtual LoadContainerResult LoadContainer() override
+ {
+ Stopwatch Timer;
+ LoadContainerResult Result;
+ std::filesystem::path ContainerPath = m_OutputPath;
+ ContainerPath.append(m_Name);
+ if (!std::filesystem::is_regular_file(ContainerPath))
+ {
+ Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound);
+ Result.Reason = fmt::format("The file {} does not exist"sv, ContainerPath.string());
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ IoBuffer ContainerPayload;
+ {
+ BasicFile ContainerFile;
+ ContainerFile.Open(ContainerPath, BasicFile::Mode::kRead);
+ ContainerPayload = ContainerFile.ReadAll();
+ }
+ Result.ContainerObject = LoadCompactBinaryObject(ContainerPayload);
+ if (!Result.ContainerObject)
+ {
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.Reason = fmt::format("The file {} is not formatted as a compact binary object"sv, ContainerPath.string());
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override
+ {
+ Stopwatch Timer;
+ LoadAttachmentResult Result;
+ std::filesystem::path ChunkPath = GetAttachmentPath(RawHash);
+ if (!std::filesystem::is_regular_file(ChunkPath))
+ {
+ Result.ErrorCode = gsl::narrow<int>(HttpResponseCode::NotFound);
+ Result.Reason = fmt::format("The file {} does not exist"sv, ChunkPath.string());
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ {
+ BasicFile ChunkFile;
+ ChunkFile.Open(ChunkPath, BasicFile::Mode::kRead);
+ Result.Bytes = ChunkFile.ReadAll();
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override
+ {
+ Stopwatch Timer;
+ LoadAttachmentsResult Result;
+ for (const IoHash& Hash : RawHashes)
+ {
+ LoadAttachmentResult ChunkResult = LoadAttachment(Hash);
+ if (ChunkResult.ErrorCode)
+ {
+ ChunkResult.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return LoadAttachmentsResult{ChunkResult};
+ }
+ ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(ChunkResult.ElapsedSeconds * 1000)));
+ Result.Chunks.emplace_back(
+ std::pair<IoHash, CompressedBuffer>{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))});
+ }
+ return Result;
+ }
+
+private:
+ std::filesystem::path GetAttachmentPath(const IoHash& RawHash) const
+ {
+ ExtendablePathBuilder<128> ShardedPath;
+ ShardedPath.Append(m_OutputPath.c_str());
+ ExtendableStringBuilder<64> HashString;
+ RawHash.ToHexString(HashString);
+ const char* str = HashString.c_str();
+ ShardedPath.AppendSeparator();
+ ShardedPath.AppendAsciiRange(str, str + 3);
+
+ ShardedPath.AppendSeparator();
+ ShardedPath.AppendAsciiRange(str + 3, str + 5);
+
+ ShardedPath.AppendSeparator();
+ ShardedPath.AppendAsciiRange(str + 5, str + 40);
+
+ return ShardedPath.ToPath();
+ }
+
+ const std::string m_Name;
+ const std::filesystem::path m_OutputPath;
+ bool m_EnableBlocks = true;
+ bool m_UseTempBlocks = false;
+};
+
+std::unique_ptr<RemoteProjectStore>
+CreateFileRemoteStore(const FileRemoteStoreOptions& Options)
+{
+ std::unique_ptr<RemoteProjectStore> RemoteStore = std::make_unique<LocalExportProjectStore>(Options.Name,
+ std::filesystem::path(Options.FolderPath),
+ Options.ForceDisableBlocks,
+ Options.ForceEnableTempBlocks);
+ return RemoteStore;
+}
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/fileremoteprojectstore.h b/src/zenserver/projectstore/fileremoteprojectstore.h
new file mode 100644
index 000000000..68d1eb71e
--- /dev/null
+++ b/src/zenserver/projectstore/fileremoteprojectstore.h
@@ -0,0 +1,19 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "remoteprojectstore.h"
+
+namespace zen {
+
+struct FileRemoteStoreOptions : RemoteStoreOptions
+{
+ std::filesystem::path FolderPath;
+ std::string Name;
+ bool ForceDisableBlocks;
+ bool ForceEnableTempBlocks;
+};
+
+std::unique_ptr<RemoteProjectStore> CreateFileRemoteStore(const FileRemoteStoreOptions& Options);
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/jupiterremoteprojectstore.cpp b/src/zenserver/projectstore/jupiterremoteprojectstore.cpp
new file mode 100644
index 000000000..66cf3c4f8
--- /dev/null
+++ b/src/zenserver/projectstore/jupiterremoteprojectstore.cpp
@@ -0,0 +1,244 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "jupiterremoteprojectstore.h"
+
+#include <zencore/compress.h>
+#include <zencore/fmtutils.h>
+
+#include <auth/authmgr.h>
+#include <upstream/jupiter.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+class JupiterRemoteStore : public RemoteProjectStore
+{
+public:
+ JupiterRemoteStore(Ref<CloudCacheClient>&& CloudClient,
+ std::string_view Namespace,
+ std::string_view Bucket,
+ const IoHash& Key,
+ bool ForceDisableBlocks,
+ bool ForceDisableTempBlocks)
+ : m_CloudClient(CloudClient)
+ , m_Namespace(Namespace)
+ , m_Bucket(Bucket)
+ , m_Key(Key)
+ {
+ if (ForceDisableBlocks)
+ {
+ m_EnableBlocks = false;
+ }
+ if (ForceDisableTempBlocks)
+ {
+ m_UseTempBlocks = false;
+ }
+ }
+
+ virtual RemoteStoreInfo GetInfo() const override
+ {
+ return {.CreateBlocks = m_EnableBlocks,
+ .UseTempBlockFiles = m_UseTempBlocks,
+ .Description = fmt::format("[cloud] {} as {}/{}/{}"sv, m_CloudClient->ServiceUrl(), m_Namespace, m_Bucket, m_Key)};
+ }
+
+ virtual SaveResult SaveContainer(const IoBuffer& Payload) override
+ {
+ const int32_t MaxAttempts = 3;
+ PutRefResult Result;
+ {
+ CloudCacheSession Session(m_CloudClient.Get());
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutRef(m_Namespace, m_Bucket, m_Key, Payload, ZenContentType::kCbObject);
+ }
+ }
+
+ return SaveResult{ConvertResult(Result), {Result.Needs.begin(), Result.Needs.end()} /*, {}*/, IoHash::HashBuffer(Payload)};
+ }
+
+ virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override
+ {
+ const int32_t MaxAttempts = 3;
+ CloudCacheResult Result;
+ {
+ CloudCacheSession Session(m_CloudClient.Get());
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCompressedBlob(m_Namespace, RawHash, Payload);
+ }
+ }
+
+ return SaveAttachmentResult{ConvertResult(Result)};
+ }
+
+ virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override
+ {
+ SaveAttachmentsResult Result;
+ for (const SharedBuffer& Chunk : Chunks)
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(Chunk.AsIoBuffer());
+ SaveAttachmentResult ChunkResult = SaveAttachment(Compressed.GetCompressed(), Compressed.DecodeRawHash());
+ if (ChunkResult.ErrorCode)
+ {
+ return SaveAttachmentsResult{ChunkResult};
+ }
+ }
+ return Result;
+ }
+
+ virtual Result FinalizeContainer(const IoHash& RawHash) override
+ {
+ const int32_t MaxAttempts = 3;
+ CloudCacheResult Result;
+ {
+ CloudCacheSession Session(m_CloudClient.Get());
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.FinalizeRef(m_Namespace, m_Bucket, m_Key, RawHash);
+ }
+ }
+ return ConvertResult(Result);
+ }
+
+ virtual LoadContainerResult LoadContainer() override
+ {
+ const int32_t MaxAttempts = 3;
+ CloudCacheResult Result;
+ {
+ CloudCacheSession Session(m_CloudClient.Get());
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.GetRef(m_Namespace, m_Bucket, m_Key, ZenContentType::kCbObject);
+ }
+ }
+
+ if (Result.ErrorCode || !Result.Success)
+ {
+ return LoadContainerResult{ConvertResult(Result)};
+ }
+
+ CbObject ContainerObject = LoadCompactBinaryObject(Result.Response);
+ if (!ContainerObject)
+ {
+ return LoadContainerResult{
+ RemoteProjectStore::Result{
+ .ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Reason = fmt::format("The ref {}/{}/{} is not formatted as a compact binary object"sv, m_Namespace, m_Bucket, m_Key)},
+ std::move(ContainerObject)};
+ }
+
+ return LoadContainerResult{ConvertResult(Result), std::move(ContainerObject)};
+ }
+
+ virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override
+ {
+ const int32_t MaxAttempts = 3;
+ CloudCacheResult Result;
+ {
+ CloudCacheSession Session(m_CloudClient.Get());
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.GetCompressedBlob(m_Namespace, RawHash);
+ }
+ }
+ return LoadAttachmentResult{ConvertResult(Result), std::move(Result.Response)};
+ }
+
+ virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override
+ {
+ LoadAttachmentsResult Result;
+ for (const IoHash& Hash : RawHashes)
+ {
+ LoadAttachmentResult ChunkResult = LoadAttachment(Hash);
+ if (ChunkResult.ErrorCode)
+ {
+ return LoadAttachmentsResult{ChunkResult};
+ }
+ ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(ChunkResult.ElapsedSeconds * 1000)));
+ Result.Chunks.emplace_back(
+ std::pair<IoHash, CompressedBuffer>{Hash, CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult.Bytes))});
+ }
+ return Result;
+ }
+
+private:
+ static Result ConvertResult(const CloudCacheResult& Response)
+ {
+ std::string Text;
+ int32_t ErrorCode = 0;
+ if (Response.ErrorCode != 0)
+ {
+ ErrorCode = Response.ErrorCode;
+ }
+ else if (!Response.Success)
+ {
+ ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ if (Response.Response.GetContentType() == ZenContentType::kText)
+ {
+ Text =
+ std::string(reinterpret_cast<const std::string::value_type*>(Response.Response.GetData()), Response.Response.GetSize());
+ }
+ }
+ return {.ErrorCode = ErrorCode, .ElapsedSeconds = Response.ElapsedSeconds, .Reason = Response.Reason, .Text = Text};
+ }
+
+ Ref<CloudCacheClient> m_CloudClient;
+ const std::string m_Namespace;
+ const std::string m_Bucket;
+ const IoHash m_Key;
+ bool m_EnableBlocks = true;
+ bool m_UseTempBlocks = true;
+};
+
+std::unique_ptr<RemoteProjectStore>
+CreateJupiterRemoteStore(const JupiterRemoteStoreOptions& Options)
+{
+ std::string Url = Options.Url;
+ if (Url.find("://"sv) == std::string::npos)
+ {
+ // Assume https URL
+ Url = fmt::format("https://{}"sv, Url);
+ }
+ CloudCacheClientOptions ClientOptions{.Name = "Remote store"sv,
+ .ServiceUrl = Url,
+ .ConnectTimeout = std::chrono::milliseconds(2000),
+ .Timeout = std::chrono::milliseconds(60000)};
+ // 1) Access token as parameter in request
+ // 2) Environment variable (different win vs linux/mac)
+ // 3) openid-provider (assumes oidctoken.exe -Zen true has been run with matching Options.OpenIdProvider
+
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+ if (!Options.AccessToken.empty())
+ {
+ TokenProvider = CloudCacheTokenProvider::CreateFromCallback([AccessToken = Options.AccessToken]() {
+ return CloudCacheAccessToken{.Value = AccessToken, .ExpireTime = GcClock::TimePoint::max()};
+ });
+ }
+ else
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([&AuthManager = Options.AuthManager, OpenIdProvider = Options.OpenIdProvider]() {
+ AuthMgr::OpenIdAccessToken Token = AuthManager.GetOpenIdAccessToken(OpenIdProvider.empty() ? "Default" : OpenIdProvider);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+
+ Ref<CloudCacheClient> CloudClient(new CloudCacheClient(ClientOptions, std::move(TokenProvider)));
+
+ std::unique_ptr<RemoteProjectStore> RemoteStore = std::make_unique<JupiterRemoteStore>(std::move(CloudClient),
+ Options.Namespace,
+ Options.Bucket,
+ Options.Key,
+ Options.ForceDisableBlocks,
+ Options.ForceDisableTempBlocks);
+ return RemoteStore;
+}
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/jupiterremoteprojectstore.h b/src/zenserver/projectstore/jupiterremoteprojectstore.h
new file mode 100644
index 000000000..31548af22
--- /dev/null
+++ b/src/zenserver/projectstore/jupiterremoteprojectstore.h
@@ -0,0 +1,26 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "remoteprojectstore.h"
+
+namespace zen {
+
+class AuthMgr;
+
+struct JupiterRemoteStoreOptions : RemoteStoreOptions
+{
+ std::string Url;
+ std::string Namespace;
+ std::string Bucket;
+ IoHash Key;
+ std::string OpenIdProvider;
+ std::string AccessToken;
+ AuthMgr& AuthManager;
+ bool ForceDisableBlocks;
+ bool ForceDisableTempBlocks;
+};
+
+std::unique_ptr<RemoteProjectStore> CreateJupiterRemoteStore(const JupiterRemoteStoreOptions& Options);
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/projectstore.cpp b/src/zenserver/projectstore/projectstore.cpp
new file mode 100644
index 000000000..847a79a1d
--- /dev/null
+++ b/src/zenserver/projectstore/projectstore.cpp
@@ -0,0 +1,4082 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "projectstore.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+#include <zenhttp/httpshared.h>
+#include <zenstore/caslog.h>
+#include <zenstore/cidstore.h>
+#include <zenstore/scrubcontext.h>
+#include <zenutil/cache/rpcrecording.h>
+
+#include "fileremoteprojectstore.h"
+#include "jupiterremoteprojectstore.h"
+#include "remoteprojectstore.h"
+#include "zenremoteprojectstore.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <xxh3.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+#endif // ZEN_WITH_TESTS
+
+namespace zen {
+
+namespace {
+ bool PrepareDirectoryDelete(const std::filesystem::path& Dir, std::filesystem::path& OutDeleteDir)
+ {
+ int DropIndex = 0;
+ do
+ {
+ if (!std::filesystem::exists(Dir))
+ {
+ return true;
+ }
+
+ std::string DroppedName = fmt::format("[dropped]{}({})", Dir.filename().string(), DropIndex);
+ std::filesystem::path DroppedBucketPath = Dir.parent_path() / DroppedName;
+ if (std::filesystem::exists(DroppedBucketPath))
+ {
+ DropIndex++;
+ continue;
+ }
+
+ std::error_code Ec;
+ std::filesystem::rename(Dir, DroppedBucketPath, Ec);
+ if (!Ec)
+ {
+ OutDeleteDir = DroppedBucketPath;
+ return true;
+ }
+ if (Ec && !std::filesystem::exists(DroppedBucketPath))
+ {
+ // We can't move our folder, probably because it is busy, bail..
+ return false;
+ }
+ Sleep(100);
+ } while (true);
+ }
+
+ std::pair<std::unique_ptr<RemoteProjectStore>, std::string> CreateRemoteStore(CbObjectView Params,
+ AuthMgr& AuthManager,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize)
+ {
+ using namespace std::literals;
+
+ std::unique_ptr<RemoteProjectStore> RemoteStore;
+
+ if (CbObjectView File = Params["file"sv].AsObjectView(); File)
+ {
+ std::filesystem::path FolderPath(File["path"sv].AsString());
+ if (FolderPath.empty())
+ {
+ return {nullptr, "Missing file path"};
+ }
+ std::string_view Name(File["name"sv].AsString());
+ if (Name.empty())
+ {
+ return {nullptr, "Missing file name"};
+ }
+ bool ForceDisableBlocks = File["disableblocks"sv].AsBool(false);
+ bool ForceEnableTempBlocks = File["enabletempblocks"sv].AsBool(false);
+
+ FileRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize},
+ FolderPath,
+ std::string(Name),
+ ForceDisableBlocks,
+ ForceEnableTempBlocks};
+ RemoteStore = CreateFileRemoteStore(Options);
+ }
+
+ if (CbObjectView Cloud = Params["cloud"sv].AsObjectView(); Cloud)
+ {
+ std::string_view CloudServiceUrl = Cloud["url"sv].AsString();
+ if (CloudServiceUrl.empty())
+ {
+ return {nullptr, "Missing service url"};
+ }
+
+ std::string Url = cpr::util::urlDecode(std::string(CloudServiceUrl));
+ std::string_view Namespace = Cloud["namespace"sv].AsString();
+ if (Namespace.empty())
+ {
+ return {nullptr, "Missing namespace"};
+ }
+ std::string_view Bucket = Cloud["bucket"sv].AsString();
+ if (Bucket.empty())
+ {
+ return {nullptr, "Missing bucket"};
+ }
+ std::string_view OpenIdProvider = Cloud["openid-provider"sv].AsString();
+ std::string AccessToken = std::string(Cloud["access-token"sv].AsString());
+ if (AccessToken.empty())
+ {
+ std::string_view AccessTokenEnvVariable = Cloud["access-token-env"].AsString();
+ if (!AccessTokenEnvVariable.empty())
+ {
+ AccessToken = GetEnvVariable(AccessTokenEnvVariable);
+ }
+ }
+ std::string_view KeyParam = Cloud["key"sv].AsString();
+ if (KeyParam.empty())
+ {
+ return {nullptr, "Missing key"};
+ }
+ if (KeyParam.length() != IoHash::StringLength)
+ {
+ return {nullptr, "Invalid key"};
+ }
+ IoHash Key = IoHash::FromHexString(KeyParam);
+ if (Key == IoHash::Zero)
+ {
+ return {nullptr, "Invalid key string"};
+ }
+ bool ForceDisableBlocks = Cloud["disableblocks"sv].AsBool(false);
+ bool ForceDisableTempBlocks = Cloud["disabletempblocks"sv].AsBool(false);
+
+ JupiterRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize},
+ Url,
+ std::string(Namespace),
+ std::string(Bucket),
+ Key,
+ std::string(OpenIdProvider),
+ AccessToken,
+ AuthManager,
+ ForceDisableBlocks,
+ ForceDisableTempBlocks};
+ RemoteStore = CreateJupiterRemoteStore(Options);
+ }
+
+ if (CbObjectView Zen = Params["zen"sv].AsObjectView(); Zen)
+ {
+ std::string_view Url = Zen["url"sv].AsString();
+ std::string_view Project = Zen["project"sv].AsString();
+ if (Project.empty())
+ {
+ return {nullptr, "Missing project"};
+ }
+ std::string_view Oplog = Zen["oplog"sv].AsString();
+ if (Oplog.empty())
+ {
+ return {nullptr, "Missing oplog"};
+ }
+ ZenRemoteStoreOptions Options = {RemoteStoreOptions{.MaxBlockSize = MaxBlockSize, .MaxChunkEmbedSize = MaxChunkEmbedSize},
+ std::string(Url),
+ std::string(Project),
+ std::string(Oplog)};
+ RemoteStore = CreateZenRemoteStore(Options);
+ }
+
+ if (!RemoteStore)
+ {
+ return {nullptr, "Unknown remote store type"};
+ }
+
+ return {std::move(RemoteStore), ""};
+ }
+
+ std::pair<HttpResponseCode, std::string> ConvertResult(const RemoteProjectStore::Result& Result)
+ {
+ if (Result.ErrorCode == 0)
+ {
+ return {HttpResponseCode::OK, Result.Text};
+ }
+ return {static_cast<HttpResponseCode>(Result.ErrorCode),
+ Result.Reason.empty() ? Result.Text
+ : Result.Text.empty() ? Result.Reason
+ : fmt::format("{}. Reason: '{}'", Result.Text, Result.Reason)};
+ }
+
+ void CSVHeader(bool Details, bool AttachmentDetails, StringBuilderBase& CSVWriter)
+ {
+ if (AttachmentDetails)
+ {
+ CSVWriter << "Project, Oplog, LSN, Key, Cid, Size";
+ }
+ else if (Details)
+ {
+ CSVWriter << "Project, Oplog, LSN, Key, Size, AttachmentCount, AttachmentsSize";
+ }
+ else
+ {
+ CSVWriter << "Project, Oplog, Key";
+ }
+ }
+
+ void CSVWriteOp(CidStore& CidStore,
+ std::string_view ProjectId,
+ std::string_view OplogId,
+ bool Details,
+ bool AttachmentDetails,
+ int LSN,
+ const Oid& Key,
+ CbObject Op,
+ StringBuilderBase& CSVWriter)
+ {
+ StringBuilder<32> KeyStringBuilder;
+ Key.ToString(KeyStringBuilder);
+ const std::string_view KeyString = KeyStringBuilder.ToView();
+
+ SharedBuffer Buffer = Op.GetBuffer();
+ if (AttachmentDetails)
+ {
+ Op.IterateAttachments([&CidStore, &CSVWriter, &ProjectId, &OplogId, LSN, &KeyString](CbFieldView FieldView) {
+ const IoHash AttachmentHash = FieldView.AsAttachment();
+ IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash);
+ CSVWriter << "\r\n"
+ << ProjectId << ", " << OplogId << ", " << LSN << ", " << KeyString << ", " << AttachmentHash.ToHexString()
+ << ", " << gsl::narrow<uint64_t>(Attachment.GetSize());
+ });
+ }
+ else if (Details)
+ {
+ uint64_t AttachmentCount = 0;
+ size_t AttachmentsSize = 0;
+ Op.IterateAttachments([&CidStore, &AttachmentCount, &AttachmentsSize](CbFieldView FieldView) {
+ const IoHash AttachmentHash = FieldView.AsAttachment();
+ AttachmentCount++;
+ IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash);
+ AttachmentsSize += Attachment.GetSize();
+ });
+ CSVWriter << "\r\n"
+ << ProjectId << ", " << OplogId << ", " << LSN << ", " << KeyString << ", " << gsl::narrow<uint64_t>(Buffer.GetSize())
+ << ", " << AttachmentCount << ", " << gsl::narrow<uint64_t>(AttachmentsSize);
+ }
+ else
+ {
+ CSVWriter << "\r\n" << ProjectId << ", " << OplogId << ", " << KeyString;
+ }
+ };
+
+ void CbWriteOp(CidStore& CidStore,
+ bool Details,
+ bool OpDetails,
+ bool AttachmentDetails,
+ int LSN,
+ const Oid& Key,
+ CbObject Op,
+ CbObjectWriter& CbWriter)
+ {
+ CbWriter.BeginObject();
+ {
+ SharedBuffer Buffer = Op.GetBuffer();
+ CbWriter.AddObjectId("key", Key);
+ if (Details)
+ {
+ CbWriter.AddInteger("lsn", LSN);
+ CbWriter.AddInteger("size", gsl::narrow<uint64_t>(Buffer.GetSize()));
+ }
+ if (AttachmentDetails)
+ {
+ CbWriter.BeginArray("attachments");
+ Op.IterateAttachments([&CidStore, &CbWriter](CbFieldView FieldView) {
+ const IoHash AttachmentHash = FieldView.AsAttachment();
+ CbWriter.BeginObject();
+ {
+ IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash);
+ CbWriter.AddString("cid", AttachmentHash.ToHexString());
+ CbWriter.AddInteger("size", gsl::narrow<uint64_t>(Attachment.GetSize()));
+ }
+ CbWriter.EndObject();
+ });
+ CbWriter.EndArray();
+ }
+ else if (Details)
+ {
+ uint64_t AttachmentCount = 0;
+ size_t AttachmentsSize = 0;
+ Op.IterateAttachments([&CidStore, &AttachmentCount, &AttachmentsSize](CbFieldView FieldView) {
+ const IoHash AttachmentHash = FieldView.AsAttachment();
+ AttachmentCount++;
+ IoBuffer Attachment = CidStore.FindChunkByCid(AttachmentHash);
+ AttachmentsSize += Attachment.GetSize();
+ });
+ if (AttachmentCount > 0)
+ {
+ CbWriter.AddInteger("attachments", AttachmentCount);
+ CbWriter.AddInteger("attachmentssize", gsl::narrow<uint64_t>(AttachmentsSize));
+ }
+ }
+ if (OpDetails)
+ {
+ CbWriter.BeginObject("op");
+ for (const CbFieldView& Field : Op)
+ {
+ if (!Field.HasName())
+ {
+ CbWriter.AddField(Field);
+ continue;
+ }
+ std::string_view FieldName = Field.GetName();
+ CbWriter.AddField(FieldName, Field);
+ }
+ CbWriter.EndObject();
+ }
+ }
+ CbWriter.EndObject();
+ };
+
+ void CbWriteOplogOps(CidStore& CidStore,
+ ProjectStore::Oplog& Oplog,
+ bool Details,
+ bool OpDetails,
+ bool AttachmentDetails,
+ CbObjectWriter& Cbo)
+ {
+ Cbo.BeginArray("ops");
+ {
+ Oplog.IterateOplogWithKey([&Cbo, &CidStore, Details, OpDetails, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) {
+ CbWriteOp(CidStore, Details, OpDetails, AttachmentDetails, LSN, Key, Op, Cbo);
+ });
+ }
+ Cbo.EndArray();
+ }
+
+ void CbWriteOplog(CidStore& CidStore,
+ ProjectStore::Oplog& Oplog,
+ bool Details,
+ bool OpDetails,
+ bool AttachmentDetails,
+ CbObjectWriter& Cbo)
+ {
+ Cbo.BeginObject();
+ {
+ Cbo.AddString("name", Oplog.OplogId());
+ CbWriteOplogOps(CidStore, Oplog, Details, OpDetails, AttachmentDetails, Cbo);
+ }
+ Cbo.EndObject();
+ }
+
+ void CbWriteOplogs(CidStore& CidStore,
+ ProjectStore::Project& Project,
+ std::vector<std::string> OpLogs,
+ bool Details,
+ bool OpDetails,
+ bool AttachmentDetails,
+ CbObjectWriter& Cbo)
+ {
+ Cbo.BeginArray("oplogs");
+ {
+ for (const std::string& OpLogId : OpLogs)
+ {
+ ProjectStore::Oplog* Oplog = Project.OpenOplog(OpLogId);
+ if (Oplog != nullptr)
+ {
+ CbWriteOplog(CidStore, *Oplog, Details, OpDetails, AttachmentDetails, Cbo);
+ }
+ }
+ }
+ Cbo.EndArray();
+ }
+
+ void CbWriteProject(CidStore& CidStore,
+ ProjectStore::Project& Project,
+ std::vector<std::string> OpLogs,
+ bool Details,
+ bool OpDetails,
+ bool AttachmentDetails,
+ CbObjectWriter& Cbo)
+ {
+ Cbo.BeginObject();
+ {
+ Cbo.AddString("name", Project.Identifier);
+ CbWriteOplogs(CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo);
+ }
+ Cbo.EndObject();
+ }
+
+} // namespace
+
+//////////////////////////////////////////////////////////////////////////
+
+Oid
+OpKeyStringAsOId(std::string_view OpKey)
+{
+ using namespace std::literals;
+
+ CbObjectWriter Writer;
+ Writer << "key"sv << OpKey;
+
+ XXH3_128Stream KeyHasher;
+ Writer.Save()["key"sv].WriteToStream([&](const void* Data, size_t Size) { KeyHasher.Append(Data, Size); });
+ XXH3_128 KeyHash = KeyHasher.GetHash();
+
+ Oid OpId;
+ memcpy(OpId.OidBits, &KeyHash, sizeof(OpId.OidBits));
+
+ return OpId;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct ProjectStore::OplogStorage : public RefCounted
+{
+ OplogStorage(ProjectStore::Oplog* OwnerOplog, std::filesystem::path BasePath) : m_OwnerOplog(OwnerOplog), m_OplogStoragePath(BasePath)
+ {
+ }
+
+ ~OplogStorage()
+ {
+ ZEN_INFO("closing oplog storage at {}", m_OplogStoragePath);
+ Flush();
+ }
+
+ [[nodiscard]] bool Exists() { return Exists(m_OplogStoragePath); }
+ [[nodiscard]] static bool Exists(std::filesystem::path BasePath)
+ {
+ return std::filesystem::exists(BasePath / "ops.zlog") && std::filesystem::exists(BasePath / "ops.zops");
+ }
+
+ static bool Delete(std::filesystem::path BasePath) { return DeleteDirectories(BasePath); }
+
+ uint64_t OpBlobsSize() const
+ {
+ RwLock::SharedLockScope _(m_RwLock);
+ return m_NextOpsOffset;
+ }
+
+ void Open(bool IsCreate)
+ {
+ using namespace std::literals;
+
+ ZEN_INFO("initializing oplog storage at '{}'", m_OplogStoragePath);
+
+ if (IsCreate)
+ {
+ DeleteDirectories(m_OplogStoragePath);
+ CreateDirectories(m_OplogStoragePath);
+ }
+
+ m_Oplog.Open(m_OplogStoragePath / "ops.zlog"sv, IsCreate ? CasLogFile::Mode::kTruncate : CasLogFile::Mode::kWrite);
+ m_Oplog.Initialize();
+
+ m_OpBlobs.Open(m_OplogStoragePath / "ops.zops"sv, IsCreate ? BasicFile::Mode::kTruncate : BasicFile::Mode::kWrite);
+
+ ZEN_ASSERT(IsPow2(m_OpsAlign));
+ ZEN_ASSERT(!(m_NextOpsOffset & (m_OpsAlign - 1)));
+ }
+
+ void ReplayLog(std::function<void(CbObject, const OplogEntry&)>&& Handler)
+ {
+ ZEN_TRACE_CPU("ProjectStore::OplogStorage::ReplayLog");
+
+ // This could use memory mapping or do something clever but for now it just reads the file sequentially
+
+ ZEN_INFO("replaying log for '{}'", m_OplogStoragePath);
+
+ Stopwatch Timer;
+
+ uint64_t InvalidEntries = 0;
+
+ IoBuffer OpBuffer;
+ m_Oplog.Replay(
+ [&](const OplogEntry& LogEntry) {
+ if (LogEntry.OpCoreSize == 0)
+ {
+ ++InvalidEntries;
+
+ return;
+ }
+
+ if (OpBuffer.GetSize() < LogEntry.OpCoreSize)
+ {
+ OpBuffer = IoBuffer(LogEntry.OpCoreSize);
+ }
+
+ const uint64_t OpFileOffset = LogEntry.OpCoreOffset * m_OpsAlign;
+
+ m_OpBlobs.Read((void*)OpBuffer.Data(), LogEntry.OpCoreSize, OpFileOffset);
+
+ // Verify checksum, ignore op data if incorrect
+ const auto OpCoreHash = uint32_t(XXH3_64bits(OpBuffer.Data(), LogEntry.OpCoreSize) & 0xffffFFFF);
+
+ if (OpCoreHash != LogEntry.OpCoreHash)
+ {
+ ZEN_WARN("skipping oplog entry with bad checksum!");
+ return;
+ }
+
+ CbObject Op(SharedBuffer::MakeView(OpBuffer.Data(), LogEntry.OpCoreSize));
+
+ m_NextOpsOffset =
+ Max(m_NextOpsOffset.load(std::memory_order_relaxed), RoundUp(OpFileOffset + LogEntry.OpCoreSize, m_OpsAlign));
+ m_MaxLsn = Max(m_MaxLsn.load(std::memory_order_relaxed), LogEntry.OpLsn);
+
+ Handler(Op, LogEntry);
+ },
+ 0);
+
+ if (InvalidEntries)
+ {
+ ZEN_WARN("ignored {} zero-sized oplog entries", InvalidEntries);
+ }
+
+ ZEN_INFO("Oplog replay completed in {} - Max LSN# {}, Next offset: {}",
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()),
+ m_MaxLsn,
+ m_NextOpsOffset);
+ }
+
+ void ReplayLog(const std::vector<OplogEntryAddress>& Entries, std::function<void(CbObject)>&& Handler)
+ {
+ for (const OplogEntryAddress& Entry : Entries)
+ {
+ CbObject Op = GetOp(Entry);
+ Handler(Op);
+ }
+ }
+
+ CbObject GetOp(const OplogEntryAddress& Entry)
+ {
+ IoBuffer OpBuffer(Entry.Size);
+
+ const uint64_t OpFileOffset = Entry.Offset * m_OpsAlign;
+ m_OpBlobs.Read((void*)OpBuffer.Data(), Entry.Size, OpFileOffset);
+
+ return CbObject(SharedBuffer(std::move(OpBuffer)));
+ }
+
+ OplogEntry AppendOp(SharedBuffer Buffer, uint32_t OpCoreHash, XXH3_128 KeyHash)
+ {
+ ZEN_TRACE_CPU("ProjectStore::OplogStorage::AppendOp");
+
+ using namespace std::literals;
+
+ uint64_t WriteSize = Buffer.GetSize();
+
+ RwLock::ExclusiveLockScope Lock(m_RwLock);
+ const uint64_t WriteOffset = m_NextOpsOffset;
+ const uint32_t OpLsn = ++m_MaxLsn;
+ m_NextOpsOffset = RoundUp(WriteOffset + WriteSize, m_OpsAlign);
+ Lock.ReleaseNow();
+
+ ZEN_ASSERT(IsMultipleOf(WriteOffset, m_OpsAlign));
+
+ OplogEntry Entry = {.OpLsn = OpLsn,
+ .OpCoreOffset = gsl::narrow_cast<uint32_t>(WriteOffset / m_OpsAlign),
+ .OpCoreSize = uint32_t(Buffer.GetSize()),
+ .OpCoreHash = OpCoreHash,
+ .OpKeyHash = KeyHash};
+
+ m_Oplog.Append(Entry);
+ m_OpBlobs.Write(Buffer.GetData(), WriteSize, WriteOffset);
+
+ return Entry;
+ }
+
+ void Flush()
+ {
+ m_Oplog.Flush();
+ m_OpBlobs.Flush();
+ }
+
+ spdlog::logger& Log() { return m_OwnerOplog->Log(); }
+
+private:
+ ProjectStore::Oplog* m_OwnerOplog;
+ std::filesystem::path m_OplogStoragePath;
+ mutable RwLock m_RwLock;
+ TCasLogFile<OplogEntry> m_Oplog;
+ BasicFile m_OpBlobs;
+ std::atomic<uint64_t> m_NextOpsOffset{0};
+ uint64_t m_OpsAlign = 32;
+ std::atomic<uint32_t> m_MaxLsn{0};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+ProjectStore::Oplog::Oplog(std::string_view Id,
+ Project* Project,
+ CidStore& Store,
+ std::filesystem::path BasePath,
+ const std::filesystem::path& MarkerPath)
+: m_OuterProject(Project)
+, m_CidStore(Store)
+, m_BasePath(BasePath)
+, m_MarkerPath(MarkerPath)
+, m_OplogId(Id)
+{
+ using namespace std::literals;
+
+ m_Storage = new OplogStorage(this, m_BasePath);
+ const bool StoreExists = m_Storage->Exists();
+ m_Storage->Open(/* IsCreate */ !StoreExists);
+
+ m_TempPath = m_BasePath / "temp"sv;
+
+ CleanDirectory(m_TempPath);
+}
+
+ProjectStore::Oplog::~Oplog()
+{
+ if (m_Storage)
+ {
+ Flush();
+ }
+}
+
+void
+ProjectStore::Oplog::Flush()
+{
+ ZEN_ASSERT(m_Storage);
+ m_Storage->Flush();
+}
+
+void
+ProjectStore::Oplog::Scrub(ScrubContext& Ctx) const
+{
+ ZEN_UNUSED(Ctx);
+}
+
+void
+ProjectStore::Oplog::GatherReferences(GcContext& GcCtx)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+
+ std::vector<IoHash> Hashes;
+ Hashes.reserve(Max(m_ChunkMap.size(), m_MetaMap.size()));
+
+ for (const auto& Kv : m_ChunkMap)
+ {
+ Hashes.push_back(Kv.second);
+ }
+
+ GcCtx.AddRetainedCids(Hashes);
+
+ Hashes.clear();
+
+ for (const auto& Kv : m_MetaMap)
+ {
+ Hashes.push_back(Kv.second);
+ }
+
+ GcCtx.AddRetainedCids(Hashes);
+}
+
+uint64_t
+ProjectStore::Oplog::TotalSize() const
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (m_Storage)
+ {
+ return m_Storage->OpBlobsSize();
+ }
+ return 0;
+}
+
+bool
+ProjectStore::Oplog::IsExpired() const
+{
+ if (m_MarkerPath.empty())
+ {
+ return false;
+ }
+ return !std::filesystem::exists(m_MarkerPath);
+}
+
+std::filesystem::path
+ProjectStore::Oplog::PrepareForDelete(bool MoveFolder)
+{
+ RwLock::ExclusiveLockScope _(m_OplogLock);
+ m_ChunkMap.clear();
+ m_MetaMap.clear();
+ m_FileMap.clear();
+ m_OpAddressMap.clear();
+ m_LatestOpMap.clear();
+ m_Storage = {};
+ if (!MoveFolder)
+ {
+ return {};
+ }
+ std::filesystem::path MovedDir;
+ if (PrepareDirectoryDelete(m_BasePath, MovedDir))
+ {
+ return MovedDir;
+ }
+ return {};
+}
+
+bool
+ProjectStore::Oplog::ExistsAt(std::filesystem::path BasePath)
+{
+ using namespace std::literals;
+
+ std::filesystem::path StateFilePath = BasePath / "oplog.zcb"sv;
+ return std::filesystem::is_regular_file(StateFilePath);
+}
+
+void
+ProjectStore::Oplog::Read()
+{
+ using namespace std::literals;
+
+ std::filesystem::path StateFilePath = m_BasePath / "oplog.zcb"sv;
+ if (std::filesystem::is_regular_file(StateFilePath))
+ {
+ ZEN_INFO("reading config for oplog '{}' in project '{}' from {}", m_OplogId, m_OuterProject->Identifier, StateFilePath);
+
+ BasicFile Blob;
+ Blob.Open(StateFilePath, BasicFile::Mode::kRead);
+
+ IoBuffer Obj = Blob.ReadAll();
+ CbValidateError ValidationError = ValidateCompactBinary(MemoryView(Obj.Data(), Obj.Size()), CbValidateMode::All);
+
+ if (ValidationError != CbValidateError::None)
+ {
+ ZEN_ERROR("validation error {} hit for '{}'", int(ValidationError), StateFilePath);
+ return;
+ }
+
+ CbObject Cfg = LoadCompactBinaryObject(Obj);
+
+ m_MarkerPath = Cfg["gcpath"sv].AsString();
+ }
+ else
+ {
+ ZEN_INFO("config for oplog '{}' in project '{}' not found at {}. Assuming legacy store",
+ m_OplogId,
+ m_OuterProject->Identifier,
+ StateFilePath);
+ }
+ ReplayLog();
+}
+
+void
+ProjectStore::Oplog::Write()
+{
+ using namespace std::literals;
+
+ BinaryWriter Mem;
+
+ CbObjectWriter Cfg;
+
+ Cfg << "gcpath"sv << PathToUtf8(m_MarkerPath);
+
+ Cfg.Save(Mem);
+
+ std::filesystem::path StateFilePath = m_BasePath / "oplog.zcb"sv;
+
+ ZEN_INFO("persisting config for oplog '{}' in project '{}' to {}", m_OplogId, m_OuterProject->Identifier, StateFilePath);
+
+ BasicFile Blob;
+ Blob.Open(StateFilePath, BasicFile::Mode::kTruncate);
+ Blob.Write(Mem.Data(), Mem.Size(), 0);
+ Blob.Flush();
+}
+
+void
+ProjectStore::Oplog::ReplayLog()
+{
+ RwLock::ExclusiveLockScope OplogLock(m_OplogLock);
+ if (!m_Storage)
+ {
+ return;
+ }
+ m_Storage->ReplayLog(
+ [&](CbObject Op, const OplogEntry& OpEntry) { RegisterOplogEntry(OplogLock, GetMapping(Op), OpEntry, kUpdateReplay); });
+}
+
+IoBuffer
+ProjectStore::Oplog::FindChunk(Oid ChunkId)
+{
+ RwLock::SharedLockScope OplogLock(m_OplogLock);
+ if (!m_Storage)
+ {
+ return IoBuffer{};
+ }
+
+ if (auto ChunkIt = m_ChunkMap.find(ChunkId); ChunkIt != m_ChunkMap.end())
+ {
+ IoHash ChunkHash = ChunkIt->second;
+ OplogLock.ReleaseNow();
+
+ IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkHash);
+ Chunk.SetContentType(ZenContentType::kCompressedBinary);
+
+ return Chunk;
+ }
+
+ if (auto FileIt = m_FileMap.find(ChunkId); FileIt != m_FileMap.end())
+ {
+ std::filesystem::path FilePath = m_OuterProject->RootDir / FileIt->second.ServerPath;
+
+ OplogLock.ReleaseNow();
+
+ IoBuffer FileChunk = IoBufferBuilder::MakeFromFile(FilePath);
+ FileChunk.SetContentType(ZenContentType::kBinary);
+
+ return FileChunk;
+ }
+
+ if (auto MetaIt = m_MetaMap.find(ChunkId); MetaIt != m_MetaMap.end())
+ {
+ IoHash ChunkHash = MetaIt->second;
+ OplogLock.ReleaseNow();
+
+ IoBuffer Chunk = m_CidStore.FindChunkByCid(ChunkHash);
+ Chunk.SetContentType(ZenContentType::kCompressedBinary);
+
+ return Chunk;
+ }
+
+ return {};
+}
+
+void
+ProjectStore::Oplog::IterateFileMap(
+ std::function<void(const Oid&, const std::string_view& ServerPath, const std::string_view& ClientPath)>&& Fn)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return;
+ }
+
+ for (const auto& Kv : m_FileMap)
+ {
+ Fn(Kv.first, Kv.second.ServerPath, Kv.second.ClientPath);
+ }
+}
+
+void
+ProjectStore::Oplog::IterateOplog(std::function<void(CbObject)>&& Handler)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return;
+ }
+
+ std::vector<OplogEntryAddress> Entries;
+ Entries.reserve(m_LatestOpMap.size());
+
+ for (const auto& Kv : m_LatestOpMap)
+ {
+ const auto AddressEntry = m_OpAddressMap.find(Kv.second);
+ ZEN_ASSERT(AddressEntry != m_OpAddressMap.end());
+
+ Entries.push_back(AddressEntry->second);
+ }
+
+ std::sort(Entries.begin(), Entries.end(), [](const OplogEntryAddress& Lhs, const OplogEntryAddress& Rhs) {
+ return Lhs.Offset < Rhs.Offset;
+ });
+
+ m_Storage->ReplayLog(Entries, [&](CbObject Op) { Handler(Op); });
+}
+
+void
+ProjectStore::Oplog::IterateOplogWithKey(std::function<void(int, const Oid&, CbObject)>&& Handler)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return;
+ }
+
+ std::vector<size_t> EntryIndexes;
+ std::vector<OplogEntryAddress> Entries;
+ std::vector<Oid> Keys;
+ std::vector<int> LSNs;
+ Entries.reserve(m_LatestOpMap.size());
+ EntryIndexes.reserve(m_LatestOpMap.size());
+ Keys.reserve(m_LatestOpMap.size());
+ LSNs.reserve(m_LatestOpMap.size());
+
+ for (const auto& Kv : m_LatestOpMap)
+ {
+ const auto AddressEntry = m_OpAddressMap.find(Kv.second);
+ ZEN_ASSERT(AddressEntry != m_OpAddressMap.end());
+
+ Entries.push_back(AddressEntry->second);
+ Keys.push_back(Kv.first);
+ LSNs.push_back(Kv.second);
+ EntryIndexes.push_back(EntryIndexes.size());
+ }
+
+ std::sort(EntryIndexes.begin(), EntryIndexes.end(), [&Entries](const size_t& Lhs, const size_t& Rhs) {
+ const OplogEntryAddress& LhsEntry = Entries[Lhs];
+ const OplogEntryAddress& RhsEntry = Entries[Rhs];
+ return LhsEntry.Offset < RhsEntry.Offset;
+ });
+ std::vector<OplogEntryAddress> SortedEntries;
+ SortedEntries.reserve(EntryIndexes.size());
+ for (size_t Index : EntryIndexes)
+ {
+ SortedEntries.push_back(Entries[Index]);
+ }
+
+ size_t EntryIndex = 0;
+ m_Storage->ReplayLog(SortedEntries, [&](CbObject Op) {
+ Handler(LSNs[EntryIndex], Keys[EntryIndex], Op);
+ EntryIndex++;
+ });
+}
+
+int
+ProjectStore::Oplog::GetOpIndexByKey(const Oid& Key)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return {};
+ }
+ if (const auto LatestOp = m_LatestOpMap.find(Key); LatestOp != m_LatestOpMap.end())
+ {
+ return LatestOp->second;
+ }
+ return -1;
+}
+
+std::optional<CbObject>
+ProjectStore::Oplog::GetOpByKey(const Oid& Key)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return {};
+ }
+
+ if (const auto LatestOp = m_LatestOpMap.find(Key); LatestOp != m_LatestOpMap.end())
+ {
+ const auto AddressEntry = m_OpAddressMap.find(LatestOp->second);
+ ZEN_ASSERT(AddressEntry != m_OpAddressMap.end());
+
+ return m_Storage->GetOp(AddressEntry->second);
+ }
+
+ return {};
+}
+
+std::optional<CbObject>
+ProjectStore::Oplog::GetOpByIndex(int Index)
+{
+ RwLock::SharedLockScope _(m_OplogLock);
+ if (!m_Storage)
+ {
+ return {};
+ }
+
+ if (const auto AddressEntryIt = m_OpAddressMap.find(Index); AddressEntryIt != m_OpAddressMap.end())
+ {
+ return m_Storage->GetOp(AddressEntryIt->second);
+ }
+
+ return {};
+}
+
+void
+ProjectStore::Oplog::AddFileMapping(const RwLock::ExclusiveLockScope&,
+ Oid FileId,
+ IoHash Hash,
+ std::string_view ServerPath,
+ std::string_view ClientPath)
+{
+ if (Hash != IoHash::Zero)
+ {
+ m_ChunkMap.insert_or_assign(FileId, Hash);
+ }
+
+ FileMapEntry Entry;
+ Entry.ServerPath = ServerPath;
+ Entry.ClientPath = ClientPath;
+
+ m_FileMap[FileId] = std::move(Entry);
+
+ if (Hash != IoHash::Zero)
+ {
+ m_ChunkMap.insert_or_assign(FileId, Hash);
+ }
+}
+
+void
+ProjectStore::Oplog::AddChunkMapping(const RwLock::ExclusiveLockScope&, Oid ChunkId, IoHash Hash)
+{
+ m_ChunkMap.insert_or_assign(ChunkId, Hash);
+}
+
+void
+ProjectStore::Oplog::AddMetaMapping(const RwLock::ExclusiveLockScope&, Oid ChunkId, IoHash Hash)
+{
+ m_MetaMap.insert_or_assign(ChunkId, Hash);
+}
+
+ProjectStore::Oplog::OplogEntryMapping
+ProjectStore::Oplog::GetMapping(CbObject Core)
+{
+ using namespace std::literals;
+
+ OplogEntryMapping Result;
+
+ // Update chunk id maps
+ CbObjectView PackageObj = Core["package"sv].AsObjectView();
+ CbArrayView BulkDataArray = Core["bulkdata"sv].AsArrayView();
+ CbArrayView PackageDataArray = Core["packagedata"sv].AsArrayView();
+ Result.Chunks.reserve(PackageObj ? 1 : 0 + BulkDataArray.Num() + PackageDataArray.Num());
+
+ if (PackageObj)
+ {
+ Oid Id = PackageObj["id"sv].AsObjectId();
+ IoHash Hash = PackageObj["data"sv].AsBinaryAttachment();
+ Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash});
+ ZEN_DEBUG("package data {} -> {}", Id, Hash);
+ }
+
+ for (CbFieldView& Entry : PackageDataArray)
+ {
+ CbObjectView PackageDataObj = Entry.AsObjectView();
+ Oid Id = PackageDataObj["id"sv].AsObjectId();
+ IoHash Hash = PackageDataObj["data"sv].AsBinaryAttachment();
+ Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash});
+ ZEN_DEBUG("package {} -> {}", Id, Hash);
+ }
+
+ for (CbFieldView& Entry : BulkDataArray)
+ {
+ CbObjectView BulkObj = Entry.AsObjectView();
+ Oid Id = BulkObj["id"sv].AsObjectId();
+ IoHash Hash = BulkObj["data"sv].AsBinaryAttachment();
+ Result.Chunks.emplace_back(OplogEntryMapping::Mapping{Id, Hash});
+ ZEN_DEBUG("bulkdata {} -> {}", Id, Hash);
+ }
+
+ CbArrayView FilesArray = Core["files"sv].AsArrayView();
+ Result.Files.reserve(FilesArray.Num());
+ for (CbFieldView& Entry : FilesArray)
+ {
+ CbObjectView FileObj = Entry.AsObjectView();
+
+ std::string_view ServerPath = FileObj["serverpath"sv].AsString();
+ std::string_view ClientPath = FileObj["clientpath"sv].AsString();
+ if (ServerPath.empty() || ClientPath.empty())
+ {
+ ZEN_WARN("invalid file");
+ continue;
+ }
+
+ Oid Id = FileObj["id"sv].AsObjectId();
+ IoHash Hash = FileObj["data"sv].AsBinaryAttachment();
+ Result.Files.emplace_back(
+ OplogEntryMapping::FileMapping{OplogEntryMapping::Mapping{Id, Hash}, std::string(ServerPath), std::string(ClientPath)});
+ ZEN_DEBUG("file {} -> {}, ServerPath: {}, ClientPath: {}", Id, Hash, ServerPath, ClientPath);
+ }
+
+ CbArrayView MetaArray = Core["meta"sv].AsArrayView();
+ Result.Meta.reserve(MetaArray.Num());
+ for (CbFieldView& Entry : MetaArray)
+ {
+ CbObjectView MetaObj = Entry.AsObjectView();
+ Oid Id = MetaObj["id"sv].AsObjectId();
+ IoHash Hash = MetaObj["data"sv].AsBinaryAttachment();
+ Result.Meta.emplace_back(OplogEntryMapping::Mapping{Id, Hash});
+ auto NameString = MetaObj["name"sv].AsString();
+ ZEN_DEBUG("meta data ({}) {} -> {}", NameString, Id, Hash);
+ }
+
+ return Result;
+}
+
+uint32_t
+ProjectStore::Oplog::RegisterOplogEntry(RwLock::ExclusiveLockScope& OplogLock,
+ const OplogEntryMapping& OpMapping,
+ const OplogEntry& OpEntry,
+ UpdateType TypeOfUpdate)
+{
+ ZEN_TRACE_CPU("ProjectStore::Oplog::RegisterOplogEntry");
+
+ ZEN_UNUSED(TypeOfUpdate);
+
+ // For now we're assuming the update is all in-memory so we can hold an exclusive lock without causing
+ // too many problems. Longer term we'll probably want to ensure we can do concurrent updates however
+
+ using namespace std::literals;
+
+ // Update chunk id maps
+ for (const OplogEntryMapping::Mapping& Chunk : OpMapping.Chunks)
+ {
+ AddChunkMapping(OplogLock, Chunk.Id, Chunk.Hash);
+ }
+
+ for (const OplogEntryMapping::FileMapping& File : OpMapping.Files)
+ {
+ AddFileMapping(OplogLock, File.Id, File.Hash, File.ServerPath, File.ClientPath);
+ }
+
+ for (const OplogEntryMapping::Mapping& Meta : OpMapping.Meta)
+ {
+ AddMetaMapping(OplogLock, Meta.Id, Meta.Hash);
+ }
+
+ m_OpAddressMap.emplace(OpEntry.OpLsn, OplogEntryAddress{.Offset = OpEntry.OpCoreOffset, .Size = OpEntry.OpCoreSize});
+ m_LatestOpMap[OpEntry.OpKeyAsOId()] = OpEntry.OpLsn;
+
+ return OpEntry.OpLsn;
+}
+
+uint32_t
+ProjectStore::Oplog::AppendNewOplogEntry(CbPackage OpPackage)
+{
+ ZEN_TRACE_CPU("ProjectStore::Oplog::AppendNewOplogEntry");
+
+ const CbObject& Core = OpPackage.GetObject();
+ const uint32_t EntryId = AppendNewOplogEntry(Core);
+ if (EntryId == 0xffffffffu)
+ {
+ // The oplog has been deleted so just drop this
+ return EntryId;
+ }
+
+ // Persist attachments after oplog entry so GC won't find attachments without references
+
+ uint64_t AttachmentBytes = 0;
+ uint64_t NewAttachmentBytes = 0;
+
+ auto Attachments = OpPackage.GetAttachments();
+
+ for (const auto& Attach : Attachments)
+ {
+ ZEN_ASSERT(Attach.IsCompressedBinary());
+
+ CompressedBuffer AttachmentData = Attach.AsCompressedBinary();
+ const uint64_t AttachmentSize = AttachmentData.DecodeRawSize();
+ CidStore::InsertResult InsertResult = m_CidStore.AddChunk(AttachmentData.GetCompressed().Flatten().AsIoBuffer(), Attach.GetHash());
+
+ if (InsertResult.New)
+ {
+ NewAttachmentBytes += AttachmentSize;
+ }
+ AttachmentBytes += AttachmentSize;
+ }
+
+ ZEN_DEBUG("oplog entry #{} attachments: {} new, {} total", EntryId, NiceBytes(NewAttachmentBytes), NiceBytes(AttachmentBytes));
+
+ return EntryId;
+}
+
+uint32_t
+ProjectStore::Oplog::AppendNewOplogEntry(CbObject Core)
+{
+ ZEN_TRACE_CPU("ProjectStore::Oplog::AppendNewOplogEntry");
+
+ using namespace std::literals;
+
+ OplogEntryMapping Mapping = GetMapping(Core);
+
+ SharedBuffer Buffer = Core.GetBuffer();
+ const uint64_t WriteSize = Buffer.GetSize();
+ const auto OpCoreHash = uint32_t(XXH3_64bits(Buffer.GetData(), WriteSize) & 0xffffFFFF);
+
+ ZEN_ASSERT(WriteSize != 0);
+
+ XXH3_128Stream KeyHasher;
+ Core["key"sv].WriteToStream([&](const void* Data, size_t Size) { KeyHasher.Append(Data, Size); });
+ XXH3_128 KeyHash = KeyHasher.GetHash();
+
+ RefPtr<OplogStorage> Storage;
+ {
+ RwLock::SharedLockScope _(m_OplogLock);
+ Storage = m_Storage;
+ }
+ if (!m_Storage)
+ {
+ return 0xffffffffu;
+ }
+ const OplogEntry OpEntry = m_Storage->AppendOp(Buffer, OpCoreHash, KeyHash);
+
+ RwLock::ExclusiveLockScope OplogLock(m_OplogLock);
+ const uint32_t EntryId = RegisterOplogEntry(OplogLock, Mapping, OpEntry, kUpdateNewEntry);
+
+ return EntryId;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ProjectStore::Project::Project(ProjectStore* PrjStore, CidStore& Store, std::filesystem::path BasePath)
+: m_ProjectStore(PrjStore)
+, m_CidStore(Store)
+, m_OplogStoragePath(BasePath)
+{
+}
+
+ProjectStore::Project::~Project()
+{
+}
+
+bool
+ProjectStore::Project::Exists(std::filesystem::path BasePath)
+{
+ return std::filesystem::exists(BasePath / "Project.zcb");
+}
+
+void
+ProjectStore::Project::Read()
+{
+ using namespace std::literals;
+
+ std::filesystem::path ProjectStateFilePath = m_OplogStoragePath / "Project.zcb"sv;
+
+ ZEN_INFO("reading config for project '{}' from {}", Identifier, ProjectStateFilePath);
+
+ BasicFile Blob;
+ Blob.Open(ProjectStateFilePath, BasicFile::Mode::kRead);
+
+ IoBuffer Obj = Blob.ReadAll();
+ CbValidateError ValidationError = ValidateCompactBinary(MemoryView(Obj.Data(), Obj.Size()), CbValidateMode::All);
+
+ if (ValidationError == CbValidateError::None)
+ {
+ CbObject Cfg = LoadCompactBinaryObject(Obj);
+
+ Identifier = Cfg["id"sv].AsString();
+ RootDir = Cfg["root"sv].AsString();
+ ProjectRootDir = Cfg["project"sv].AsString();
+ EngineRootDir = Cfg["engine"sv].AsString();
+ ProjectFilePath = Cfg["projectfile"sv].AsString();
+ }
+ else
+ {
+ ZEN_ERROR("validation error {} hit for '{}'", int(ValidationError), ProjectStateFilePath);
+ }
+}
+
+void
+ProjectStore::Project::Write()
+{
+ using namespace std::literals;
+
+ BinaryWriter Mem;
+
+ CbObjectWriter Cfg;
+ Cfg << "id"sv << Identifier;
+ Cfg << "root"sv << PathToUtf8(RootDir);
+ Cfg << "project"sv << ProjectRootDir;
+ Cfg << "engine"sv << EngineRootDir;
+ Cfg << "projectfile"sv << ProjectFilePath;
+
+ Cfg.Save(Mem);
+
+ CreateDirectories(m_OplogStoragePath);
+
+ std::filesystem::path ProjectStateFilePath = m_OplogStoragePath / "Project.zcb"sv;
+
+ ZEN_INFO("persisting config for project '{}' to {}", Identifier, ProjectStateFilePath);
+
+ BasicFile Blob;
+ Blob.Open(ProjectStateFilePath, BasicFile::Mode::kTruncate);
+ Blob.Write(Mem.Data(), Mem.Size(), 0);
+ Blob.Flush();
+}
+
+spdlog::logger&
+ProjectStore::Project::Log()
+{
+ return m_ProjectStore->Log();
+}
+
+std::filesystem::path
+ProjectStore::Project::BasePathForOplog(std::string_view OplogId)
+{
+ return m_OplogStoragePath / OplogId;
+}
+
+ProjectStore::Oplog*
+ProjectStore::Project::NewOplog(std::string_view OplogId, const std::filesystem::path& MarkerPath)
+{
+ RwLock::ExclusiveLockScope _(m_ProjectLock);
+
+ std::filesystem::path OplogBasePath = BasePathForOplog(OplogId);
+
+ try
+ {
+ Oplog* Log = m_Oplogs
+ .try_emplace(std::string{OplogId},
+ std::make_unique<ProjectStore::Oplog>(OplogId, this, m_CidStore, OplogBasePath, MarkerPath))
+ .first->second.get();
+
+ Log->Write();
+ return Log;
+ }
+ catch (std::exception&)
+ {
+ // In case of failure we need to ensure there's no half constructed entry around
+ //
+ // (This is probably already ensured by the try_emplace implementation?)
+
+ m_Oplogs.erase(std::string{OplogId});
+
+ return nullptr;
+ }
+}
+
+ProjectStore::Oplog*
+ProjectStore::Project::OpenOplog(std::string_view OplogId)
+{
+ {
+ RwLock::SharedLockScope _(m_ProjectLock);
+
+ auto OplogIt = m_Oplogs.find(std::string(OplogId));
+
+ if (OplogIt != m_Oplogs.end())
+ {
+ return OplogIt->second.get();
+ }
+ }
+
+ RwLock::ExclusiveLockScope _(m_ProjectLock);
+
+ std::filesystem::path OplogBasePath = BasePathForOplog(OplogId);
+
+ if (Oplog::ExistsAt(OplogBasePath))
+ {
+ // Do open of existing oplog
+
+ try
+ {
+ Oplog* Log =
+ m_Oplogs
+ .try_emplace(std::string{OplogId},
+ std::make_unique<ProjectStore::Oplog>(OplogId, this, m_CidStore, OplogBasePath, std::filesystem::path{}))
+ .first->second.get();
+ Log->Read();
+
+ return Log;
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_WARN("failed to open oplog '{}' @ '{}': {}", OplogId, OplogBasePath, ex.what());
+
+ m_Oplogs.erase(std::string{OplogId});
+ }
+ }
+
+ return nullptr;
+}
+
+void
+ProjectStore::Project::DeleteOplog(std::string_view OplogId)
+{
+ std::filesystem::path DeletePath;
+ {
+ RwLock::ExclusiveLockScope _(m_ProjectLock);
+
+ auto OplogIt = m_Oplogs.find(std::string(OplogId));
+
+ if (OplogIt != m_Oplogs.end())
+ {
+ std::unique_ptr<Oplog>& Oplog = OplogIt->second;
+ DeletePath = Oplog->PrepareForDelete(true);
+ m_DeletedOplogs.emplace_back(std::move(Oplog));
+ m_Oplogs.erase(OplogIt);
+ }
+ }
+
+ // Erase content on disk
+ if (!DeletePath.empty())
+ {
+ OplogStorage::Delete(DeletePath);
+ }
+}
+
+std::vector<std::string>
+ProjectStore::Project::ScanForOplogs() const
+{
+ DirectoryContent DirContent;
+ GetDirectoryContent(m_OplogStoragePath, DirectoryContent::IncludeDirsFlag, DirContent);
+ std::vector<std::string> Oplogs;
+ Oplogs.reserve(DirContent.Directories.size());
+ for (const std::filesystem::path& DirPath : DirContent.Directories)
+ {
+ Oplogs.push_back(DirPath.filename().string());
+ }
+ return Oplogs;
+}
+
+void
+ProjectStore::Project::IterateOplogs(std::function<void(const Oplog&)>&& Fn) const
+{
+ RwLock::SharedLockScope _(m_ProjectLock);
+
+ for (auto& Kv : m_Oplogs)
+ {
+ Fn(*Kv.second);
+ }
+}
+
+void
+ProjectStore::Project::IterateOplogs(std::function<void(Oplog&)>&& Fn)
+{
+ RwLock::SharedLockScope _(m_ProjectLock);
+
+ for (auto& Kv : m_Oplogs)
+ {
+ Fn(*Kv.second);
+ }
+}
+
+void
+ProjectStore::Project::Flush()
+{
+ // We only need to flush oplogs that we have already loaded
+ IterateOplogs([&](Oplog& Ops) { Ops.Flush(); });
+}
+
+void
+ProjectStore::Project::Scrub(ScrubContext& Ctx)
+{
+ // Scrubbing needs to check all existing oplogs
+ std::vector<std::string> OpLogs = ScanForOplogs();
+ for (const std::string& OpLogId : OpLogs)
+ {
+ OpenOplog(OpLogId);
+ }
+ IterateOplogs([&](const Oplog& Ops) {
+ if (!Ops.IsExpired())
+ {
+ Ops.Scrub(Ctx);
+ }
+ });
+}
+
+void
+ProjectStore::Project::GatherReferences(GcContext& GcCtx)
+{
+ ZEN_TRACE_CPU("ProjectStore::Project::GatherReferences");
+
+ Stopwatch Timer;
+ const auto Guard = MakeGuard([&] {
+ ZEN_DEBUG("gathered references from project store project {} in {}", Identifier, NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ // GatherReferences needs to check all existing oplogs
+ std::vector<std::string> OpLogs = ScanForOplogs();
+ for (const std::string& OpLogId : OpLogs)
+ {
+ OpenOplog(OpLogId);
+ }
+ IterateOplogs([&](Oplog& Ops) {
+ if (!Ops.IsExpired())
+ {
+ Ops.GatherReferences(GcCtx);
+ }
+ });
+}
+
+uint64_t
+ProjectStore::Project::TotalSize() const
+{
+ uint64_t Result = 0;
+ {
+ RwLock::SharedLockScope _(m_ProjectLock);
+ for (const auto& It : m_Oplogs)
+ {
+ Result += It.second->TotalSize();
+ }
+ }
+ return Result;
+}
+
+bool
+ProjectStore::Project::PrepareForDelete(std::filesystem::path& OutDeletePath)
+{
+ RwLock::ExclusiveLockScope _(m_ProjectLock);
+
+ for (auto& It : m_Oplogs)
+ {
+ // We don't care about the moved folder
+ It.second->PrepareForDelete(false);
+ m_DeletedOplogs.emplace_back(std::move(It.second));
+ }
+
+ m_Oplogs.clear();
+
+ bool Success = PrepareDirectoryDelete(m_OplogStoragePath, OutDeletePath);
+ if (!Success)
+ {
+ return false;
+ }
+ m_OplogStoragePath.clear();
+ return true;
+}
+
+bool
+ProjectStore::Project::IsExpired() const
+{
+ if (ProjectFilePath.empty())
+ {
+ return false;
+ }
+ return !std::filesystem::exists(ProjectFilePath);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+ProjectStore::ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcManager& Gc)
+: GcStorage(Gc)
+, GcContributor(Gc)
+, m_Log(logging::Get("project"))
+, m_CidStore(Store)
+, m_ProjectBasePath(BasePath)
+{
+ ZEN_INFO("initializing project store at '{}'", BasePath);
+ // m_Log.set_level(spdlog::level::debug);
+}
+
+ProjectStore::~ProjectStore()
+{
+ ZEN_INFO("closing project store ('{}')", m_ProjectBasePath);
+}
+
+std::filesystem::path
+ProjectStore::BasePathForProject(std::string_view ProjectId)
+{
+ return m_ProjectBasePath / ProjectId;
+}
+
+void
+ProjectStore::DiscoverProjects()
+{
+ if (!std::filesystem::exists(m_ProjectBasePath))
+ {
+ return;
+ }
+
+ DirectoryContent DirContent;
+ GetDirectoryContent(m_ProjectBasePath, DirectoryContent::IncludeDirsFlag, DirContent);
+
+ for (const std::filesystem::path& DirPath : DirContent.Directories)
+ {
+ std::string DirName = PathToUtf8(DirPath.filename());
+ OpenProject(DirName);
+ }
+}
+
+void
+ProjectStore::IterateProjects(std::function<void(Project& Prj)>&& Fn)
+{
+ RwLock::SharedLockScope _(m_ProjectsLock);
+
+ for (auto& Kv : m_Projects)
+ {
+ Fn(*Kv.second.Get());
+ }
+}
+
+void
+ProjectStore::Flush()
+{
+ std::vector<Ref<Project>> Projects;
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+ Projects.reserve(m_Projects.size());
+
+ for (auto& Kv : m_Projects)
+ {
+ Projects.push_back(Kv.second);
+ }
+ }
+ for (const Ref<Project>& Project : Projects)
+ {
+ Project->Flush();
+ }
+}
+
+void
+ProjectStore::Scrub(ScrubContext& Ctx)
+{
+ DiscoverProjects();
+
+ std::vector<Ref<Project>> Projects;
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+ Projects.reserve(m_Projects.size());
+
+ for (auto& Kv : m_Projects)
+ {
+ if (Kv.second->IsExpired())
+ {
+ continue;
+ }
+ Projects.push_back(Kv.second);
+ }
+ }
+ for (const Ref<Project>& Project : Projects)
+ {
+ Project->Scrub(Ctx);
+ }
+}
+
+void
+ProjectStore::GatherReferences(GcContext& GcCtx)
+{
+ ZEN_TRACE_CPU("ProjectStore::GatherReferences");
+
+ size_t ProjectCount = 0;
+ size_t ExpiredProjectCount = 0;
+ Stopwatch Timer;
+ const auto Guard = MakeGuard([&] {
+ ZEN_DEBUG("gathered references from '{}' in {}, found {} active projects and {} expired projects",
+ m_ProjectBasePath.string(),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()),
+ ProjectCount,
+ ExpiredProjectCount);
+ });
+
+ DiscoverProjects();
+
+ std::vector<Ref<Project>> Projects;
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+ Projects.reserve(m_Projects.size());
+
+ for (auto& Kv : m_Projects)
+ {
+ if (Kv.second->IsExpired())
+ {
+ ExpiredProjectCount++;
+ continue;
+ }
+ Projects.push_back(Kv.second);
+ }
+ }
+ ProjectCount = Projects.size();
+ for (const Ref<Project>& Project : Projects)
+ {
+ Project->GatherReferences(GcCtx);
+ }
+}
+
+void
+ProjectStore::CollectGarbage(GcContext& GcCtx)
+{
+ ZEN_TRACE_CPU("ProjectStore::CollectGarbage");
+
+ size_t ProjectCount = 0;
+ size_t ExpiredProjectCount = 0;
+
+ Stopwatch Timer;
+ const auto Guard = MakeGuard([&] {
+ ZEN_DEBUG("garbage collect from '{}' DONE after {}, found {} active projects and {} expired projects",
+ m_ProjectBasePath.string(),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()),
+ ProjectCount,
+ ExpiredProjectCount);
+ });
+ std::vector<Ref<Project>> ExpiredProjects;
+ std::vector<Ref<Project>> Projects;
+
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+ for (auto& Kv : m_Projects)
+ {
+ if (Kv.second->IsExpired())
+ {
+ ExpiredProjects.push_back(Kv.second);
+ ExpiredProjectCount++;
+ continue;
+ }
+ Projects.push_back(Kv.second);
+ ProjectCount++;
+ }
+ }
+
+ if (!GcCtx.IsDeletionMode())
+ {
+ ZEN_DEBUG("garbage collect DISABLED, for '{}' ", m_ProjectBasePath.string());
+ return;
+ }
+
+ for (const Ref<Project>& Project : Projects)
+ {
+ std::vector<std::string> ExpiredOplogs;
+ {
+ RwLock::ExclusiveLockScope _(m_ProjectsLock);
+ Project->IterateOplogs([&ExpiredOplogs](ProjectStore::Oplog& Oplog) {
+ if (Oplog.IsExpired())
+ {
+ ExpiredOplogs.push_back(Oplog.OplogId());
+ }
+ });
+ }
+ for (const std::string& OplogId : ExpiredOplogs)
+ {
+ ZEN_DEBUG("ProjectStore::CollectGarbage garbage collected oplog '{}' in project '{}'. Removing storage on disk",
+ OplogId,
+ Project->Identifier);
+ Project->DeleteOplog(OplogId);
+ }
+ }
+
+ if (ExpiredProjects.empty())
+ {
+ ZEN_DEBUG("garbage collect for '{}', no expired projects found", m_ProjectBasePath.string());
+ return;
+ }
+
+ for (const Ref<Project>& Project : ExpiredProjects)
+ {
+ std::filesystem::path PathToRemove;
+ std::string ProjectId;
+ {
+ RwLock::ExclusiveLockScope _(m_ProjectsLock);
+ if (!Project->IsExpired())
+ {
+ ZEN_DEBUG("ProjectStore::CollectGarbage skipped garbage collect of project '{}'. Project no longer expired.", ProjectId);
+ continue;
+ }
+ bool Success = Project->PrepareForDelete(PathToRemove);
+ if (!Success)
+ {
+ ZEN_DEBUG("ProjectStore::CollectGarbage skipped garbage collect of project '{}'. Project folder is locked.", ProjectId);
+ continue;
+ }
+ m_Projects.erase(Project->Identifier);
+ ProjectId = Project->Identifier;
+ }
+
+ ZEN_DEBUG("ProjectStore::CollectGarbage garbage collected project '{}'. Removing storage on disk", ProjectId);
+ if (PathToRemove.empty())
+ {
+ continue;
+ }
+
+ DeleteDirectories(PathToRemove);
+ }
+}
+
+GcStorageSize
+ProjectStore::StorageSize() const
+{
+ GcStorageSize Result;
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+ for (auto& Kv : m_Projects)
+ {
+ const Ref<Project>& Project = Kv.second;
+ Result.DiskSize += Project->TotalSize();
+ }
+ }
+ return Result;
+}
+
+Ref<ProjectStore::Project>
+ProjectStore::OpenProject(std::string_view ProjectId)
+{
+ {
+ RwLock::SharedLockScope _(m_ProjectsLock);
+
+ auto ProjIt = m_Projects.find(std::string{ProjectId});
+
+ if (ProjIt != m_Projects.end())
+ {
+ return ProjIt->second;
+ }
+ }
+
+ RwLock::ExclusiveLockScope _(m_ProjectsLock);
+
+ std::filesystem::path BasePath = BasePathForProject(ProjectId);
+
+ if (Project::Exists(BasePath))
+ {
+ try
+ {
+ ZEN_INFO("opening project {} @ {}", ProjectId, BasePath);
+
+ Ref<Project>& Prj =
+ m_Projects
+ .try_emplace(std::string{ProjectId}, Ref<ProjectStore::Project>(new ProjectStore::Project(this, m_CidStore, BasePath)))
+ .first->second;
+ Prj->Identifier = ProjectId;
+ Prj->Read();
+ return Prj;
+ }
+ catch (std::exception& e)
+ {
+ ZEN_WARN("failed to open {} @ {} ({})", ProjectId, BasePath, e.what());
+ m_Projects.erase(std::string{ProjectId});
+ }
+ }
+
+ return {};
+}
+
+Ref<ProjectStore::Project>
+ProjectStore::NewProject(std::filesystem::path BasePath,
+ std::string_view ProjectId,
+ std::string_view RootDir,
+ std::string_view EngineRootDir,
+ std::string_view ProjectRootDir,
+ std::string_view ProjectFilePath)
+{
+ RwLock::ExclusiveLockScope _(m_ProjectsLock);
+
+ Ref<Project>& Prj =
+ m_Projects.try_emplace(std::string{ProjectId}, Ref<ProjectStore::Project>(new ProjectStore::Project(this, m_CidStore, BasePath)))
+ .first->second;
+ Prj->Identifier = ProjectId;
+ Prj->RootDir = RootDir;
+ Prj->EngineRootDir = EngineRootDir;
+ Prj->ProjectRootDir = ProjectRootDir;
+ Prj->ProjectFilePath = ProjectFilePath;
+ Prj->Write();
+
+ return Prj;
+}
+
+bool
+ProjectStore::DeleteProject(std::string_view ProjectId)
+{
+ ZEN_INFO("deleting project {}", ProjectId);
+
+ RwLock::ExclusiveLockScope ProjectsLock(m_ProjectsLock);
+
+ auto ProjIt = m_Projects.find(std::string{ProjectId});
+
+ if (ProjIt == m_Projects.end())
+ {
+ return true;
+ }
+
+ std::filesystem::path DeletePath;
+ bool Success = ProjIt->second->PrepareForDelete(DeletePath);
+
+ if (!Success)
+ {
+ return false;
+ }
+ m_Projects.erase(ProjIt);
+ ProjectsLock.ReleaseNow();
+
+ if (!DeletePath.empty())
+ {
+ DeleteDirectories(DeletePath);
+ }
+ return true;
+}
+
+bool
+ProjectStore::Exists(std::string_view ProjectId)
+{
+ return Project::Exists(BasePathForProject(ProjectId));
+}
+
+CbArray
+ProjectStore::GetProjectsList()
+{
+ using namespace std::literals;
+
+ DiscoverProjects();
+
+ CbWriter Response;
+ Response.BeginArray();
+
+ IterateProjects([&Response](ProjectStore::Project& Prj) {
+ Response.BeginObject();
+ Response << "Id"sv << Prj.Identifier;
+ Response << "RootDir"sv << Prj.RootDir.string();
+ Response << "ProjectRootDir"sv << Prj.ProjectRootDir;
+ Response << "EngineRootDir"sv << Prj.EngineRootDir;
+ Response << "ProjectFilePath"sv << Prj.ProjectFilePath;
+ Response.EndObject();
+ });
+ Response.EndArray();
+ return Response.Save().AsArray();
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::GetProjectFiles(const std::string_view ProjectId, const std::string_view OplogId, bool FilterClient, CbObject& OutPayload)
+{
+ using namespace std::literals;
+
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Project files request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Project files for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ CbObjectWriter Response;
+ Response.BeginArray("files"sv);
+
+ FoundLog->IterateFileMap([&](const Oid& Id, const std::string_view& ServerPath, const std::string_view& ClientPath) {
+ Response.BeginObject();
+ Response << "id"sv << Id;
+ Response << "clientpath"sv << ClientPath;
+ if (!FilterClient)
+ {
+ Response << "serverpath"sv << ServerPath;
+ }
+ Response.EndObject();
+ });
+
+ Response.EndArray();
+ OutPayload = Response.Save();
+ return {HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::GetChunkInfo(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view ChunkId,
+ CbObject& OutPayload)
+{
+ using namespace std::literals;
+
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk info request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk info request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+ if (ChunkId.size() != 2 * sizeof(Oid::OidBits))
+ {
+ return {HttpResponseCode::BadRequest,
+ fmt::format("Chunk info request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId)};
+ }
+
+ const Oid Obj = Oid::FromHexString(ChunkId);
+
+ IoBuffer Chunk = FoundLog->FindChunk(Obj);
+ if (!Chunk)
+ {
+ return {HttpResponseCode::NotFound, {}};
+ }
+
+ uint64_t ChunkSize = Chunk.GetSize();
+ if (Chunk.GetContentType() == HttpContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ bool IsCompressed = CompressedBuffer::ValidateCompressedHeader(Chunk, RawHash, RawSize);
+ ZEN_ASSERT(IsCompressed);
+ ChunkSize = RawSize;
+ }
+
+ CbObjectWriter Response;
+ Response << "size"sv << ChunkSize;
+ OutPayload = Response.Save();
+ return {HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::GetChunkRange(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view ChunkId,
+ uint64_t Offset,
+ uint64_t Size,
+ ZenContentType AcceptType,
+ IoBuffer& OutChunk)
+{
+ bool IsOffset = Offset != 0 || Size != ~(0ull);
+
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ if (ChunkId.size() != 2 * sizeof(Oid::OidBits))
+ {
+ return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId)};
+ }
+
+ const Oid Obj = Oid::FromHexString(ChunkId);
+
+ IoBuffer Chunk = FoundLog->FindChunk(Obj);
+ if (!Chunk)
+ {
+ return {HttpResponseCode::NotFound, {}};
+ }
+
+ OutChunk = Chunk;
+ HttpContentType ContentType = Chunk.GetContentType();
+
+ if (Chunk.GetContentType() == HttpContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(std::move(Chunk)), RawHash, RawSize);
+ ZEN_ASSERT(!Compressed.IsNull());
+
+ if (IsOffset)
+ {
+ if ((Offset + Size) > RawSize)
+ {
+ Size = RawSize - Offset;
+ }
+
+ if (AcceptType == HttpContentType::kBinary)
+ {
+ OutChunk = Compressed.Decompress(Offset, Size).AsIoBuffer();
+ OutChunk.SetContentType(HttpContentType::kBinary);
+ }
+ else
+ {
+ // Value will be a range of compressed blocks that covers the requested range
+ // The client will have to compensate for any offsets that do not land on an even block size multiple
+ OutChunk = Compressed.CopyRange(Offset, Size).GetCompressed().Flatten().AsIoBuffer();
+ OutChunk.SetContentType(HttpContentType::kCompressedBinary);
+ }
+ }
+ else
+ {
+ if (AcceptType == HttpContentType::kBinary)
+ {
+ OutChunk = Compressed.Decompress().AsIoBuffer();
+ OutChunk.SetContentType(HttpContentType::kBinary);
+ }
+ else
+ {
+ OutChunk = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ OutChunk.SetContentType(HttpContentType::kCompressedBinary);
+ }
+ }
+ }
+ else if (IsOffset)
+ {
+ if ((Offset + Size) > Chunk.GetSize())
+ {
+ Size = Chunk.GetSize() - Offset;
+ }
+ OutChunk = IoBuffer(std::move(Chunk), Offset, Size);
+ OutChunk.SetContentType(ContentType);
+ }
+
+ return {HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::GetChunk(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view Cid,
+ ZenContentType AcceptType,
+ IoBuffer& OutChunk)
+{
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ if (Cid.length() != IoHash::StringLength)
+ {
+ return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, Cid)};
+ }
+
+ const IoHash Hash = IoHash::FromHexString(Cid);
+ OutChunk = m_CidStore.FindChunkByCid(Hash);
+
+ if (!OutChunk)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("chunk - '{}' MISSING", Cid)};
+ }
+
+ if (AcceptType == ZenContentType::kUnknownContentType || AcceptType == ZenContentType::kBinary)
+ {
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(OutChunk));
+ OutChunk = Compressed.Decompress().AsIoBuffer();
+ OutChunk.SetContentType(ZenContentType::kBinary);
+ }
+ else
+ {
+ OutChunk.SetContentType(ZenContentType::kCompressedBinary);
+ }
+ return {HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::PutChunk(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view Cid,
+ ZenContentType ContentType,
+ IoBuffer&& Chunk)
+{
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk put request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Chunk put request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ if (Cid.length() != IoHash::StringLength)
+ {
+ return {HttpResponseCode::BadRequest, fmt::format("Chunk put request for invalid chunk hash '{}'", Cid)};
+ }
+
+ const IoHash Hash = IoHash::FromHexString(Cid);
+
+ if (ContentType != HttpContentType::kCompressedBinary)
+ {
+ return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid content type for chunk '{}'", Cid)};
+ }
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), RawHash, RawSize);
+ if (RawHash != Hash)
+ {
+ return {HttpResponseCode::BadRequest, fmt::format("Chunk request for invalid payload format for chunk '{}'", Cid)};
+ }
+
+ CidStore::InsertResult Result = m_CidStore.AddChunk(Chunk, Hash);
+ return {Result.New ? HttpResponseCode::Created : HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::WriteOplog(const std::string_view ProjectId, const std::string_view OplogId, IoBuffer&& Payload, CbObject& OutResponse)
+{
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Write oplog request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId);
+
+ if (!Oplog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Write oplog request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ CbObject ContainerObject = LoadCompactBinaryObject(Payload);
+ if (!ContainerObject)
+ {
+ return {HttpResponseCode::BadRequest, "Invalid payload format"};
+ }
+
+ CidStore& ChunkStore = m_CidStore;
+ RwLock AttachmentsLock;
+ std::unordered_set<IoHash, IoHash::Hasher> Attachments;
+
+ auto HasAttachment = [&ChunkStore](const IoHash& RawHash) { return ChunkStore.ContainsChunk(RawHash); };
+ auto OnNeedBlock = [&AttachmentsLock, &Attachments](const IoHash& BlockHash, const std::vector<IoHash>&& ChunkHashes) {
+ RwLock::ExclusiveLockScope _(AttachmentsLock);
+ if (BlockHash != IoHash::Zero)
+ {
+ Attachments.insert(BlockHash);
+ }
+ else
+ {
+ Attachments.insert(ChunkHashes.begin(), ChunkHashes.end());
+ }
+ };
+ auto OnNeedAttachment = [&AttachmentsLock, &Attachments](const IoHash& RawHash) {
+ RwLock::ExclusiveLockScope _(AttachmentsLock);
+ Attachments.insert(RawHash);
+ };
+
+ RemoteProjectStore::Result RemoteResult = SaveOplogContainer(*Oplog, ContainerObject, HasAttachment, OnNeedBlock, OnNeedAttachment);
+
+ if (RemoteResult.ErrorCode)
+ {
+ return ConvertResult(RemoteResult);
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+ {
+ for (const IoHash& Hash : Attachments)
+ {
+ ZEN_DEBUG("Need attachment {}", Hash);
+ Cbo << Hash;
+ }
+ }
+ Cbo.EndArray(); // "need"
+
+ OutResponse = Cbo.Save();
+ return {HttpResponseCode::OK, {}};
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::ReadOplog(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const HttpServerRequest::QueryParams& Params,
+ CbObject& OutResponse)
+{
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Read oplog request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId);
+
+ if (!Oplog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Read oplog request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ size_t MaxBlockSize = 128u * 1024u * 1024u;
+ if (auto Param = Params.GetValue("maxblocksize"); Param.empty() == false)
+ {
+ if (auto Value = ParseInt<size_t>(Param))
+ {
+ MaxBlockSize = Value.value();
+ }
+ }
+ size_t MaxChunkEmbedSize = 1024u * 1024u;
+ if (auto Param = Params.GetValue("maxchunkembedsize"); Param.empty() == false)
+ {
+ if (auto Value = ParseInt<size_t>(Param))
+ {
+ MaxChunkEmbedSize = Value.value();
+ }
+ }
+
+ CidStore& ChunkStore = m_CidStore;
+
+ RemoteProjectStore::LoadContainerResult ContainerResult = BuildContainer(
+ ChunkStore,
+ *Oplog,
+ MaxBlockSize,
+ MaxChunkEmbedSize,
+ false,
+ [](CompressedBuffer&&, const IoHash) {},
+ [](const IoHash&) {},
+ [](const std::unordered_set<IoHash, IoHash::Hasher>) {});
+
+ OutResponse = std::move(ContainerResult.ContainerObject);
+ return ConvertResult(ContainerResult);
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::WriteBlock(const std::string_view ProjectId, const std::string_view OplogId, IoBuffer&& Payload)
+{
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Write block request for unknown project '{}'", ProjectId)};
+ }
+
+ ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId);
+
+ if (!Oplog)
+ {
+ return {HttpResponseCode::NotFound, fmt::format("Write block request for unknown oplog '{}/{}'", ProjectId, OplogId)};
+ }
+
+ if (!IterateBlock(std::move(Payload), [this](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) {
+ IoBuffer Compressed = Chunk.GetCompressed().Flatten().AsIoBuffer();
+ m_CidStore.AddChunk(Compressed, AttachmentRawHash);
+ ZEN_DEBUG("Saved attachment {} from block, size {}", AttachmentRawHash, Compressed.GetSize());
+ }))
+ {
+ return {HttpResponseCode::BadRequest, "Invalid chunk in block"};
+ }
+
+ return {HttpResponseCode::OK, {}};
+}
+
+void
+ProjectStore::Rpc(HttpServerRequest& HttpReq,
+ const std::string_view ProjectId,
+ const std::string_view OplogId,
+ IoBuffer&& Payload,
+ AuthMgr& AuthManager)
+{
+ using namespace std::literals;
+ HttpContentType PayloadContentType = HttpReq.RequestContentType();
+ CbPackage Package;
+ CbObject Cb;
+ switch (PayloadContentType)
+ {
+ case HttpContentType::kJSON:
+ case HttpContentType::kUnknownContentType:
+ case HttpContentType::kText:
+ {
+ std::string JsonText(reinterpret_cast<const char*>(Payload.GetData()), Payload.GetSize());
+ Cb = LoadCompactBinaryFromJson(JsonText).AsObject();
+ if (!Cb)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Content format not supported, expected JSON format");
+ }
+ }
+ break;
+ case HttpContentType::kCbObject:
+ Cb = LoadCompactBinaryObject(Payload);
+ if (!Cb)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Content format not supported, expected compact binary format");
+ }
+ break;
+ case HttpContentType::kCbPackage:
+ Package = ParsePackageMessage(Payload);
+ Cb = Package.GetObject();
+ if (!Cb)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ "Content format not supported, expected package message format");
+ }
+ break;
+ default:
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid request content type");
+ }
+
+ Ref<ProjectStore::Project> Project = OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("Rpc oplog request for unknown project '{}'", ProjectId));
+ }
+
+ ProjectStore::Oplog* Oplog = Project->OpenOplog(OplogId);
+
+ if (!Oplog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("Rpc oplog request for unknown oplog '{}/{}'", ProjectId, OplogId));
+ }
+
+ std::string_view Method = Cb["method"sv].AsString();
+
+ if (Method == "import")
+ {
+ std::pair<HttpResponseCode, std::string> Result = Import(*Project.Get(), *Oplog, Cb["params"sv].AsObjectView(), AuthManager);
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ }
+ else if (Method == "export")
+ {
+ std::pair<HttpResponseCode, std::string> Result = Export(*Project.Get(), *Oplog, Cb["params"sv].AsObjectView(), AuthManager);
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ }
+ else if (Method == "getchunks")
+ {
+ CbPackage ResponsePackage;
+ {
+ CbArrayView ChunksArray = Cb["chunks"sv].AsArrayView();
+ CbObjectWriter ResponseWriter;
+ ResponseWriter.BeginArray("chunks"sv);
+ for (CbFieldView FieldView : ChunksArray)
+ {
+ IoHash RawHash = FieldView.AsHash();
+ IoBuffer ChunkBuffer = m_CidStore.FindChunkByCid(RawHash);
+ if (ChunkBuffer)
+ {
+ ResponseWriter.AddHash(RawHash);
+ ResponsePackage.AddAttachment(
+ CbAttachment(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkBuffer)), RawHash));
+ }
+ }
+ ResponseWriter.EndArray();
+ ResponsePackage.SetObject(ResponseWriter.Save());
+ }
+ CompositeBuffer RpcResponseBuffer = FormatPackageMessageBuffer(ResponsePackage, FormatFlags::kDefault);
+ return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kCbPackage, RpcResponseBuffer);
+ }
+ else if (Method == "putchunks")
+ {
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ IoHash RawHash = Attachment.GetHash();
+ CompressedBuffer Compressed = Attachment.AsCompressedBinary();
+ m_CidStore.AddChunk(Compressed.GetCompressed().Flatten().AsIoBuffer(), RawHash, CidStore::InsertMode::kCopyOnly);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, fmt::format("Unknown rpc method '{}'", Method));
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::Export(ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, CbObjectView&& Params, AuthMgr& AuthManager)
+{
+ using namespace std::literals;
+
+ size_t MaxBlockSize = Params["maxblocksize"sv].AsUInt64(128u * 1024u * 1024u);
+ size_t MaxChunkEmbedSize = Params["maxchunkembedsize"sv].AsUInt64(1024u * 1024u);
+ bool Force = Params["force"sv].AsBool(false);
+
+ std::pair<std::unique_ptr<RemoteProjectStore>, std::string> RemoteStoreResult =
+ CreateRemoteStore(Params, AuthManager, MaxBlockSize, MaxChunkEmbedSize);
+
+ if (RemoteStoreResult.first == nullptr)
+ {
+ return {HttpResponseCode::BadRequest, RemoteStoreResult.second};
+ }
+ std::unique_ptr<RemoteProjectStore> RemoteStore = std::move(RemoteStoreResult.first);
+ RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo();
+
+ ZEN_INFO("Saving oplog '{}/{}' to {}, maxblocksize {}, maxchunkembedsize {}",
+ Project.Identifier,
+ Oplog.OplogId(),
+ StoreInfo.Description,
+ NiceBytes(MaxBlockSize),
+ NiceBytes(MaxChunkEmbedSize));
+
+ RemoteProjectStore::Result Result = SaveOplog(m_CidStore,
+ *RemoteStore,
+ Oplog,
+ MaxBlockSize,
+ MaxChunkEmbedSize,
+ StoreInfo.CreateBlocks,
+ StoreInfo.UseTempBlockFiles,
+ Force);
+
+ return ConvertResult(Result);
+}
+
+std::pair<HttpResponseCode, std::string>
+ProjectStore::Import(ProjectStore::Project& Project, ProjectStore::Oplog& Oplog, CbObjectView&& Params, AuthMgr& AuthManager)
+{
+ using namespace std::literals;
+
+ size_t MaxBlockSize = Params["maxblocksize"sv].AsUInt64(128u * 1024u * 1024u);
+ size_t MaxChunkEmbedSize = Params["maxchunkembedsize"sv].AsUInt64(1024u * 1024u);
+ bool Force = Params["force"sv].AsBool(false);
+
+ std::pair<std::unique_ptr<RemoteProjectStore>, std::string> RemoteStoreResult =
+ CreateRemoteStore(Params, AuthManager, MaxBlockSize, MaxChunkEmbedSize);
+
+ if (RemoteStoreResult.first == nullptr)
+ {
+ return {HttpResponseCode::BadRequest, RemoteStoreResult.second};
+ }
+ std::unique_ptr<RemoteProjectStore> RemoteStore = std::move(RemoteStoreResult.first);
+ RemoteProjectStore::RemoteStoreInfo StoreInfo = RemoteStore->GetInfo();
+
+ ZEN_INFO("Loading oplog '{}/{}' from {}", Project.Identifier, Oplog.OplogId(), StoreInfo.Description);
+ RemoteProjectStore::Result Result = LoadOplog(m_CidStore, *RemoteStore, Oplog, Force);
+ return ConvertResult(Result);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpProjectService::HttpProjectService(CidStore& Store, ProjectStore* Projects, HttpStatsService& StatsService, AuthMgr& AuthMgr)
+: m_Log(logging::Get("project"))
+, m_CidStore(Store)
+, m_ProjectStore(Projects)
+, m_StatsService(StatsService)
+, m_AuthMgr(AuthMgr)
+{
+ using namespace std::literals;
+
+ m_StatsService.RegisterHandler("prj", *this);
+
+ m_Router.AddPattern("project", "([[:alnum:]_.]+)");
+ m_Router.AddPattern("log", "([[:alnum:]_.]+)");
+ m_Router.AddPattern("op", "([[:digit:]]+?)");
+ m_Router.AddPattern("chunk", "([[:xdigit:]]{24})");
+ m_Router.AddPattern("hash", "([[:xdigit:]]{40})");
+
+ m_Router.RegisterRoute(
+ "",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_ProjectStore->GetProjectsList()); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "list",
+ [this](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK, m_ProjectStore->GetProjectsList()); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/batch",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ // Parse Request
+
+ IoBuffer Payload = HttpReq.ReadPayload();
+ BinaryReader Reader(Payload);
+
+ struct RequestHeader
+ {
+ enum
+ {
+ kMagic = 0xAAAA'77AC
+ };
+ uint32_t Magic;
+ uint32_t ChunkCount;
+ uint32_t Reserved1;
+ uint32_t Reserved2;
+ };
+
+ struct RequestChunkEntry
+ {
+ Oid ChunkId;
+ uint32_t CorrelationId;
+ uint64_t Offset;
+ uint64_t RequestBytes;
+ };
+
+ if (Payload.Size() <= sizeof(RequestHeader))
+ {
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ RequestHeader RequestHdr;
+ Reader.Read(&RequestHdr, sizeof RequestHdr);
+
+ if (RequestHdr.Magic != RequestHeader::kMagic)
+ {
+ HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ std::vector<RequestChunkEntry> RequestedChunks;
+ RequestedChunks.resize(RequestHdr.ChunkCount);
+ Reader.Read(RequestedChunks.data(), sizeof(RequestChunkEntry) * RequestHdr.ChunkCount);
+
+ // Make Response
+
+ struct ResponseHeader
+ {
+ uint32_t Magic = 0xbada'b00f;
+ uint32_t ChunkCount;
+ uint32_t Reserved1 = 0;
+ uint32_t Reserved2 = 0;
+ };
+
+ struct ResponseChunkEntry
+ {
+ uint32_t CorrelationId;
+ uint32_t Flags = 0;
+ uint64_t ChunkSize;
+ };
+
+ std::vector<IoBuffer> OutBlobs;
+ OutBlobs.emplace_back(sizeof(ResponseHeader) + RequestHdr.ChunkCount * sizeof(ResponseChunkEntry));
+ for (uint32_t ChunkIndex = 0; ChunkIndex < RequestHdr.ChunkCount; ++ChunkIndex)
+ {
+ const RequestChunkEntry& RequestedChunk = RequestedChunks[ChunkIndex];
+ IoBuffer FoundChunk = FoundLog->FindChunk(RequestedChunk.ChunkId);
+ if (FoundChunk)
+ {
+ if (RequestedChunk.Offset > 0 || RequestedChunk.RequestBytes < uint64_t(-1))
+ {
+ uint64_t Offset = RequestedChunk.Offset;
+ if (Offset > FoundChunk.Size())
+ {
+ Offset = FoundChunk.Size();
+ }
+ uint64_t Size = RequestedChunk.RequestBytes;
+ if ((Offset + Size) > FoundChunk.Size())
+ {
+ Size = FoundChunk.Size() - Offset;
+ }
+ FoundChunk = IoBuffer(FoundChunk, Offset, Size);
+ }
+ }
+ OutBlobs.emplace_back(std::move(FoundChunk));
+ }
+ uint8_t* ResponsePtr = reinterpret_cast<uint8_t*>(OutBlobs[0].MutableData());
+ ResponseHeader ResponseHdr;
+ ResponseHdr.ChunkCount = RequestHdr.ChunkCount;
+ memcpy(ResponsePtr, &ResponseHdr, sizeof(ResponseHdr));
+ ResponsePtr += sizeof(ResponseHdr);
+ for (uint32_t ChunkIndex = 0; ChunkIndex < RequestHdr.ChunkCount; ++ChunkIndex)
+ {
+ // const RequestChunkEntry& RequestedChunk = RequestedChunks[ChunkIndex];
+ const IoBuffer& FoundChunk(OutBlobs[ChunkIndex + 1]);
+ ResponseChunkEntry ResponseChunk;
+ ResponseChunk.CorrelationId = ChunkIndex;
+ if (FoundChunk)
+ {
+ ResponseChunk.ChunkSize = FoundChunk.Size();
+ }
+ else
+ {
+ ResponseChunk.ChunkSize = uint64_t(-1);
+ }
+ memcpy(ResponsePtr, &ResponseChunk, sizeof(ResponseChunk));
+ ResponsePtr += sizeof(ResponseChunk);
+ }
+ return HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, OutBlobs);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/files",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ // File manifest fetch, returns the client file list
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+
+ const bool FilterClient = Params.GetValue("filter"sv) == "client"sv;
+
+ CbObject ResponsePayload;
+ std::pair<HttpResponseCode, std::string> Result =
+ m_ProjectStore->GetProjectFiles(ProjectId, OplogId, FilterClient, ResponsePayload);
+ if (Result.first == HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, ResponsePayload);
+ }
+ else
+ {
+ ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`",
+ ToString(HttpReq.RequestVerb()),
+ HttpReq.QueryString(),
+ static_cast<int>(Result.first),
+ Result.second);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/{chunk}/info",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ const auto& ChunkId = Req.GetCapture(3);
+
+ CbObject ResponsePayload;
+ std::pair<HttpResponseCode, std::string> Result = m_ProjectStore->GetChunkInfo(ProjectId, OplogId, ChunkId, ResponsePayload);
+ if (Result.first == HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, ResponsePayload);
+ }
+ else if (Result.first == HttpResponseCode::NotFound)
+ {
+ ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, ChunkId);
+ }
+ else
+ {
+ ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`",
+ ToString(HttpReq.RequestVerb()),
+ HttpReq.QueryString(),
+ static_cast<int>(Result.first),
+ Result.second);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/{chunk}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ const auto& ChunkId = Req.GetCapture(3);
+
+ uint64_t Offset = 0;
+ uint64_t Size = ~(0ull);
+
+ auto QueryParms = Req.ServerRequest().GetQueryParams();
+
+ if (auto OffsetParm = QueryParms.GetValue("offset"); OffsetParm.empty() == false)
+ {
+ if (auto OffsetVal = ParseInt<uint64_t>(OffsetParm))
+ {
+ Offset = OffsetVal.value();
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ }
+
+ if (auto SizeParm = QueryParms.GetValue("size"); SizeParm.empty() == false)
+ {
+ if (auto SizeVal = ParseInt<uint64_t>(SizeParm))
+ {
+ Size = SizeVal.value();
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+ }
+
+ HttpContentType AcceptType = HttpReq.AcceptContentType();
+
+ IoBuffer Chunk;
+ std::pair<HttpResponseCode, std::string> Result =
+ m_ProjectStore->GetChunkRange(ProjectId, OplogId, ChunkId, Offset, Size, AcceptType, Chunk);
+ if (Result.first == HttpResponseCode::OK)
+ {
+ ZEN_DEBUG("chunk - '{}/{}/{}' '{}'", ProjectId, OplogId, ChunkId, ToString(Chunk.GetContentType()));
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Chunk.GetContentType(), Chunk);
+ }
+ else if (Result.first == HttpResponseCode::NotFound)
+ {
+ ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, ChunkId);
+ }
+ else
+ {
+ ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`",
+ ToString(HttpReq.RequestVerb()),
+ HttpReq.QueryString(),
+ static_cast<int>(Result.first),
+ Result.second);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ },
+ HttpVerb::kGet | HttpVerb::kHead);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/{hash}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ const auto& Cid = Req.GetCapture(3);
+ HttpContentType AcceptType = HttpReq.AcceptContentType();
+ HttpContentType RequestType = HttpReq.RequestContentType();
+
+ switch (Req.ServerRequest().RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ IoBuffer Value;
+ std::pair<HttpResponseCode, std::string> Result =
+ m_ProjectStore->GetChunk(ProjectId, OplogId, Cid, AcceptType, Value);
+
+ if (Result.first == HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Value.GetContentType(), Value);
+ }
+ else if (Result.first == HttpResponseCode::NotFound)
+ {
+ ZEN_DEBUG("chunk - '{}/{}/{}' MISSING", ProjectId, OplogId, Cid);
+ }
+ else
+ {
+ ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`",
+ ToString(HttpReq.RequestVerb()),
+ HttpReq.QueryString(),
+ static_cast<int>(Result.first),
+ Result.second);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ }
+ case HttpVerb::kPost:
+ {
+ std::pair<HttpResponseCode, std::string> Result =
+ m_ProjectStore->PutChunk(ProjectId, OplogId, Cid, RequestType, HttpReq.ReadPayload());
+ if (Result.first == HttpResponseCode::OK || Result.first == HttpResponseCode::Created)
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ else
+ {
+ ZEN_DEBUG("Request {}: '{}' failed with {}. Reason: `{}`",
+ ToString(HttpReq.RequestVerb()),
+ HttpReq.QueryString(),
+ static_cast<int>(Result.first),
+ Result.second);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ }
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/prep",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ // This operation takes a list of referenced hashes and decides which
+ // chunks are not present on this server. This list is then returned in
+ // the "need" list in the response
+
+ IoBuffer Payload = HttpReq.ReadPayload();
+ CbObject RequestObject = LoadCompactBinaryObject(Payload);
+
+ std::vector<IoHash> NeedList;
+
+ for (auto Entry : RequestObject["have"sv])
+ {
+ const IoHash FileHash = Entry.AsHash();
+
+ if (!m_CidStore.ContainsChunk(FileHash))
+ {
+ ZEN_DEBUG("prep - NEED: {}", FileHash);
+
+ NeedList.push_back(FileHash);
+ }
+ }
+
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("need");
+
+ for (const IoHash& Hash : NeedList)
+ {
+ Cbo << Hash;
+ }
+
+ Cbo.EndArray();
+ CbObject Response = Cbo.Save();
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Response);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/new",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+
+ bool IsUsingSalt = false;
+ IoHash SaltHash = IoHash::Zero;
+
+ if (std::string_view SaltParam = Params.GetValue("salt"); SaltParam.empty() == false)
+ {
+ const uint32_t Salt = std::stoi(std::string(SaltParam));
+ SaltHash = IoHash::HashBuffer(&Salt, sizeof Salt);
+ IsUsingSalt = true;
+ }
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog& Oplog = *FoundLog;
+
+ IoBuffer Payload = HttpReq.ReadPayload();
+
+ // This will attempt to open files which may not exist for the case where
+ // the prep step rejected the chunk. This should be fixed since there's
+ // a performance cost associated with any file system activity
+
+ bool IsValid = true;
+ std::vector<IoHash> MissingChunks;
+
+ CbPackage::AttachmentResolver Resolver = [&](const IoHash& Hash) -> SharedBuffer {
+ if (m_CidStore.ContainsChunk(Hash))
+ {
+ // Return null attachment as we already have it, no point in reading it and storing it again
+ return {};
+ }
+
+ IoHash AttachmentId;
+ if (IsUsingSalt)
+ {
+ IoHash AttachmentSpec[]{SaltHash, Hash};
+ AttachmentId = IoHash::HashBuffer(MakeMemoryView(AttachmentSpec));
+ }
+ else
+ {
+ AttachmentId = Hash;
+ }
+
+ std::filesystem::path AttachmentPath = Oplog.TempPath() / AttachmentId.ToHexString();
+ if (IoBuffer Data = IoBufferBuilder::MakeFromTemporaryFile(AttachmentPath))
+ {
+ return SharedBuffer(std::move(Data));
+ }
+ else
+ {
+ IsValid = false;
+ MissingChunks.push_back(Hash);
+
+ return {};
+ }
+ };
+
+ CbPackage Package;
+
+ if (!legacy::TryLoadCbPackage(Package, Payload, &UniqueBuffer::Alloc, &Resolver))
+ {
+ std::filesystem::path BadPackagePath =
+ Oplog.TempPath() / "bad_packages"sv / fmt::format("session{}_request{}"sv, HttpReq.SessionId(), HttpReq.RequestId());
+
+ ZEN_WARN("Received malformed package! Saving payload to '{}'", BadPackagePath);
+
+ WriteFile(BadPackagePath, Payload);
+
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid package");
+ }
+
+ if (!IsValid)
+ {
+ // TODO: emit diagnostics identifying missing chunks
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound, HttpContentType::kText, "Missing chunk reference");
+ }
+
+ CbObject Core = Package.GetObject();
+
+ if (!Core["key"sv])
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "No oplog entry key specified");
+ }
+
+ // Write core to oplog
+
+ const uint32_t OpLsn = Oplog.AppendNewOplogEntry(Package);
+
+ if (OpLsn == ProjectStore::Oplog::kInvalidOp)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest);
+ }
+
+ ZEN_DEBUG("'{}/{}' op #{} ({}) - '{}'", ProjectId, OplogId, OpLsn, NiceBytes(Payload.Size()), Core["key"sv].AsString());
+
+ HttpReq.WriteResponse(HttpResponseCode::Created);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/{op}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const std::string& ProjectId = Req.GetCapture(1);
+ const std::string& OplogId = Req.GetCapture(2);
+ const std::string& OpIdString = Req.GetCapture(3);
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog& Oplog = *FoundLog;
+
+ if (const std::optional<int32_t> OpId = ParseInt<uint32_t>(OpIdString))
+ {
+ if (std::optional<CbObject> MaybeOp = Oplog.GetOpByIndex(OpId.value()))
+ {
+ CbObject& Op = MaybeOp.value();
+ if (Req.ServerRequest().AcceptContentType() == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+ Package.SetObject(Op);
+
+ Op.IterateAttachments([&](CbFieldView FieldView) {
+ const IoHash AttachmentHash = FieldView.AsAttachment();
+ IoBuffer Payload = m_CidStore.FindChunkByCid(AttachmentHash);
+
+ // We force this for now as content type is not consistently tracked (will
+ // be fixed in CidStore refactor)
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+
+ if (Payload)
+ {
+ switch (Payload.GetContentType())
+ {
+ case ZenContentType::kCbObject:
+ if (CbObject Object = LoadCompactBinaryObject(Payload))
+ {
+ Package.AddAttachment(CbAttachment(Object));
+ }
+ else
+ {
+ // Error - malformed object
+
+ ZEN_WARN("malformed object returned for {}", AttachmentHash);
+ }
+ break;
+
+ case ZenContentType::kCompressedBinary:
+ if (CompressedBuffer Compressed = CompressedBuffer::FromCompressedNoValidate(std::move(Payload)))
+ {
+ Package.AddAttachment(CbAttachment(Compressed, AttachmentHash));
+ }
+ else
+ {
+ // Error - not compressed!
+
+ ZEN_WARN("invalid compressed binary returned for {}", AttachmentHash);
+ }
+ break;
+
+ default:
+ Package.AddAttachment(CbAttachment(SharedBuffer(Payload)));
+ break;
+ }
+ }
+ });
+
+ return HttpReq.WriteResponse(HttpResponseCode::Accepted, Package);
+ }
+ else
+ {
+ // Client cannot accept a package, so we only send the core object
+ return HttpReq.WriteResponse(HttpResponseCode::Accepted, Op);
+ }
+ }
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}",
+ [this](HttpRouterRequest& Req) {
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+
+ if (!Project)
+ {
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("project {} not found", ProjectId));
+ }
+
+ switch (Req.ServerRequest().RequestVerb())
+ {
+ case HttpVerb::kGet:
+ {
+ ProjectStore::Oplog* OplogIt = Project->OpenOplog(OplogId);
+
+ if (!OplogIt)
+ {
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("oplog {} not found in project {}", OplogId, ProjectId));
+ }
+
+ ProjectStore::Oplog& Log = *OplogIt;
+
+ CbObjectWriter Cb;
+ Cb << "id"sv << Log.OplogId() << "project"sv << Project->Identifier << "tempdir"sv << Log.TempPath().c_str()
+ << "markerpath"sv << Log.MarkerPath().c_str() << "totalsize"sv << Log.TotalSize() << "opcount"
+ << Log.OplogCount() << "expired"sv << Log.IsExpired();
+
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cb.Save());
+ }
+ break;
+
+ case HttpVerb::kPost:
+ {
+ std::filesystem::path OplogMarkerPath;
+ if (CbObject Params = Req.ServerRequest().ReadPayloadObject())
+ {
+ OplogMarkerPath = Params["gcpath"sv].AsString();
+ }
+
+ ProjectStore::Oplog* OplogIt = Project->OpenOplog(OplogId);
+
+ if (!OplogIt)
+ {
+ if (!Project->NewOplog(OplogId, OplogMarkerPath))
+ {
+ // TODO: indicate why the operation failed!
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::InternalServerError);
+ }
+
+ ZEN_INFO("established oplog '{}/{}', gc marker file at '{}'", ProjectId, OplogId, OplogMarkerPath);
+
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::Created);
+ }
+
+ // I guess this should ultimately be used to execute RPCs but for now, it
+ // does absolutely nothing
+
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::BadRequest);
+ }
+ break;
+
+ case HttpVerb::kDelete:
+ {
+ ZEN_INFO("deleting oplog '{}/{}'", ProjectId, OplogId);
+
+ Project->DeleteOplog(OplogId);
+
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::OK);
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kPost | HttpVerb::kGet | HttpVerb::kDelete);
+
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/entries",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+ if (!Project)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Oplog* FoundLog = Project->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ CbObjectWriter Response;
+
+ if (FoundLog->OplogCount() > 0)
+ {
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+
+ if (auto OpKey = Params.GetValue("opkey"); !OpKey.empty())
+ {
+ Oid OpKeyId = OpKeyStringAsOId(OpKey);
+ std::optional<CbObject> Op = FoundLog->GetOpByKey(OpKeyId);
+
+ if (Op.has_value())
+ {
+ Response << "entry"sv << Op.value();
+ }
+ else
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ }
+ else
+ {
+ Response.BeginArray("entries"sv);
+
+ FoundLog->IterateOplog([&Response](CbObject Op) { Response << Op; });
+
+ Response.EndArray();
+ }
+ }
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Response.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "{project}",
+ [this](HttpRouterRequest& Req) {
+ const std::string ProjectId = Req.GetCapture(1);
+
+ switch (Req.ServerRequest().RequestVerb())
+ {
+ case HttpVerb::kPost:
+ {
+ IoBuffer Payload = Req.ServerRequest().ReadPayload();
+ CbObject Params = LoadCompactBinaryObject(Payload);
+ std::string_view Id = Params["id"sv].AsString();
+ std::string_view Root = Params["root"sv].AsString();
+ std::string_view EngineRoot = Params["engine"sv].AsString();
+ std::string_view ProjectRoot = Params["project"sv].AsString();
+ std::string_view ProjectFilePath = Params["projectfile"sv].AsString();
+
+ const std::filesystem::path BasePath = m_ProjectStore->BasePath() / ProjectId;
+ m_ProjectStore->NewProject(BasePath, ProjectId, Root, EngineRoot, ProjectRoot, ProjectFilePath);
+
+ ZEN_INFO("established project - {} (id: '{}', roots: '{}', '{}', '{}', '{}'{})",
+ ProjectId,
+ Id,
+ Root,
+ EngineRoot,
+ ProjectRoot,
+ ProjectFilePath,
+ ProjectFilePath.empty() ? ", project will not be GCd due to empty project file path" : "");
+
+ Req.ServerRequest().WriteResponse(HttpResponseCode::Created);
+ }
+ break;
+
+ case HttpVerb::kGet:
+ {
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+
+ if (!Project)
+ {
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("project {} not found", ProjectId));
+ }
+
+ std::vector<std::string> OpLogs = Project->ScanForOplogs();
+
+ CbObjectWriter Response;
+ Response << "id"sv << Project->Identifier;
+ Response << "root"sv << PathToUtf8(Project->RootDir);
+ Response << "engine"sv << PathToUtf8(Project->EngineRootDir);
+ Response << "project"sv << PathToUtf8(Project->ProjectRootDir);
+ Response << "projectfile"sv << PathToUtf8(Project->ProjectFilePath);
+
+ Response.BeginArray("oplogs"sv);
+ for (const std::string& OplogId : OpLogs)
+ {
+ Response.BeginObject();
+ Response << "id"sv << OplogId;
+ Response.EndObject();
+ }
+ Response.EndArray(); // oplogs
+
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Response.Save());
+ }
+ break;
+
+ case HttpVerb::kDelete:
+ {
+ Ref<ProjectStore::Project> Project = m_ProjectStore->OpenProject(ProjectId);
+
+ if (!Project)
+ {
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::NotFound,
+ HttpContentType::kText,
+ fmt::format("project {} not found", ProjectId));
+ }
+
+ ZEN_INFO("deleting project '{}'", ProjectId);
+ if (!m_ProjectStore->DeleteProject(ProjectId))
+ {
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::Locked,
+ HttpContentType::kText,
+ fmt::format("project {} is in use", ProjectId));
+ }
+
+ return Req.ServerRequest().WriteResponse(HttpResponseCode::NoContent);
+ }
+ break;
+
+ default:
+ break;
+ }
+ },
+ HttpVerb::kGet | HttpVerb::kPost | HttpVerb::kDelete);
+
+ // Push a oplog container
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/save",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ if (HttpReq.RequestContentType() != HttpContentType::kCbObject)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid content type");
+ }
+ IoBuffer Payload = Req.ServerRequest().ReadPayload();
+
+ CbObject Response;
+ std::pair<HttpResponseCode, std::string> Result = m_ProjectStore->WriteOplog(ProjectId, OplogId, std::move(Payload), Response);
+ if (Result.first == HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Response);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ },
+ HttpVerb::kPost);
+
+ // Pull a oplog container
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/load",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ if (HttpReq.AcceptContentType() != HttpContentType::kCbObject)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::BadRequest, HttpContentType::kText, "Invalid accept content type");
+ }
+ IoBuffer Payload = Req.ServerRequest().ReadPayload();
+
+ CbObject Response;
+ std::pair<HttpResponseCode, std::string> Result =
+ m_ProjectStore->ReadOplog(ProjectId, OplogId, Req.ServerRequest().GetQueryParams(), Response);
+ if (Result.first == HttpResponseCode::OK)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Response);
+ }
+ if (Result.second.empty())
+ {
+ return HttpReq.WriteResponse(Result.first);
+ }
+ return HttpReq.WriteResponse(Result.first, HttpContentType::kText, Result.second);
+ },
+ HttpVerb::kGet);
+
+ // Do an rpc style operation on project/oplog
+ m_Router.RegisterRoute(
+ "{project}/oplog/{log}/rpc",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ IoBuffer Payload = Req.ServerRequest().ReadPayload();
+
+ m_ProjectStore->Rpc(HttpReq, ProjectId, OplogId, std::move(Payload), m_AuthMgr);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "details\\$",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+ bool CSV = Params.GetValue("csv") == "true";
+ bool Details = Params.GetValue("details") == "true";
+ bool OpDetails = Params.GetValue("opdetails") == "true";
+ bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true";
+
+ if (CSV)
+ {
+ ExtendableStringBuilder<4096> CSVWriter;
+ CSVHeader(Details, AttachmentDetails, CSVWriter);
+
+ m_ProjectStore->IterateProjects([&](ProjectStore::Project& Project) {
+ Project.IterateOplogs([&](ProjectStore::Oplog& Oplog) {
+ Oplog.IterateOplogWithKey(
+ [this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) {
+ CSVWriteOp(m_CidStore,
+ Project.Identifier,
+ Oplog.OplogId(),
+ Details,
+ AttachmentDetails,
+ LSN,
+ Key,
+ Op,
+ CSVWriter);
+ });
+ });
+ });
+
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView());
+ }
+ else
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("projects");
+ {
+ m_ProjectStore->DiscoverProjects();
+
+ m_ProjectStore->IterateProjects([&](ProjectStore::Project& Project) {
+ std::vector<std::string> OpLogs = Project.ScanForOplogs();
+ CbWriteProject(m_CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo);
+ });
+ }
+ Cbo.EndArray();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "details\\$/{project}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+ bool CSV = Params.GetValue("csv") == "true";
+ bool Details = Params.GetValue("details") == "true";
+ bool OpDetails = Params.GetValue("opdetails") == "true";
+ bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true";
+
+ Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId);
+ if (!FoundProject)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ ProjectStore::Project& Project = *FoundProject.Get();
+ if (CSV)
+ {
+ ExtendableStringBuilder<4096> CSVWriter;
+ CSVHeader(Details, AttachmentDetails, CSVWriter);
+
+ FoundProject->IterateOplogs([&](ProjectStore::Oplog& Oplog) {
+ Oplog.IterateOplogWithKey([this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN,
+ const Oid& Key,
+ CbObject Op) {
+ CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, Key, Op, CSVWriter);
+ });
+ });
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView());
+ }
+ else
+ {
+ CbObjectWriter Cbo;
+ std::vector<std::string> OpLogs = FoundProject->ScanForOplogs();
+ Cbo.BeginArray("projects");
+ {
+ CbWriteProject(m_CidStore, Project, OpLogs, Details, OpDetails, AttachmentDetails, Cbo);
+ }
+ Cbo.EndArray();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "details\\$/{project}/{log}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+ bool CSV = Params.GetValue("csv") == "true";
+ bool Details = Params.GetValue("details") == "true";
+ bool OpDetails = Params.GetValue("opdetails") == "true";
+ bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true";
+
+ Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId);
+ if (!FoundProject)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ ProjectStore::Oplog* FoundLog = FoundProject->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ ProjectStore::Project& Project = *FoundProject.Get();
+ ProjectStore::Oplog& Oplog = *FoundLog;
+ if (CSV)
+ {
+ ExtendableStringBuilder<4096> CSVWriter;
+ CSVHeader(Details, AttachmentDetails, CSVWriter);
+
+ Oplog.IterateOplogWithKey(
+ [this, &Project, &Oplog, &CSVWriter, Details, AttachmentDetails](int LSN, const Oid& Key, CbObject Op) {
+ CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, Key, Op, CSVWriter);
+ });
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView());
+ }
+ else
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("oplogs");
+ {
+ CbWriteOplog(m_CidStore, Oplog, Details, OpDetails, AttachmentDetails, Cbo);
+ }
+ Cbo.EndArray();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "details\\$/{project}/{log}/{chunk}",
+ [this](HttpRouterRequest& Req) {
+ HttpServerRequest& HttpReq = Req.ServerRequest();
+ const auto& ProjectId = Req.GetCapture(1);
+ const auto& OplogId = Req.GetCapture(2);
+ const auto& ChunkId = Req.GetCapture(3);
+
+ HttpServerRequest::QueryParams Params = HttpReq.GetQueryParams();
+ bool CSV = Params.GetValue("csv") == "true";
+ bool Details = Params.GetValue("details") == "true";
+ bool OpDetails = Params.GetValue("opdetails") == "true";
+ bool AttachmentDetails = Params.GetValue("attachmentdetails") == "true";
+
+ Ref<ProjectStore::Project> FoundProject = m_ProjectStore->OpenProject(ProjectId);
+ if (!FoundProject)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ ProjectStore::Oplog* FoundLog = FoundProject->OpenOplog(OplogId);
+
+ if (!FoundLog)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ if (ChunkId.size() != 2 * sizeof(Oid::OidBits))
+ {
+ return HttpReq.WriteResponse(
+ HttpResponseCode::BadRequest,
+ HttpContentType::kText,
+ fmt::format("Chunk info request for invalid chunk id '{}/{}'/'{}'", ProjectId, OplogId, ChunkId));
+ }
+
+ const Oid ObjId = Oid::FromHexString(ChunkId);
+ ProjectStore::Project& Project = *FoundProject.Get();
+ ProjectStore::Oplog& Oplog = *FoundLog;
+
+ int LSN = Oplog.GetOpIndexByKey(ObjId);
+ if (LSN == -1)
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+ std::optional<CbObject> Op = Oplog.GetOpByIndex(LSN);
+ if (!Op.has_value())
+ {
+ return HttpReq.WriteResponse(HttpResponseCode::NotFound);
+ }
+
+ if (CSV)
+ {
+ ExtendableStringBuilder<4096> CSVWriter;
+ CSVHeader(Details, AttachmentDetails, CSVWriter);
+
+ CSVWriteOp(m_CidStore, Project.Identifier, Oplog.OplogId(), Details, AttachmentDetails, LSN, ObjId, Op.value(), CSVWriter);
+ HttpReq.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, CSVWriter.ToView());
+ }
+ else
+ {
+ CbObjectWriter Cbo;
+ Cbo.BeginArray("ops");
+ {
+ CbWriteOp(m_CidStore, Details, OpDetails, AttachmentDetails, LSN, ObjId, Op.value(), Cbo);
+ }
+ Cbo.EndArray();
+ HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+ },
+ HttpVerb::kGet);
+}
+
+HttpProjectService::~HttpProjectService()
+{
+ m_StatsService.UnregisterHandler("prj", *this);
+}
+
+const char*
+HttpProjectService::BaseUri() const
+{
+ return "/prj/";
+}
+
+void
+HttpProjectService::HandleRequest(HttpServerRequest& Request)
+{
+ if (m_Router.HandleRequest(Request) == false)
+ {
+ ZEN_WARN("No route found for {0}", Request.RelativeUri());
+ }
+}
+
+void
+HttpProjectService::HandleStatsRequest(HttpServerRequest& HttpReq)
+{
+ const GcStorageSize StoreSize = m_ProjectStore->StorageSize();
+ const CidStoreSize CidSize = m_CidStore.TotalSize();
+
+ CbObjectWriter Cbo;
+ Cbo.BeginObject("store");
+ {
+ Cbo.BeginObject("size");
+ {
+ Cbo << "disk" << StoreSize.DiskSize;
+ Cbo << "memory" << StoreSize.MemorySize;
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndObject();
+
+ Cbo.BeginObject("cid");
+ {
+ Cbo.BeginObject("size");
+ {
+ Cbo << "tiny" << CidSize.TinySize;
+ Cbo << "small" << CidSize.SmallSize;
+ Cbo << "large" << CidSize.LargeSize;
+ Cbo << "total" << CidSize.TotalSize;
+ }
+ Cbo.EndObject();
+ }
+ Cbo.EndObject();
+
+ return HttpReq.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+namespace testutils {
+ using namespace std::literals;
+
+ std::string OidAsString(const Oid& Id)
+ {
+ StringBuilder<25> OidStringBuilder;
+ Id.ToString(OidStringBuilder);
+ return OidStringBuilder.ToString();
+ }
+
+ CbPackage CreateOplogPackage(const Oid& Id, const std::span<const std::pair<Oid, CompressedBuffer>>& Attachments)
+ {
+ CbPackage Package;
+ CbObjectWriter Object;
+ Object << "key"sv << OidAsString(Id);
+ if (!Attachments.empty())
+ {
+ Object.BeginArray("bulkdata");
+ for (const auto& Attachment : Attachments)
+ {
+ CbAttachment Attach(Attachment.second, Attachment.second.DecodeRawHash());
+ Object.BeginObject();
+ Object << "id"sv << Attachment.first;
+ Object << "type"sv
+ << "Standard"sv;
+ Object << "data"sv << Attach;
+ Object.EndObject();
+
+ Package.AddAttachment(Attach);
+ }
+ Object.EndArray();
+ }
+ Package.SetObject(Object.Save());
+ return Package;
+ };
+
+ std::vector<std::pair<Oid, CompressedBuffer>> CreateAttachments(const std::span<const size_t>& Sizes)
+ {
+ std::vector<std::pair<Oid, CompressedBuffer>> Result;
+ Result.reserve(Sizes.size());
+ for (size_t Size : Sizes)
+ {
+ std::vector<uint8_t> Data;
+ Data.resize(Size);
+ uint16_t* DataPtr = reinterpret_cast<uint16_t*>(Data.data());
+ for (size_t Idx = 0; Idx < Size / 2; ++Idx)
+ {
+ DataPtr[Idx] = static_cast<uint16_t>(Idx % 0xffffu);
+ }
+ if (Size & 1)
+ {
+ Data[Size - 1] = static_cast<uint8_t>((Size - 1) & 0xff);
+ }
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer::MakeView(Data.data(), Data.size()));
+ Result.emplace_back(std::pair<Oid, CompressedBuffer>(Oid::NewOid(), Compressed));
+ }
+ return Result;
+ }
+
+ uint64 GetCompressedOffset(const CompressedBuffer& Buffer, uint64 RawOffset)
+ {
+ if (RawOffset > 0)
+ {
+ uint64 BlockSize = 0;
+ OodleCompressor Compressor;
+ OodleCompressionLevel CompressionLevel;
+ if (!Buffer.TryGetCompressParameters(Compressor, CompressionLevel, BlockSize))
+ {
+ return 0;
+ }
+ return BlockSize > 0 ? RawOffset % BlockSize : 0;
+ }
+ return 0;
+ }
+
+} // namespace testutils
+
+TEST_CASE("project.store.create")
+{
+ using namespace std::literals;
+
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ CidStore.Initialize(CidConfig);
+
+ std::string_view ProjectName("proj1"sv);
+ std::filesystem::path BasePath = TempDir.Path() / "projectstore";
+ ProjectStore ProjectStore(CidStore, BasePath, Gc);
+ std::filesystem::path RootDir = TempDir.Path() / "root";
+ std::filesystem::path EngineRootDir = TempDir.Path() / "engine";
+ std::filesystem::path ProjectRootDir = TempDir.Path() / "game";
+ std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject";
+
+ Ref<ProjectStore::Project> Project(ProjectStore.NewProject(BasePath / ProjectName,
+ ProjectName,
+ RootDir.string(),
+ EngineRootDir.string(),
+ ProjectRootDir.string(),
+ ProjectFilePath.string()));
+ CHECK(ProjectStore.DeleteProject(ProjectName));
+ CHECK(!Project->Exists(BasePath));
+}
+
+TEST_CASE("project.store.lifetimes")
+{
+ using namespace std::literals;
+
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ CidStore.Initialize(CidConfig);
+
+ std::filesystem::path BasePath = TempDir.Path() / "projectstore";
+ ProjectStore ProjectStore(CidStore, BasePath, Gc);
+ std::filesystem::path RootDir = TempDir.Path() / "root";
+ std::filesystem::path EngineRootDir = TempDir.Path() / "engine";
+ std::filesystem::path ProjectRootDir = TempDir.Path() / "game";
+ std::filesystem::path ProjectFilePath = TempDir.Path() / "game" / "game.uproject";
+
+ Ref<ProjectStore::Project> Project(ProjectStore.NewProject(BasePath / "proj1"sv,
+ "proj1"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ ProjectRootDir.string(),
+ ProjectFilePath.string()));
+ ProjectStore::Oplog* Oplog = Project->NewOplog("oplog1", {});
+ CHECK(Oplog != nullptr);
+
+ std::filesystem::path DeletePath;
+ CHECK(Project->PrepareForDelete(DeletePath));
+ CHECK(!DeletePath.empty());
+ CHECK(Project->OpenOplog("oplog1") == nullptr);
+ // Oplog is now invalid, but pointer can still be accessed since we store old oplog pointers
+ CHECK(Oplog->OplogCount() == 0);
+ // Project is still valid since we have a Ref to it
+ CHECK(Project->Identifier == "proj1"sv);
+}
+
+TEST_CASE("project.store.gc")
+{
+ using namespace std::literals;
+ using namespace testutils;
+
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas", .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ CidStore.Initialize(CidConfig);
+
+ std::filesystem::path BasePath = TempDir.Path() / "projectstore";
+ ProjectStore ProjectStore(CidStore, BasePath, Gc);
+ std::filesystem::path RootDir = TempDir.Path() / "root";
+ std::filesystem::path EngineRootDir = TempDir.Path() / "engine";
+
+ std::filesystem::path Project1RootDir = TempDir.Path() / "game1";
+ std::filesystem::path Project1FilePath = TempDir.Path() / "game1" / "game.uproject";
+ {
+ CreateDirectories(Project1FilePath.parent_path());
+ BasicFile ProjectFile;
+ ProjectFile.Open(Project1FilePath, BasicFile::Mode::kTruncate);
+ }
+
+ std::filesystem::path Project2RootDir = TempDir.Path() / "game2";
+ std::filesystem::path Project2FilePath = TempDir.Path() / "game2" / "game.uproject";
+ {
+ CreateDirectories(Project2FilePath.parent_path());
+ BasicFile ProjectFile;
+ ProjectFile.Open(Project2FilePath, BasicFile::Mode::kTruncate);
+ }
+
+ {
+ Ref<ProjectStore::Project> Project1(ProjectStore.NewProject(BasePath / "proj1"sv,
+ "proj1"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ Project1RootDir.string(),
+ Project1FilePath.string()));
+ ProjectStore::Oplog* Oplog = Project1->NewOplog("oplog1", {});
+ CHECK(Oplog != nullptr);
+
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), {}));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{77})));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{7123, 583, 690, 99})));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{55, 122})));
+ }
+
+ {
+ Ref<ProjectStore::Project> Project2(ProjectStore.NewProject(BasePath / "proj2"sv,
+ "proj2"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ Project2RootDir.string(),
+ Project2FilePath.string()));
+ ProjectStore::Oplog* Oplog = Project2->NewOplog("oplog1", {});
+ CHECK(Oplog != nullptr);
+
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), {}));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{177})));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{9123, 383, 590, 96})));
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(Oid::NewOid(), CreateAttachments(std::initializer_list<size_t>{535, 221})));
+ }
+
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ ProjectStore.GatherReferences(GcCtx);
+ size_t RefCount = 0;
+ GcCtx.IterateCids([&RefCount](const IoHash&) { RefCount++; });
+ CHECK(RefCount == 14);
+ ProjectStore.CollectGarbage(GcCtx);
+ CHECK(ProjectStore.OpenProject("proj1"sv));
+ CHECK(ProjectStore.OpenProject("proj2"sv));
+ }
+
+ std::filesystem::remove(Project1FilePath);
+
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ ProjectStore.GatherReferences(GcCtx);
+ size_t RefCount = 0;
+ GcCtx.IterateCids([&RefCount](const IoHash&) { RefCount++; });
+ CHECK(RefCount == 7);
+ ProjectStore.CollectGarbage(GcCtx);
+ CHECK(!ProjectStore.OpenProject("proj1"sv));
+ CHECK(ProjectStore.OpenProject("proj2"sv));
+ }
+}
+
+TEST_CASE("project.store.partial.read")
+{
+ using namespace std::literals;
+ using namespace testutils;
+
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+ CidStoreConfiguration CidConfig = {.RootDirectory = TempDir.Path() / "cas"sv, .TinyValueThreshold = 1024, .HugeValueThreshold = 4096};
+ CidStore.Initialize(CidConfig);
+
+ std::filesystem::path BasePath = TempDir.Path() / "projectstore"sv;
+ ProjectStore ProjectStore(CidStore, BasePath, Gc);
+ std::filesystem::path RootDir = TempDir.Path() / "root"sv;
+ std::filesystem::path EngineRootDir = TempDir.Path() / "engine"sv;
+
+ std::filesystem::path Project1RootDir = TempDir.Path() / "game1"sv;
+ std::filesystem::path Project1FilePath = TempDir.Path() / "game1"sv / "game.uproject"sv;
+ {
+ CreateDirectories(Project1FilePath.parent_path());
+ BasicFile ProjectFile;
+ ProjectFile.Open(Project1FilePath, BasicFile::Mode::kTruncate);
+ }
+
+ std::vector<Oid> OpIds;
+ OpIds.insert(OpIds.end(), {Oid::NewOid(), Oid::NewOid(), Oid::NewOid(), Oid::NewOid()});
+ std::unordered_map<Oid, std::vector<std::pair<Oid, CompressedBuffer>>, Oid::Hasher> Attachments;
+ {
+ Ref<ProjectStore::Project> Project1(ProjectStore.NewProject(BasePath / "proj1"sv,
+ "proj1"sv,
+ RootDir.string(),
+ EngineRootDir.string(),
+ Project1RootDir.string(),
+ Project1FilePath.string()));
+ ProjectStore::Oplog* Oplog = Project1->NewOplog("oplog1"sv, {});
+ CHECK(Oplog != nullptr);
+ Attachments[OpIds[0]] = {};
+ Attachments[OpIds[1]] = CreateAttachments(std::initializer_list<size_t>{77});
+ Attachments[OpIds[2]] = CreateAttachments(std::initializer_list<size_t>{7123, 9583, 690, 99});
+ Attachments[OpIds[3]] = CreateAttachments(std::initializer_list<size_t>{55, 122});
+ for (auto It : Attachments)
+ {
+ Oplog->AppendNewOplogEntry(CreateOplogPackage(It.first, It.second));
+ }
+ }
+ {
+ IoBuffer Chunk;
+ CHECK(ProjectStore
+ .GetChunk("proj1"sv,
+ "oplog1"sv,
+ Attachments[OpIds[1]][0].second.DecodeRawHash().ToHexString(),
+ HttpContentType::kCompressedBinary,
+ Chunk)
+ .first == HttpResponseCode::OK);
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Attachment = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), RawHash, RawSize);
+ CHECK(RawSize == Attachments[OpIds[1]][0].second.DecodeRawSize());
+ }
+
+ IoBuffer ChunkResult;
+ CHECK(ProjectStore
+ .GetChunkRange("proj1"sv,
+ "oplog1"sv,
+ OidAsString(Attachments[OpIds[2]][1].first),
+ 0,
+ ~0ull,
+ HttpContentType::kCompressedBinary,
+ ChunkResult)
+ .first == HttpResponseCode::OK);
+ CHECK(ChunkResult);
+ CHECK(CompressedBuffer::FromCompressedNoValidate(std::move(ChunkResult)).DecodeRawSize() ==
+ Attachments[OpIds[2]][1].second.DecodeRawSize());
+
+ IoBuffer PartialChunkResult;
+ CHECK(ProjectStore
+ .GetChunkRange("proj1"sv,
+ "oplog1"sv,
+ OidAsString(Attachments[OpIds[2]][1].first),
+ 5,
+ 1773,
+ HttpContentType::kCompressedBinary,
+ PartialChunkResult)
+ .first == HttpResponseCode::OK);
+ CHECK(PartialChunkResult);
+ IoHash PartialRawHash;
+ uint64_t PartialRawSize;
+ CompressedBuffer PartialCompressedResult =
+ CompressedBuffer::FromCompressed(SharedBuffer(PartialChunkResult), PartialRawHash, PartialRawSize);
+ CHECK(PartialRawSize >= 1773);
+
+ uint64_t RawOffsetInPartialCompressed = GetCompressedOffset(PartialCompressedResult, 5);
+ SharedBuffer PartialDecompressed = PartialCompressedResult.Decompress(RawOffsetInPartialCompressed);
+ SharedBuffer FullDecompressed = Attachments[OpIds[2]][1].second.Decompress();
+ const uint8_t* FullDataPtr = &(reinterpret_cast<const uint8_t*>(FullDecompressed.GetView().GetData())[5]);
+ const uint8_t* PartialDataPtr = reinterpret_cast<const uint8_t*>(PartialDecompressed.GetView().GetData());
+ CHECK(FullDataPtr[0] == PartialDataPtr[0]);
+}
+
+TEST_CASE("project.store.block")
+{
+ using namespace std::literals;
+ using namespace testutils;
+
+ std::vector<std::size_t> AttachmentSizes({7633, 6825, 5738, 8031, 7225, 566, 3656, 6006, 24, 3466, 1093, 4269, 2257, 3685, 3489,
+ 7194, 6151, 5482, 6217, 3511, 6738, 5061, 7537, 2759, 1916, 8210, 2235, 4024, 1582, 5251,
+ 491, 5464, 4607, 8135, 3767, 4045, 4415, 5007, 8876, 6761, 3359, 8526, 4097, 4855, 8225});
+
+ std::vector<std::pair<Oid, CompressedBuffer>> AttachmentsWithId = CreateAttachments(AttachmentSizes);
+ std::vector<SharedBuffer> Chunks;
+ Chunks.reserve(AttachmentSizes.size());
+ for (const auto& It : AttachmentsWithId)
+ {
+ Chunks.push_back(It.second.GetCompressed().Flatten());
+ }
+ CompressedBuffer Block = GenerateBlock(std::move(Chunks));
+ IoBuffer BlockBuffer = Block.GetCompressed().Flatten().AsIoBuffer();
+ CHECK(IterateBlock(std::move(BlockBuffer), [](CompressedBuffer&&, const IoHash&) {}));
+}
+
+#endif
+
+void
+prj_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/projectstore.h b/src/zenserver/projectstore/projectstore.h
new file mode 100644
index 000000000..e4f664b85
--- /dev/null
+++ b/src/zenserver/projectstore/projectstore.h
@@ -0,0 +1,372 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/uid.h>
+#include <zencore/xxhash.h>
+#include <zenhttp/httpserver.h>
+#include <zenstore/gc.h>
+
+#include "monitoring/httpstats.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+class CbPackage;
+class CidStore;
+class AuthMgr;
+class ScrubContext;
+
+struct OplogEntry
+{
+ uint32_t OpLsn;
+ uint32_t OpCoreOffset; // note: Multiple of alignment!
+ uint32_t OpCoreSize;
+ uint32_t OpCoreHash; // Used as checksum
+ XXH3_128 OpKeyHash; // XXH128_canonical_t
+
+ inline Oid OpKeyAsOId() const
+ {
+ Oid Id;
+ memcpy(Id.OidBits, &OpKeyHash, sizeof Id.OidBits);
+ return Id;
+ }
+};
+
+struct OplogEntryAddress
+{
+ uint64_t Offset;
+ uint64_t Size;
+};
+
+static_assert(IsPow2(sizeof(OplogEntry)));
+
+/** Project Store
+
+ A project store consists of a number of Projects.
+
+ Each project contains a number of oplogs (short for "operation log"). UE uses
+ one oplog per target platform to store the output of the cook process.
+
+ An oplog consists of a sequence of "op" entries. Each entry is a structured object
+ containing references to attachments. Attachments are typically the serialized
+ package data split into separate chunks for bulk data, exports and header
+ information.
+ */
+class ProjectStore : public RefCounted, public GcStorage, public GcContributor
+{
+ struct OplogStorage;
+
+public:
+ ProjectStore(CidStore& Store, std::filesystem::path BasePath, GcManager& Gc);
+ ~ProjectStore();
+
+ struct Project;
+
+ struct Oplog
+ {
+ Oplog(std::string_view Id,
+ Project* Project,
+ CidStore& Store,
+ std::filesystem::path BasePath,
+ const std::filesystem::path& MarkerPath);
+ ~Oplog();
+
+ [[nodiscard]] static bool ExistsAt(std::filesystem::path BasePath);
+
+ void Read();
+ void Write();
+
+ void IterateFileMap(std::function<void(const Oid&, const std::string_view& ServerPath, const std::string_view& ClientPath)>&& Fn);
+ void IterateOplog(std::function<void(CbObject)>&& Fn);
+ void IterateOplogWithKey(std::function<void(int, const Oid&, CbObject)>&& Fn);
+ std::optional<CbObject> GetOpByKey(const Oid& Key);
+ std::optional<CbObject> GetOpByIndex(int Index);
+ int GetOpIndexByKey(const Oid& Key);
+
+ IoBuffer FindChunk(Oid ChunkId);
+
+ inline static const uint32_t kInvalidOp = ~0u;
+
+ /** Persist a new oplog entry
+ *
+ * Returns the oplog LSN assigned to the new entry, or kInvalidOp if the entry is rejected
+ */
+ uint32_t AppendNewOplogEntry(CbPackage Op);
+
+ uint32_t AppendNewOplogEntry(CbObject Core);
+
+ enum UpdateType
+ {
+ kUpdateNewEntry,
+ kUpdateReplay
+ };
+
+ const std::string& OplogId() const { return m_OplogId; }
+
+ const std::filesystem::path& TempPath() const { return m_TempPath; }
+ const std::filesystem::path& MarkerPath() const { return m_MarkerPath; }
+
+ spdlog::logger& Log() { return m_OuterProject->Log(); }
+ void Flush();
+ void Scrub(ScrubContext& Ctx) const;
+ void GatherReferences(GcContext& GcCtx);
+ uint64_t TotalSize() const;
+
+ std::size_t OplogCount() const
+ {
+ RwLock::SharedLockScope _(m_OplogLock);
+ return m_LatestOpMap.size();
+ }
+
+ bool IsExpired() const;
+ std::filesystem::path PrepareForDelete(bool MoveFolder);
+
+ private:
+ struct FileMapEntry
+ {
+ std::string ServerPath;
+ std::string ClientPath;
+ };
+
+ template<class V>
+ using OidMap = tsl::robin_map<Oid, V, Oid::Hasher>;
+
+ Project* m_OuterProject = nullptr;
+ CidStore& m_CidStore;
+ std::filesystem::path m_BasePath;
+ std::filesystem::path m_MarkerPath;
+ std::filesystem::path m_TempPath;
+
+ mutable RwLock m_OplogLock;
+ OidMap<IoHash> m_ChunkMap; // output data chunk id -> CAS address
+ OidMap<IoHash> m_MetaMap; // meta chunk id -> CAS address
+ OidMap<FileMapEntry> m_FileMap; // file id -> file map entry
+ int32_t m_ManifestVersion; // File system manifest version
+ tsl::robin_map<int, OplogEntryAddress> m_OpAddressMap; // Index LSN -> op data in ops blob file
+ OidMap<int> m_LatestOpMap; // op key -> latest op LSN for key
+
+ RefPtr<OplogStorage> m_Storage;
+ std::string m_OplogId;
+
+ /** Scan oplog and register each entry, thus updating the in-memory tracking tables
+ */
+ void ReplayLog();
+
+ struct OplogEntryMapping
+ {
+ struct Mapping
+ {
+ Oid Id;
+ IoHash Hash;
+ };
+ struct FileMapping : public Mapping
+ {
+ std::string ServerPath;
+ std::string ClientPath;
+ };
+ std::vector<Mapping> Chunks;
+ std::vector<Mapping> Meta;
+ std::vector<FileMapping> Files;
+ };
+
+ OplogEntryMapping GetMapping(CbObject Core);
+
+ /** Update tracking metadata for a new oplog entry
+ *
+ * This is used during replay (and gets called as part of new op append)
+ *
+ * Returns the oplog LSN assigned to the new entry, or kInvalidOp if the entry is rejected
+ */
+ uint32_t RegisterOplogEntry(RwLock::ExclusiveLockScope& OplogLock,
+ const OplogEntryMapping& OpMapping,
+ const OplogEntry& OpEntry,
+ UpdateType TypeOfUpdate);
+
+ void AddFileMapping(const RwLock::ExclusiveLockScope& OplogLock,
+ Oid FileId,
+ IoHash Hash,
+ std::string_view ServerPath,
+ std::string_view ClientPath);
+ void AddChunkMapping(const RwLock::ExclusiveLockScope& OplogLock, Oid ChunkId, IoHash Hash);
+ void AddMetaMapping(const RwLock::ExclusiveLockScope& OplogLock, Oid ChunkId, IoHash Hash);
+ };
+
+ struct Project : public RefCounted
+ {
+ std::string Identifier;
+ std::filesystem::path RootDir;
+ std::string EngineRootDir;
+ std::string ProjectRootDir;
+ std::string ProjectFilePath;
+
+ Oplog* NewOplog(std::string_view OplogId, const std::filesystem::path& MarkerPath);
+ Oplog* OpenOplog(std::string_view OplogId);
+ void DeleteOplog(std::string_view OplogId);
+ void IterateOplogs(std::function<void(const Oplog&)>&& Fn) const;
+ void IterateOplogs(std::function<void(Oplog&)>&& Fn);
+ std::vector<std::string> ScanForOplogs() const;
+ bool IsExpired() const;
+
+ Project(ProjectStore* PrjStore, CidStore& Store, std::filesystem::path BasePath);
+ virtual ~Project();
+
+ void Read();
+ void Write();
+ [[nodiscard]] static bool Exists(std::filesystem::path BasePath);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ spdlog::logger& Log();
+ void GatherReferences(GcContext& GcCtx);
+ uint64_t TotalSize() const;
+ bool PrepareForDelete(std::filesystem::path& OutDeletePath);
+
+ private:
+ ProjectStore* m_ProjectStore;
+ CidStore& m_CidStore;
+ mutable RwLock m_ProjectLock;
+ std::map<std::string, std::unique_ptr<Oplog>> m_Oplogs;
+ std::vector<std::unique_ptr<Oplog>> m_DeletedOplogs;
+ std::filesystem::path m_OplogStoragePath;
+
+ std::filesystem::path BasePathForOplog(std::string_view OplogId);
+ };
+
+ // Oplog* OpenProjectOplog(std::string_view ProjectId, std::string_view OplogId);
+
+ Ref<Project> OpenProject(std::string_view ProjectId);
+ Ref<Project> NewProject(std::filesystem::path BasePath,
+ std::string_view ProjectId,
+ std::string_view RootDir,
+ std::string_view EngineRootDir,
+ std::string_view ProjectRootDir,
+ std::string_view ProjectFilePath);
+ bool DeleteProject(std::string_view ProjectId);
+ bool Exists(std::string_view ProjectId);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ void DiscoverProjects();
+ void IterateProjects(std::function<void(Project& Prj)>&& Fn);
+
+ spdlog::logger& Log() { return m_Log; }
+ const std::filesystem::path& BasePath() const { return m_ProjectBasePath; }
+
+ virtual void GatherReferences(GcContext& GcCtx) override;
+ virtual void CollectGarbage(GcContext& GcCtx) override;
+ virtual GcStorageSize StorageSize() const override;
+
+ CbArray GetProjectsList();
+ std::pair<HttpResponseCode, std::string> GetProjectFiles(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ bool FilterClient,
+ CbObject& OutPayload);
+ std::pair<HttpResponseCode, std::string> GetChunkInfo(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view ChunkId,
+ CbObject& OutPayload);
+ std::pair<HttpResponseCode, std::string> GetChunkRange(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view ChunkId,
+ uint64_t Offset,
+ uint64_t Size,
+ ZenContentType AcceptType,
+ IoBuffer& OutChunk);
+ std::pair<HttpResponseCode, std::string> GetChunk(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view Cid,
+ ZenContentType AcceptType,
+ IoBuffer& OutChunk);
+
+ std::pair<HttpResponseCode, std::string> PutChunk(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const std::string_view Cid,
+ ZenContentType ContentType,
+ IoBuffer&& Chunk);
+
+ std::pair<HttpResponseCode, std::string> WriteOplog(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ IoBuffer&& Payload,
+ CbObject& OutResponse);
+
+ std::pair<HttpResponseCode, std::string> ReadOplog(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ const HttpServerRequest::QueryParams& Params,
+ CbObject& OutResponse);
+
+ std::pair<HttpResponseCode, std::string> WriteBlock(const std::string_view ProjectId,
+ const std::string_view OplogId,
+ IoBuffer&& Payload);
+
+ void Rpc(HttpServerRequest& HttpReq,
+ const std::string_view ProjectId,
+ const std::string_view OplogId,
+ IoBuffer&& Payload,
+ AuthMgr& AuthManager);
+
+ std::pair<HttpResponseCode, std::string> Export(ProjectStore::Project& Project,
+ ProjectStore::Oplog& Oplog,
+ CbObjectView&& Params,
+ AuthMgr& AuthManager);
+
+ std::pair<HttpResponseCode, std::string> Import(ProjectStore::Project& Project,
+ ProjectStore::Oplog& Oplog,
+ CbObjectView&& Params,
+ AuthMgr& AuthManager);
+
+private:
+ spdlog::logger& m_Log;
+ CidStore& m_CidStore;
+ std::filesystem::path m_ProjectBasePath;
+ mutable RwLock m_ProjectsLock;
+ std::map<std::string, Ref<Project>> m_Projects;
+
+ std::filesystem::path BasePathForProject(std::string_view ProjectId);
+};
+
+//////////////////////////////////////////////////////////////////////////
+//
+// {project} a project identifier
+// {target} a variation of the project, typically a build target
+// {lsn} oplog entry sequence number
+//
+// /prj/{project}
+// /prj/{project}/oplog/{target}
+// /prj/{project}/oplog/{target}/{lsn}
+//
+// oplog entry
+//
+// id: {id}
+// key: {}
+// meta: {}
+// data: []
+// refs:
+//
+
+class HttpProjectService : public HttpService, public IHttpStatsProvider
+{
+public:
+ HttpProjectService(CidStore& Store, ProjectStore* InProjectStore, HttpStatsService& StatsService, AuthMgr& AuthMgr);
+ ~HttpProjectService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+
+ virtual void HandleStatsRequest(HttpServerRequest& Request) override;
+
+private:
+ inline spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ CidStore& m_CidStore;
+ HttpRequestRouter m_Router;
+ Ref<ProjectStore> m_ProjectStore;
+ HttpStatsService& m_StatsService;
+ AuthMgr& m_AuthMgr;
+};
+
+void prj_forcelink();
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/remoteprojectstore.cpp b/src/zenserver/projectstore/remoteprojectstore.cpp
new file mode 100644
index 000000000..1e6ca51a1
--- /dev/null
+++ b/src/zenserver/projectstore/remoteprojectstore.cpp
@@ -0,0 +1,1036 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "remoteprojectstore.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compress.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/workthreadpool.h>
+#include <zenstore/cidstore.h>
+
+namespace zen {
+
+/*
+ OplogContainer
+ Binary("ops") // Compressed CompactBinary object to hide attachment references, also makes the oplog smaller
+ {
+ CbArray("ops")
+ {
+ CbObject Op
+ (CbFieldType::BinaryAttachment Attachments[])
+ (OpData)
+ }
+ }
+ CbArray("blocks")
+ CbObject
+ CbFieldType::BinaryAttachment "rawhash" // Optional, only if we are creating blocks (Jupiter/File)
+ CbArray("chunks")
+ CbFieldType::Hash // Chunk hashes
+ CbArray("chunks") // Optional, only if we are not creating blocks (Zen)
+ CbFieldType::BinaryAttachment // Chunk attachment hashes
+
+ CompressedBinary ChunkBlock
+ {
+ VarUInt ChunkCount
+ VarUInt ChunkSizes[ChunkCount]
+ uint8_t[chunksize])[ChunkCount]
+ }
+*/
+
+////////////////////////////// AsyncRemoteResult
+
+struct AsyncRemoteResult
+{
+ void SetError(int32_t ErrorCode, const std::string& ErrorReason, const std::string ErrorText)
+ {
+ int32_t Expected = 0;
+ if (m_ErrorCode.compare_exchange_weak(Expected, ErrorCode ? ErrorCode : -1))
+ {
+ m_ErrorReason = ErrorReason;
+ m_ErrorText = ErrorText;
+ }
+ }
+ bool IsError() const { return m_ErrorCode.load() != 0; }
+ int GetError() const { return m_ErrorCode.load(); };
+ const std::string& GetErrorReason() const { return m_ErrorReason; };
+ const std::string& GetErrorText() const { return m_ErrorText; };
+ RemoteProjectStore::Result ConvertResult(double ElapsedSeconds = 0.0) const
+ {
+ return RemoteProjectStore::Result{m_ErrorCode, ElapsedSeconds, m_ErrorReason, m_ErrorText};
+ }
+
+private:
+ std::atomic<int32_t> m_ErrorCode = 0;
+ std::string m_ErrorReason;
+ std::string m_ErrorText;
+};
+
+bool
+IterateBlock(IoBuffer&& CompressedBlock, std::function<void(CompressedBuffer&& Chunk, const IoHash& AttachmentHash)> Visitor)
+{
+ IoBuffer BlockPayload = CompressedBuffer::FromCompressedNoValidate(std::move(CompressedBlock)).Decompress().AsIoBuffer();
+
+ MemoryView BlockView = BlockPayload.GetView();
+ const uint8_t* ReadPtr = reinterpret_cast<const uint8_t*>(BlockView.GetData());
+ uint32_t NumberSize;
+ uint64_t ChunkCount = ReadVarUInt(ReadPtr, NumberSize);
+ ReadPtr += NumberSize;
+ std::vector<uint64_t> ChunkSizes;
+ ChunkSizes.reserve(ChunkCount);
+ while (ChunkCount--)
+ {
+ ChunkSizes.push_back(ReadVarUInt(ReadPtr, NumberSize));
+ ReadPtr += NumberSize;
+ }
+ ptrdiff_t TempBufferLength = std::distance(reinterpret_cast<const uint8_t*>(BlockView.GetData()), ReadPtr);
+ ZEN_ASSERT(TempBufferLength > 0);
+ for (uint64_t ChunkSize : ChunkSizes)
+ {
+ IoBuffer Chunk(IoBuffer::Wrap, ReadPtr, ChunkSize);
+ IoHash AttachmentRawHash;
+ uint64_t AttachmentRawSize;
+ CompressedBuffer CompressedChunk = CompressedBuffer::FromCompressed(SharedBuffer(Chunk), AttachmentRawHash, AttachmentRawSize);
+
+ if (!CompressedChunk)
+ {
+ ZEN_ERROR("Invalid chunk in block");
+ return false;
+ }
+ Visitor(std::move(CompressedChunk), AttachmentRawHash);
+ ReadPtr += ChunkSize;
+ ZEN_ASSERT(ReadPtr <= BlockView.GetDataEnd());
+ }
+ return true;
+};
+
+CompressedBuffer
+GenerateBlock(std::vector<SharedBuffer>&& Chunks)
+{
+ size_t ChunkCount = Chunks.size();
+ SharedBuffer SizeBuffer;
+ {
+ IoBuffer TempBuffer(ChunkCount * 9);
+ MutableMemoryView View = TempBuffer.GetMutableView();
+ uint8_t* BufferStartPtr = reinterpret_cast<uint8_t*>(View.GetData());
+ uint8_t* BufferEndPtr = BufferStartPtr;
+ BufferEndPtr += WriteVarUInt(gsl::narrow<uint64_t>(ChunkCount), BufferEndPtr);
+ auto It = Chunks.begin();
+ while (It != Chunks.end())
+ {
+ BufferEndPtr += WriteVarUInt(gsl::narrow<uint64_t>(It->GetSize()), BufferEndPtr);
+ It++;
+ }
+ ZEN_ASSERT(BufferEndPtr <= View.GetDataEnd());
+ ptrdiff_t TempBufferLength = std::distance(BufferStartPtr, BufferEndPtr);
+ SizeBuffer = SharedBuffer(IoBuffer(TempBuffer, 0, gsl::narrow<size_t>(TempBufferLength)));
+ }
+ CompositeBuffer AllBuffers(std::move(SizeBuffer), CompositeBuffer(std::move(Chunks)));
+
+ CompressedBuffer CompressedBlock =
+ CompressedBuffer::Compress(std::move(AllBuffers), OodleCompressor::Mermaid, OodleCompressionLevel::None);
+
+ return CompressedBlock;
+}
+
+struct Block
+{
+ IoHash BlockHash;
+ std::vector<IoHash> ChunksInBlock;
+};
+
+void
+CreateBlock(WorkerThreadPool& WorkerPool,
+ Latch& OpSectionsLatch,
+ std::vector<SharedBuffer>&& ChunksInBlock,
+ RwLock& SectionsLock,
+ std::vector<Block>& Blocks,
+ size_t BlockIndex,
+ const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock,
+ AsyncRemoteResult& RemoteResult)
+{
+ OpSectionsLatch.AddCount(1);
+ WorkerPool.ScheduleWork(
+ [&Blocks, &SectionsLock, &OpSectionsLatch, BlockIndex, Chunks = std::move(ChunksInBlock), &AsyncOnBlock, &RemoteResult]() mutable {
+ auto _ = MakeGuard([&OpSectionsLatch] { OpSectionsLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+ if (!Chunks.empty())
+ {
+ CompressedBuffer CompressedBlock = GenerateBlock(std::move(Chunks)); // Move to callback and return IoHash
+ IoHash BlockHash = CompressedBlock.DecodeRawHash();
+ AsyncOnBlock(std::move(CompressedBlock), BlockHash);
+ {
+ // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index
+ RwLock::SharedLockScope __(SectionsLock);
+ Blocks[BlockIndex].BlockHash = BlockHash;
+ }
+ }
+ });
+}
+
+size_t
+AddBlock(RwLock& BlocksLock, std::vector<Block>& Blocks)
+{
+ size_t BlockIndex;
+ {
+ RwLock::ExclusiveLockScope _(BlocksLock);
+ BlockIndex = Blocks.size();
+ Blocks.resize(BlockIndex + 1);
+ }
+ return BlockIndex;
+}
+
+CbObject
+BuildContainer(CidStore& ChunkStore,
+ ProjectStore::Oplog& Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize,
+ bool BuildBlocks,
+ WorkerThreadPool& WorkerPool,
+ const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock,
+ const std::function<void(const IoHash&)>& OnLargeAttachment,
+ const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks,
+ AsyncRemoteResult& RemoteResult)
+{
+ using namespace std::literals;
+
+ std::unordered_set<IoHash, IoHash::Hasher> LargeChunkHashes;
+ CbObjectWriter SectionOpsWriter;
+ SectionOpsWriter.BeginArray("ops"sv);
+
+ size_t OpCount = 0;
+
+ CbObject OplogContainerObject;
+ {
+ RwLock BlocksLock;
+ std::vector<Block> Blocks;
+ CompressedBuffer OpsBuffer;
+
+ Latch BlockCreateLatch(1);
+
+ std::unordered_set<IoHash, IoHash::Hasher> BlockAttachmentHashes;
+
+ size_t BlockSize = 0;
+ std::vector<SharedBuffer> ChunksInBlock;
+
+ std::unordered_set<IoHash, IoHash::Hasher> Attachments;
+ Oplog.IterateOplog([&Attachments, &SectionOpsWriter, &OpCount](CbObject Op) {
+ Op.IterateAttachments([&](CbFieldView FieldView) { Attachments.insert(FieldView.AsAttachment()); });
+ (SectionOpsWriter) << Op;
+ OpCount++;
+ });
+
+ for (const IoHash& AttachmentHash : Attachments)
+ {
+ IoBuffer Payload = ChunkStore.FindChunkByCid(AttachmentHash);
+ if (!Payload)
+ {
+ RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::NotFound),
+ fmt::format("Failed to find attachment {} for op", AttachmentHash),
+ {});
+ ZEN_ERROR("Failed to build container ({}). Reason: '{}'", RemoteResult.GetError(), RemoteResult.GetErrorReason());
+ return {};
+ }
+ uint64_t PayloadSize = Payload.GetSize();
+ if (PayloadSize > MaxChunkEmbedSize)
+ {
+ if (LargeChunkHashes.insert(AttachmentHash).second)
+ {
+ OnLargeAttachment(AttachmentHash);
+ }
+ continue;
+ }
+
+ if (!BlockAttachmentHashes.insert(AttachmentHash).second)
+ {
+ continue;
+ }
+
+ BlockSize += PayloadSize;
+ if (BuildBlocks)
+ {
+ ChunksInBlock.emplace_back(SharedBuffer(std::move(Payload)));
+ }
+ else
+ {
+ Payload = {};
+ }
+
+ if (BlockSize >= MaxBlockSize)
+ {
+ size_t BlockIndex = AddBlock(BlocksLock, Blocks);
+ if (BuildBlocks)
+ {
+ CreateBlock(WorkerPool,
+ BlockCreateLatch,
+ std::move(ChunksInBlock),
+ BlocksLock,
+ Blocks,
+ BlockIndex,
+ AsyncOnBlock,
+ RemoteResult);
+ }
+ else
+ {
+ OnBlockChunks(BlockAttachmentHashes);
+ }
+ {
+ // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index
+ RwLock::SharedLockScope _(BlocksLock);
+ Blocks[BlockIndex].ChunksInBlock.insert(Blocks[BlockIndex].ChunksInBlock.end(),
+ BlockAttachmentHashes.begin(),
+ BlockAttachmentHashes.end());
+ }
+ BlockAttachmentHashes.clear();
+ ChunksInBlock.clear();
+ BlockSize = 0;
+ }
+ }
+ if (BlockSize > 0)
+ {
+ size_t BlockIndex = AddBlock(BlocksLock, Blocks);
+ if (BuildBlocks)
+ {
+ CreateBlock(WorkerPool,
+ BlockCreateLatch,
+ std::move(ChunksInBlock),
+ BlocksLock,
+ Blocks,
+ BlockIndex,
+ AsyncOnBlock,
+ RemoteResult);
+ }
+ else
+ {
+ OnBlockChunks(BlockAttachmentHashes);
+ }
+ {
+ // We can share the lock as we are not resizing the vector and only touch BlockHash at our own index
+ RwLock::SharedLockScope _(BlocksLock);
+ Blocks[BlockIndex].ChunksInBlock.insert(Blocks[BlockIndex].ChunksInBlock.end(),
+ BlockAttachmentHashes.begin(),
+ BlockAttachmentHashes.end());
+ }
+ BlockAttachmentHashes.clear();
+ ChunksInBlock.clear();
+ BlockSize = 0;
+ }
+ SectionOpsWriter.EndArray(); // "ops"
+
+ CompressedBuffer CompressedOpsSection = CompressedBuffer::Compress(SectionOpsWriter.Save().GetBuffer());
+ ZEN_DEBUG("Added oplog section {}, {}", CompressedOpsSection.DecodeRawHash(), NiceBytes(CompressedOpsSection.GetCompressedSize()));
+
+ BlockCreateLatch.CountDown();
+ while (!BlockCreateLatch.Wait(1000))
+ {
+ ZEN_INFO("Creating blocks, {} remaining...", BlockCreateLatch.Remaining());
+ }
+
+ if (!RemoteResult.IsError())
+ {
+ CbObjectWriter OplogContinerWriter;
+ RwLock::SharedLockScope _(BlocksLock);
+ OplogContinerWriter.AddBinary("ops"sv, CompressedOpsSection.GetCompressed().Flatten().AsIoBuffer());
+
+ OplogContinerWriter.BeginArray("blocks"sv);
+ {
+ for (const Block& B : Blocks)
+ {
+ ZEN_ASSERT(!B.ChunksInBlock.empty());
+ if (BuildBlocks)
+ {
+ ZEN_ASSERT(B.BlockHash != IoHash::Zero);
+
+ OplogContinerWriter.BeginObject();
+ {
+ OplogContinerWriter.AddBinaryAttachment("rawhash"sv, B.BlockHash);
+ OplogContinerWriter.BeginArray("chunks"sv);
+ {
+ for (const IoHash& RawHash : B.ChunksInBlock)
+ {
+ OplogContinerWriter.AddHash(RawHash);
+ }
+ }
+ OplogContinerWriter.EndArray(); // "chunks"
+ }
+ OplogContinerWriter.EndObject();
+ continue;
+ }
+
+ ZEN_ASSERT(B.BlockHash == IoHash::Zero);
+ OplogContinerWriter.BeginObject();
+ {
+ OplogContinerWriter.BeginArray("chunks"sv);
+ {
+ for (const IoHash& RawHash : B.ChunksInBlock)
+ {
+ OplogContinerWriter.AddBinaryAttachment(RawHash);
+ }
+ }
+ OplogContinerWriter.EndArray();
+ }
+ OplogContinerWriter.EndObject();
+ }
+ }
+ OplogContinerWriter.EndArray(); // "blocks"sv
+
+ OplogContinerWriter.BeginArray("chunks"sv);
+ {
+ for (const IoHash& AttachmentHash : LargeChunkHashes)
+ {
+ OplogContinerWriter.AddBinaryAttachment(AttachmentHash);
+ }
+ }
+ OplogContinerWriter.EndArray(); // "chunks"
+
+ OplogContainerObject = OplogContinerWriter.Save();
+ }
+ }
+ return OplogContainerObject;
+}
+
+RemoteProjectStore::LoadContainerResult
+BuildContainer(CidStore& ChunkStore,
+ ProjectStore::Oplog& Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize,
+ bool BuildBlocks,
+ const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock,
+ const std::function<void(const IoHash&)>& OnLargeAttachment,
+ const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks)
+{
+ // We are creating a worker thread pool here since we are uploading a lot of attachments in one go and we dont want to keep a
+ // WorkerThreadPool alive
+ size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u);
+ WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount));
+
+ AsyncRemoteResult RemoteResult;
+ CbObject ContainerObject = BuildContainer(ChunkStore,
+ Oplog,
+ MaxBlockSize,
+ MaxChunkEmbedSize,
+ BuildBlocks,
+ WorkerPool,
+ AsyncOnBlock,
+ OnLargeAttachment,
+ OnBlockChunks,
+ RemoteResult);
+ return RemoteProjectStore::LoadContainerResult{RemoteResult.ConvertResult(), ContainerObject};
+}
+
+RemoteProjectStore::Result
+SaveOplog(CidStore& ChunkStore,
+ RemoteProjectStore& RemoteStore,
+ ProjectStore::Oplog& Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize,
+ bool BuildBlocks,
+ bool UseTempBlocks,
+ bool ForceUpload)
+{
+ using namespace std::literals;
+
+ Stopwatch Timer;
+
+ // We are creating a worker thread pool here since we are uploading a lot of attachments in one go
+ // Doing upload is a rare and transient occation so we don't want to keep a WorkerThreadPool alive.
+ size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u);
+ WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount));
+
+ std::filesystem::path AttachmentTempPath;
+ if (UseTempBlocks)
+ {
+ AttachmentTempPath = Oplog.TempPath();
+ AttachmentTempPath.append(".pending");
+ CreateDirectories(AttachmentTempPath);
+ }
+
+ AsyncRemoteResult RemoteResult;
+ RwLock AttachmentsLock;
+ std::unordered_set<IoHash, IoHash::Hasher> LargeAttachments;
+ std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> CreatedBlocks;
+
+ auto MakeTempBlock = [AttachmentTempPath, &RemoteResult, &AttachmentsLock, &CreatedBlocks](CompressedBuffer&& CompressedBlock,
+ const IoHash& BlockHash) {
+ std::filesystem::path BlockPath = AttachmentTempPath;
+ BlockPath.append(BlockHash.ToHexString());
+ if (!std::filesystem::exists(BlockPath))
+ {
+ IoBuffer BlockBuffer;
+ try
+ {
+ BasicFile BlockFile;
+ BlockFile.Open(BlockPath, BasicFile::Mode::kTruncateDelete);
+ uint64_t Offset = 0;
+ for (const SharedBuffer& Buffer : CompressedBlock.GetCompressed().GetSegments())
+ {
+ BlockFile.Write(Buffer.GetView(), Offset);
+ Offset += Buffer.GetSize();
+ }
+ void* FileHandle = BlockFile.Detach();
+ BlockBuffer = IoBuffer(IoBuffer::File, FileHandle, 0, Offset);
+ }
+ catch (std::exception& Ex)
+ {
+ RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
+ Ex.what(),
+ "Unable to create temp block file");
+ return;
+ }
+
+ BlockBuffer.MarkAsDeleteOnClose();
+ {
+ RwLock::ExclusiveLockScope __(AttachmentsLock);
+ CreatedBlocks.insert({BlockHash, std::move(BlockBuffer)});
+ }
+ ZEN_DEBUG("Saved temp block {}, {}", BlockHash, NiceBytes(CompressedBlock.GetCompressedSize()));
+ }
+ };
+
+ auto UploadBlock = [&RemoteStore, &RemoteResult](CompressedBuffer&& CompressedBlock, const IoHash& BlockHash) {
+ RemoteProjectStore::SaveAttachmentResult Result = RemoteStore.SaveAttachment(CompressedBlock.GetCompressed(), BlockHash);
+ if (Result.ErrorCode)
+ {
+ RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
+ ZEN_ERROR("Failed to save attachment ({}). Reason: '{}'", RemoteResult.GetErrorReason(), RemoteResult.GetError());
+ return;
+ }
+ ZEN_DEBUG("Saved block {}, {}", BlockHash, NiceBytes(CompressedBlock.GetCompressedSize()));
+ };
+
+ std::vector<std::vector<IoHash>> BlockChunks;
+ auto OnBlockChunks = [&BlockChunks](const std::unordered_set<IoHash, IoHash::Hasher>& Chunks) {
+ BlockChunks.push_back({Chunks.begin(), Chunks.end()});
+ ZEN_DEBUG("Found {} block chunks", Chunks.size());
+ };
+
+ auto OnLargeAttachment = [&AttachmentsLock, &LargeAttachments](const IoHash& AttachmentHash) {
+ {
+ RwLock::ExclusiveLockScope _(AttachmentsLock);
+ LargeAttachments.insert(AttachmentHash);
+ }
+ ZEN_DEBUG("Found attachment {}", AttachmentHash);
+ };
+
+ std::function<void(CompressedBuffer&&, const IoHash&)> OnBlock;
+ if (UseTempBlocks)
+ {
+ OnBlock = MakeTempBlock;
+ }
+ else
+ {
+ OnBlock = UploadBlock;
+ }
+
+ CbObject OplogContainerObject = BuildContainer(ChunkStore,
+ Oplog,
+ MaxBlockSize,
+ MaxChunkEmbedSize,
+ BuildBlocks,
+ WorkerPool,
+ OnBlock,
+ OnLargeAttachment,
+ OnBlockChunks,
+ RemoteResult);
+
+ if (!RemoteResult.IsError())
+ {
+ uint64_t ChunkCount = OplogContainerObject["chunks"sv].AsArrayView().Num();
+ uint64_t BlockCount = OplogContainerObject["blocks"sv].AsArrayView().Num();
+ ZEN_INFO("Saving oplog container with {} attachments and {} blocks...", ChunkCount, BlockCount);
+ RemoteProjectStore::SaveResult ContainerSaveResult = RemoteStore.SaveContainer(OplogContainerObject.GetBuffer().AsIoBuffer());
+ if (ContainerSaveResult.ErrorCode)
+ {
+ RemoteResult.SetError(ContainerSaveResult.ErrorCode, ContainerSaveResult.Reason, "Failed to save oplog container");
+ ZEN_ERROR("Failed to save oplog container ({}). Reason: '{}'", RemoteResult.GetErrorReason(), RemoteResult.GetError());
+ }
+ ZEN_DEBUG("Saved container in {}", NiceTimeSpanMs(static_cast<uint64_t>(ContainerSaveResult.ElapsedSeconds * 1000)));
+ if (!ContainerSaveResult.Needs.empty())
+ {
+ ZEN_INFO("Filtering needed attachments...");
+ std::vector<IoHash> NeededLargeAttachments;
+ std::unordered_set<IoHash, IoHash::Hasher> NeededOtherAttachments;
+ NeededLargeAttachments.reserve(LargeAttachments.size());
+ NeededOtherAttachments.reserve(CreatedBlocks.size());
+ if (ForceUpload)
+ {
+ NeededLargeAttachments.insert(NeededLargeAttachments.end(), LargeAttachments.begin(), LargeAttachments.end());
+ }
+ else
+ {
+ for (const IoHash& RawHash : ContainerSaveResult.Needs)
+ {
+ if (LargeAttachments.contains(RawHash))
+ {
+ NeededLargeAttachments.push_back(RawHash);
+ continue;
+ }
+ NeededOtherAttachments.insert(RawHash);
+ }
+ }
+
+ Latch SaveAttachmentsLatch(1);
+ if (!NeededLargeAttachments.empty())
+ {
+ ZEN_INFO("Saving large attachments...");
+ for (const IoHash& RawHash : NeededLargeAttachments)
+ {
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+ SaveAttachmentsLatch.AddCount(1);
+ WorkerPool.ScheduleWork([&ChunkStore, &RemoteStore, &SaveAttachmentsLatch, &RemoteResult, RawHash, &CreatedBlocks]() {
+ auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+
+ IoBuffer Payload;
+ if (auto It = CreatedBlocks.find(RawHash); It != CreatedBlocks.end())
+ {
+ Payload = std::move(It->second);
+ }
+ else
+ {
+ Payload = ChunkStore.FindChunkByCid(RawHash);
+ }
+ if (!Payload)
+ {
+ RemoteResult.SetError(gsl::narrow<int>(HttpResponseCode::NotFound),
+ fmt::format("Failed to find attachment {}", RawHash),
+ {});
+ ZEN_ERROR("Failed to build container ({}). Reason: '{}'",
+ RemoteResult.GetErrorReason(),
+ RemoteResult.GetError());
+ return;
+ }
+
+ RemoteProjectStore::SaveAttachmentResult Result =
+ RemoteStore.SaveAttachment(CompositeBuffer(SharedBuffer(Payload)), RawHash);
+ if (Result.ErrorCode)
+ {
+ RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
+ ZEN_ERROR("Failed to save attachment '{}', {} ({}). Reason: '{}'",
+ RawHash,
+ NiceBytes(Payload.GetSize()),
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+ ZEN_DEBUG("Saved attachment {}, {} in {}",
+ RawHash,
+ NiceBytes(Payload.GetSize()),
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ return;
+ });
+ }
+ }
+
+ if (!CreatedBlocks.empty())
+ {
+ ZEN_INFO("Saving created block attachments...");
+ for (auto& It : CreatedBlocks)
+ {
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+ const IoHash& RawHash = It.first;
+ if (ForceUpload || NeededOtherAttachments.contains(RawHash))
+ {
+ IoBuffer Payload = It.second;
+ ZEN_ASSERT(Payload);
+ SaveAttachmentsLatch.AddCount(1);
+ WorkerPool.ScheduleWork(
+ [&ChunkStore, &RemoteStore, &SaveAttachmentsLatch, &RemoteResult, Payload = std::move(Payload), RawHash]() {
+ auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+
+ RemoteProjectStore::SaveAttachmentResult Result =
+ RemoteStore.SaveAttachment(CompositeBuffer(SharedBuffer(Payload)), RawHash);
+ if (Result.ErrorCode)
+ {
+ RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
+ ZEN_ERROR("Failed to save attachment '{}', {} ({}). Reason: '{}'",
+ RawHash,
+ NiceBytes(Payload.GetSize()),
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+
+ ZEN_DEBUG("Saved attachment {}, {} in {}",
+ RawHash,
+ NiceBytes(Payload.GetSize()),
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ return;
+ });
+ }
+ It.second = {};
+ }
+ }
+
+ if (!BlockChunks.empty())
+ {
+ ZEN_INFO("Saving chunk block attachments...");
+ for (const std::vector<IoHash>& Chunks : BlockChunks)
+ {
+ if (RemoteResult.IsError())
+ {
+ break;
+ }
+ std::vector<IoHash> NeededChunks;
+ if (ForceUpload)
+ {
+ NeededChunks = Chunks;
+ }
+ else
+ {
+ NeededChunks.reserve(Chunks.size());
+ for (const IoHash& Chunk : Chunks)
+ {
+ if (NeededOtherAttachments.contains(Chunk))
+ {
+ NeededChunks.push_back(Chunk);
+ }
+ }
+ if (NeededChunks.empty())
+ {
+ continue;
+ }
+ }
+ SaveAttachmentsLatch.AddCount(1);
+ WorkerPool.ScheduleWork([&RemoteStore,
+ &ChunkStore,
+ &SaveAttachmentsLatch,
+ &RemoteResult,
+ &Chunks,
+ NeededChunks = std::move(NeededChunks),
+ ForceUpload]() {
+ auto _ = MakeGuard([&SaveAttachmentsLatch] { SaveAttachmentsLatch.CountDown(); });
+ std::vector<SharedBuffer> ChunkBuffers;
+ ChunkBuffers.reserve(NeededChunks.size());
+ for (const IoHash& Chunk : NeededChunks)
+ {
+ IoBuffer ChunkPayload = ChunkStore.FindChunkByCid(Chunk);
+ if (!ChunkPayload)
+ {
+ RemoteResult.SetError(static_cast<int32_t>(HttpResponseCode::NotFound),
+ fmt::format("Missing chunk {}"sv, Chunk),
+ fmt::format("Unable to fetch attachment {} required by the oplog"sv, Chunk));
+ ChunkBuffers.clear();
+ break;
+ }
+ ChunkBuffers.emplace_back(SharedBuffer(std::move(ChunkPayload)));
+ }
+ RemoteProjectStore::SaveAttachmentsResult Result = RemoteStore.SaveAttachments(ChunkBuffers);
+ if (Result.ErrorCode)
+ {
+ RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
+ ZEN_ERROR("Failed to save attachments with {} chunks ({}). Reason: '{}'",
+ Chunks.size(),
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+ ZEN_DEBUG("Saved {} bulk attachments in {}",
+ Chunks.size(),
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ });
+ }
+ }
+ SaveAttachmentsLatch.CountDown();
+ while (!SaveAttachmentsLatch.Wait(1000))
+ {
+ ZEN_INFO("Saving attachments, {} remaining...", SaveAttachmentsLatch.Remaining());
+ }
+ SaveAttachmentsLatch.Wait();
+ }
+
+ if (!RemoteResult.IsError())
+ {
+ ZEN_INFO("Finalizing oplog container...");
+ RemoteProjectStore::Result ContainerFinalizeResult = RemoteStore.FinalizeContainer(ContainerSaveResult.RawHash);
+ if (ContainerFinalizeResult.ErrorCode)
+ {
+ RemoteResult.SetError(ContainerFinalizeResult.ErrorCode, ContainerFinalizeResult.Reason, ContainerFinalizeResult.Text);
+ ZEN_ERROR("Failed to finalize oplog container {} ({}). Reason: '{}'",
+ ContainerSaveResult.RawHash,
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ }
+ ZEN_DEBUG("Finalized container in {}", NiceTimeSpanMs(static_cast<uint64_t>(ContainerFinalizeResult.ElapsedSeconds * 1000)));
+ }
+ }
+
+ RemoteProjectStore::Result Result = RemoteResult.ConvertResult();
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ ZEN_INFO("Saved oplog {} in {}",
+ RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE",
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ return Result;
+};
+
+RemoteProjectStore::Result
+SaveOplogContainer(ProjectStore::Oplog& Oplog,
+ const CbObject& ContainerObject,
+ const std::function<bool(const IoHash& RawHash)>& HasAttachment,
+ const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock,
+ const std::function<void(const IoHash& RawHash)>& OnNeedAttachment)
+{
+ using namespace std::literals;
+
+ Stopwatch Timer;
+
+ CbArrayView LargeChunksArray = ContainerObject["chunks"sv].AsArrayView();
+ for (CbFieldView LargeChunksField : LargeChunksArray)
+ {
+ IoHash AttachmentHash = LargeChunksField.AsBinaryAttachment();
+ if (HasAttachment(AttachmentHash))
+ {
+ continue;
+ }
+ OnNeedAttachment(AttachmentHash);
+ };
+
+ CbArrayView BlocksArray = ContainerObject["blocks"sv].AsArrayView();
+ for (CbFieldView BlockField : BlocksArray)
+ {
+ CbObjectView BlockView = BlockField.AsObjectView();
+ IoHash BlockHash = BlockView["rawhash"sv].AsBinaryAttachment();
+
+ CbArrayView ChunksArray = BlockView["chunks"sv].AsArrayView();
+ if (BlockHash == IoHash::Zero)
+ {
+ std::vector<IoHash> NeededChunks;
+ NeededChunks.reserve(ChunksArray.GetSize());
+ for (CbFieldView ChunkField : ChunksArray)
+ {
+ IoHash ChunkHash = ChunkField.AsBinaryAttachment();
+ if (HasAttachment(ChunkHash))
+ {
+ continue;
+ }
+ NeededChunks.emplace_back(ChunkHash);
+ }
+
+ if (!NeededChunks.empty())
+ {
+ OnNeedBlock(IoHash::Zero, std::move(NeededChunks));
+ }
+ continue;
+ }
+
+ for (CbFieldView ChunkField : ChunksArray)
+ {
+ IoHash ChunkHash = ChunkField.AsHash();
+ if (HasAttachment(ChunkHash))
+ {
+ continue;
+ }
+
+ OnNeedBlock(BlockHash, {});
+ break;
+ }
+ };
+
+ MemoryView OpsSection = ContainerObject["ops"sv].AsBinaryView();
+ IoBuffer OpsBuffer(IoBuffer::Wrap, OpsSection.GetData(), OpsSection.GetSize());
+ IoBuffer SectionPayload = CompressedBuffer::FromCompressedNoValidate(std::move(OpsBuffer)).Decompress().AsIoBuffer();
+
+ CbObject SectionObject = LoadCompactBinaryObject(SectionPayload);
+ if (!SectionObject)
+ {
+ ZEN_ERROR("Failed to save oplog container. Reason: '{}'", "Section has unexpected data type");
+ return RemoteProjectStore::Result{gsl::narrow<int>(HttpResponseCode::BadRequest),
+ Timer.GetElapsedTimeMs() / 1000.500,
+ "Section has unexpected data type",
+ "Failed to save oplog container"};
+ }
+
+ CbArrayView OpsArray = SectionObject["ops"sv].AsArrayView();
+ for (CbFieldView OpEntry : OpsArray)
+ {
+ CbObjectView Core = OpEntry.AsObjectView();
+ BinaryWriter Writer;
+ Core.CopyTo(Writer);
+ MemoryView OpView = Writer.GetView();
+ IoBuffer OpBuffer(IoBuffer::Wrap, OpView.GetData(), OpView.GetSize());
+ CbObject Op(SharedBuffer(OpBuffer), CbFieldType::HasFieldType);
+ const uint32_t OpLsn = Oplog.AppendNewOplogEntry(Op);
+ if (OpLsn == ProjectStore::Oplog::kInvalidOp)
+ {
+ return RemoteProjectStore::Result{gsl::narrow<int>(HttpResponseCode::BadRequest),
+ Timer.GetElapsedTimeMs() / 1000.500,
+ "Failed saving op",
+ "Failed to save oplog container"};
+ }
+ ZEN_DEBUG("oplog entry #{}", OpLsn);
+ }
+ return RemoteProjectStore::Result{.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500};
+}
+
+RemoteProjectStore::Result
+LoadOplog(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, ProjectStore::Oplog& Oplog, bool ForceDownload)
+{
+ using namespace std::literals;
+
+ Stopwatch Timer;
+
+ // We are creating a worker thread pool here since we are download a lot of attachments in one go and we dont want to keep a
+ // WorkerThreadPool alive
+ size_t WorkerCount = Min(std::thread::hardware_concurrency(), 16u);
+ WorkerThreadPool WorkerPool(gsl::narrow<int>(WorkerCount));
+
+ std::unordered_set<IoHash, IoHash::Hasher> Attachments;
+ std::vector<std::vector<IoHash>> ChunksInBlocks;
+
+ RemoteProjectStore::LoadContainerResult LoadContainerResult = RemoteStore.LoadContainer();
+ if (LoadContainerResult.ErrorCode)
+ {
+ ZEN_WARN("Failed to load oplog container, reason: '{}', error code: {}", LoadContainerResult.Reason, LoadContainerResult.ErrorCode);
+ return RemoteProjectStore::Result{.ErrorCode = LoadContainerResult.ErrorCode,
+ .ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500,
+ .Reason = LoadContainerResult.Reason,
+ .Text = LoadContainerResult.Text};
+ }
+ ZEN_DEBUG("Loaded container in {}", NiceTimeSpanMs(static_cast<uint64_t>(LoadContainerResult.ElapsedSeconds * 1000)));
+
+ AsyncRemoteResult RemoteResult;
+ Latch AttachmentsWorkLatch(1);
+
+ auto HasAttachment = [&ChunkStore, ForceDownload](const IoHash& RawHash) {
+ return !ForceDownload && ChunkStore.ContainsChunk(RawHash);
+ };
+ auto OnNeedBlock = [&RemoteStore, &ChunkStore, &WorkerPool, &ChunksInBlocks, &AttachmentsWorkLatch, &RemoteResult](
+ const IoHash& BlockHash,
+ std::vector<IoHash>&& Chunks) {
+ if (BlockHash == IoHash::Zero)
+ {
+ AttachmentsWorkLatch.AddCount(1);
+ WorkerPool.ScheduleWork([&RemoteStore, &ChunkStore, &AttachmentsWorkLatch, &RemoteResult, Chunks = std::move(Chunks)]() {
+ auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+
+ RemoteProjectStore::LoadAttachmentsResult Result = RemoteStore.LoadAttachments(Chunks);
+ if (Result.ErrorCode)
+ {
+ RemoteResult.SetError(Result.ErrorCode, Result.Reason, Result.Text);
+ ZEN_ERROR("Failed to attachments with {} chunks ({}). Reason: '{}'",
+ Chunks.size(),
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+ ZEN_DEBUG("Loaded {} bulk attachments in {}",
+ Chunks.size(),
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000)));
+ for (const auto& It : Result.Chunks)
+ {
+ ChunkStore.AddChunk(It.second.GetCompressed().Flatten().AsIoBuffer(), It.first, CidStore::InsertMode::kCopyOnly);
+ }
+ });
+ return;
+ }
+ AttachmentsWorkLatch.AddCount(1);
+ WorkerPool.ScheduleWork([&AttachmentsWorkLatch, &ChunkStore, &RemoteStore, BlockHash, &RemoteResult]() {
+ auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+ RemoteProjectStore::LoadAttachmentResult BlockResult = RemoteStore.LoadAttachment(BlockHash);
+ if (BlockResult.ErrorCode)
+ {
+ RemoteResult.SetError(BlockResult.ErrorCode, BlockResult.Reason, BlockResult.Text);
+ ZEN_ERROR("Failed to load oplog container, missing attachment {} ({}). Reason: '{}'",
+ BlockHash,
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+ ZEN_DEBUG("Loaded block attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(BlockResult.ElapsedSeconds * 1000)));
+
+ if (!IterateBlock(std::move(BlockResult.Bytes), [&ChunkStore](CompressedBuffer&& Chunk, const IoHash& AttachmentRawHash) {
+ ChunkStore.AddChunk(Chunk.GetCompressed().Flatten().AsIoBuffer(), AttachmentRawHash);
+ }))
+ {
+ RemoteResult.SetError(gsl::narrow<int32_t>(HttpResponseCode::InternalServerError),
+ fmt::format("Invalid format for block {}", BlockHash),
+ {});
+ ZEN_ERROR("Failed to load oplog container, attachment {} has invalid format ({}). Reason: '{}'",
+ BlockHash,
+ RemoteResult.GetError(),
+ RemoteResult.GetErrorReason());
+ return;
+ }
+ });
+ };
+
+ auto OnNeedAttachment =
+ [&RemoteStore, &ChunkStore, &WorkerPool, &AttachmentsWorkLatch, &RemoteResult, &Attachments](const IoHash& RawHash) {
+ if (!Attachments.insert(RawHash).second)
+ {
+ return;
+ }
+
+ AttachmentsWorkLatch.AddCount(1);
+ WorkerPool.ScheduleWork([&RemoteStore, &ChunkStore, &RemoteResult, &AttachmentsWorkLatch, RawHash]() {
+ auto _ = MakeGuard([&AttachmentsWorkLatch] { AttachmentsWorkLatch.CountDown(); });
+ if (RemoteResult.IsError())
+ {
+ return;
+ }
+ RemoteProjectStore::LoadAttachmentResult AttachmentResult = RemoteStore.LoadAttachment(RawHash);
+ if (AttachmentResult.ErrorCode)
+ {
+ RemoteResult.SetError(AttachmentResult.ErrorCode, AttachmentResult.Reason, AttachmentResult.Text);
+ ZEN_ERROR("Failed to download attachment {}, reason: '{}', error code: {}",
+ RawHash,
+ AttachmentResult.Reason,
+ AttachmentResult.ErrorCode);
+ return;
+ }
+ ZEN_DEBUG("Loaded attachment in {}", NiceTimeSpanMs(static_cast<uint64_t>(AttachmentResult.ElapsedSeconds * 1000)));
+ ChunkStore.AddChunk(AttachmentResult.Bytes, RawHash);
+ });
+ };
+
+ RemoteProjectStore::Result Result =
+ SaveOplogContainer(Oplog, LoadContainerResult.ContainerObject, HasAttachment, OnNeedBlock, OnNeedAttachment);
+
+ AttachmentsWorkLatch.CountDown();
+ while (!AttachmentsWorkLatch.Wait(1000))
+ {
+ ZEN_INFO("Loading attachments, {} remaining...", AttachmentsWorkLatch.Remaining());
+ }
+ AttachmentsWorkLatch.Wait();
+ if (Result.ErrorCode == 0)
+ {
+ Result = RemoteResult.ConvertResult();
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+
+ ZEN_INFO("Loaded oplog {} in {}",
+ RemoteResult.GetError() == 0 ? "SUCCESS" : "FAILURE",
+ NiceTimeSpanMs(static_cast<uint64_t>(Result.ElapsedSeconds * 1000.0)));
+
+ return Result;
+}
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/remoteprojectstore.h b/src/zenserver/projectstore/remoteprojectstore.h
new file mode 100644
index 000000000..dcabaedd4
--- /dev/null
+++ b/src/zenserver/projectstore/remoteprojectstore.h
@@ -0,0 +1,111 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "projectstore.h"
+
+#include <unordered_set>
+
+namespace zen {
+
+class CidStore;
+class WorkerThreadPool;
+
+class RemoteProjectStore
+{
+public:
+ struct Result
+ {
+ int32_t ErrorCode{};
+ double ElapsedSeconds{};
+ std::string Reason;
+ std::string Text;
+ };
+
+ struct SaveResult : public Result
+ {
+ std::unordered_set<IoHash, IoHash::Hasher> Needs;
+ IoHash RawHash;
+ };
+
+ struct SaveAttachmentResult : public Result
+ {
+ };
+
+ struct SaveAttachmentsResult : public Result
+ {
+ };
+
+ struct LoadAttachmentResult : public Result
+ {
+ IoBuffer Bytes;
+ };
+
+ struct LoadContainerResult : public Result
+ {
+ CbObject ContainerObject;
+ };
+
+ struct LoadAttachmentsResult : public Result
+ {
+ std::vector<std::pair<IoHash, CompressedBuffer>> Chunks;
+ };
+
+ struct RemoteStoreInfo
+ {
+ bool CreateBlocks;
+ bool UseTempBlockFiles;
+ std::string Description;
+ };
+
+ virtual ~RemoteProjectStore() {}
+
+ virtual RemoteStoreInfo GetInfo() const = 0;
+
+ virtual SaveResult SaveContainer(const IoBuffer& Payload) = 0;
+ virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) = 0;
+ virtual Result FinalizeContainer(const IoHash& RawHash) = 0;
+ virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Payloads) = 0;
+
+ virtual LoadContainerResult LoadContainer() = 0;
+ virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) = 0;
+ virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) = 0;
+};
+
+struct RemoteStoreOptions
+{
+ size_t MaxBlockSize = 128u * 1024u * 1024u;
+ size_t MaxChunkEmbedSize = 1024u * 1024u;
+};
+
+RemoteProjectStore::LoadContainerResult BuildContainer(
+ CidStore& ChunkStore,
+ ProjectStore::Oplog& Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize,
+ bool BuildBlocks,
+ const std::function<void(CompressedBuffer&&, const IoHash&)>& AsyncOnBlock,
+ const std::function<void(const IoHash&)>& OnLargeAttachment,
+ const std::function<void(const std::unordered_set<IoHash, IoHash::Hasher>)>& OnBlockChunks);
+
+RemoteProjectStore::Result SaveOplogContainer(ProjectStore::Oplog& Oplog,
+ const CbObject& ContainerObject,
+ const std::function<bool(const IoHash& RawHash)>& HasAttachment,
+ const std::function<void(const IoHash& BlockHash, std::vector<IoHash>&& Chunks)>& OnNeedBlock,
+ const std::function<void(const IoHash& RawHash)>& OnNeedAttachment);
+
+RemoteProjectStore::Result SaveOplog(CidStore& ChunkStore,
+ RemoteProjectStore& RemoteStore,
+ ProjectStore::Oplog& Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize,
+ bool BuildBlocks,
+ bool UseTempBlocks,
+ bool ForceUpload);
+
+RemoteProjectStore::Result LoadOplog(CidStore& ChunkStore, RemoteProjectStore& RemoteStore, ProjectStore::Oplog& Oplog, bool ForceDownload);
+
+CompressedBuffer GenerateBlock(std::vector<SharedBuffer>&& Chunks);
+bool IterateBlock(IoBuffer&& CompressedBlock, std::function<void(CompressedBuffer&& Chunk, const IoHash& AttachmentHash)> Visitor);
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/zenremoteprojectstore.cpp b/src/zenserver/projectstore/zenremoteprojectstore.cpp
new file mode 100644
index 000000000..6ff471ae5
--- /dev/null
+++ b/src/zenserver/projectstore/zenremoteprojectstore.cpp
@@ -0,0 +1,341 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenremoteprojectstore.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/fmtutils.h>
+#include <zencore/scopeguard.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zenhttp/httpshared.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+using namespace std::literals;
+
+class ZenRemoteStore : public RemoteProjectStore
+{
+public:
+ ZenRemoteStore(std::string_view HostAddress,
+ std::string_view Project,
+ std::string_view Oplog,
+ size_t MaxBlockSize,
+ size_t MaxChunkEmbedSize)
+ : m_HostAddress(HostAddress)
+ , m_ProjectStoreUrl(fmt::format("{}/prj"sv, m_HostAddress))
+ , m_Project(Project)
+ , m_Oplog(Oplog)
+ , m_MaxBlockSize(MaxBlockSize)
+ , m_MaxChunkEmbedSize(MaxChunkEmbedSize)
+ {
+ }
+
+ virtual RemoteStoreInfo GetInfo() const override
+ {
+ return {.CreateBlocks = false, .UseTempBlockFiles = false, .Description = fmt::format("[zen] {}"sv, m_HostAddress)};
+ }
+
+ virtual SaveResult SaveContainer(const IoBuffer& Payload) override
+ {
+ Stopwatch Timer;
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+
+ std::string SaveRequest = fmt::format("{}/{}/oplog/{}/save"sv, m_ProjectStoreUrl, m_Project, m_Oplog);
+ Session->SetUrl({SaveRequest});
+ Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbObject))}});
+ MemoryView Data(Payload.GetView());
+ Session->SetBody({reinterpret_cast<const char*>(Data.GetData()), Data.GetSize()});
+ cpr::Response Response = Session->Post();
+ SaveResult Result = SaveResult{ConvertResult(Response)};
+
+ if (Result.ErrorCode)
+ {
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ IoBuffer ResponsePayload(IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ CbObject ResponseObject = LoadCompactBinaryObject(ResponsePayload);
+ if (!ResponseObject)
+ {
+ Result.Reason = fmt::format("The response for {}/{}/{} is not formatted as a compact binary object"sv,
+ m_ProjectStoreUrl,
+ m_Project,
+ m_Oplog);
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ CbArrayView NeedsArray = ResponseObject["need"sv].AsArrayView();
+ for (CbFieldView FieldView : NeedsArray)
+ {
+ IoHash ChunkHash = FieldView.AsHash();
+ Result.Needs.insert(ChunkHash);
+ }
+
+ Result.RawHash = IoHash::HashBuffer(Payload);
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual SaveAttachmentResult SaveAttachment(const CompositeBuffer& Payload, const IoHash& RawHash) override
+ {
+ Stopwatch Timer;
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+
+ std::string SaveRequest = fmt::format("{}/{}/oplog/{}/{}"sv, m_ProjectStoreUrl, m_Project, m_Oplog, RawHash);
+ Session->SetUrl({SaveRequest});
+ Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCompressedBinary))}});
+ uint64_t SizeLeft = Payload.GetSize();
+ CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0);
+ auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, SizeLeft);
+ MutableMemoryView Data(buffer, size);
+ Payload.CopyTo(Data, BufferIt);
+ SizeLeft -= size;
+ return true;
+ };
+ Session->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback));
+ cpr::Response Response = Session->Post();
+ SaveAttachmentResult Result = SaveAttachmentResult{ConvertResult(Response)};
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual SaveAttachmentsResult SaveAttachments(const std::vector<SharedBuffer>& Chunks) override
+ {
+ Stopwatch Timer;
+
+ CbPackage RequestPackage;
+ {
+ CbObjectWriter RequestWriter;
+ RequestWriter.AddString("method"sv, "putchunks"sv);
+ RequestWriter.BeginArray("chunks"sv);
+ {
+ for (const SharedBuffer& Chunk : Chunks)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(Chunk, RawHash, RawSize);
+ RequestWriter.AddHash(RawHash);
+ RequestPackage.AddAttachment(CbAttachment(Compressed, RawHash));
+ }
+ }
+ RequestWriter.EndArray(); // "chunks"
+ RequestPackage.SetObject(RequestWriter.Save());
+ }
+ CompositeBuffer Payload = FormatPackageMessageBuffer(RequestPackage, FormatFlags::kDefault);
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+ std::string SaveRequest = fmt::format("{}/{}/oplog/{}/rpc"sv, m_ProjectStoreUrl, m_Project, m_Oplog);
+ Session->SetUrl({SaveRequest});
+ Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbPackage))}});
+
+ uint64_t SizeLeft = Payload.GetSize();
+ CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0);
+ auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, SizeLeft);
+ MutableMemoryView Data(buffer, size);
+ Payload.CopyTo(Data, BufferIt);
+ SizeLeft -= size;
+ return true;
+ };
+ Session->SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback));
+ cpr::Response Response = Session->Post();
+ SaveAttachmentsResult Result = SaveAttachmentsResult{ConvertResult(Response)};
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual LoadAttachmentsResult LoadAttachments(const std::vector<IoHash>& RawHashes) override
+ {
+ Stopwatch Timer;
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+ std::string SaveRequest = fmt::format("{}/{}/oplog/{}/rpc"sv, m_ProjectStoreUrl, m_Project, m_Oplog);
+
+ CbObject Request;
+ {
+ CbObjectWriter RequestWriter;
+ RequestWriter.AddString("method"sv, "getchunks"sv);
+ RequestWriter.BeginArray("chunks"sv);
+ {
+ for (const IoHash& RawHash : RawHashes)
+ {
+ RequestWriter.AddHash(RawHash);
+ }
+ }
+ RequestWriter.EndArray(); // "chunks"
+ Request = RequestWriter.Save();
+ }
+ IoBuffer Payload = Request.GetBuffer().AsIoBuffer();
+ Session->SetBody(cpr::Body{(const char*)Payload.GetData(), Payload.GetSize()});
+ Session->SetUrl(SaveRequest);
+ Session->SetHeader({{"Content-Type", std::string(MapContentTypeToString(HttpContentType::kCbObject))},
+ {"Accept", std::string(MapContentTypeToString(HttpContentType::kCbPackage))}});
+
+ cpr::Response Response = Session->Post();
+ LoadAttachmentsResult Result = LoadAttachmentsResult{ConvertResult(Response)};
+ if (!Result.ErrorCode)
+ {
+ CbPackage Package = ParsePackageMessage(IoBuffer(IoBuffer::Wrap, Response.text.data(), Response.text.size()));
+ std::span<const CbAttachment> Attachments = Package.GetAttachments();
+ Result.Chunks.reserve(Attachments.size());
+ for (const CbAttachment& Attachment : Attachments)
+ {
+ Result.Chunks.emplace_back(
+ std::pair<IoHash, CompressedBuffer>{Attachment.GetHash(), Attachment.AsCompressedBinary().MakeOwned()});
+ }
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ };
+
+ virtual Result FinalizeContainer(const IoHash&) override
+ {
+ Stopwatch Timer;
+
+ RwLock::ExclusiveLockScope _(SessionsLock);
+ Sessions.clear();
+ return {.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500};
+ }
+
+ virtual LoadContainerResult LoadContainer() override
+ {
+ Stopwatch Timer;
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+ std::string SaveRequest = fmt::format("{}/{}/oplog/{}/load"sv, m_ProjectStoreUrl, m_Project, m_Oplog);
+ Session->SetUrl(SaveRequest);
+ Session->SetHeader({{"Accept", std::string(MapContentTypeToString(HttpContentType::kCbObject))}});
+ Session->SetParameters(
+ {{"maxblocksize", fmt::format("{}", m_MaxBlockSize)}, {"maxchunkembedsize", fmt::format("{}", m_MaxChunkEmbedSize)}});
+ cpr::Response Response = Session->Get();
+
+ LoadContainerResult Result = LoadContainerResult{ConvertResult(Response)};
+ if (!Result.ErrorCode)
+ {
+ Result.ContainerObject = LoadCompactBinaryObject(IoBuffer(IoBuffer::Clone, Response.text.data(), Response.text.size()));
+ if (!Result.ContainerObject)
+ {
+ Result.Reason = fmt::format("The response for {}/{}/{} is not formatted as a compact binary object"sv,
+ m_ProjectStoreUrl,
+ m_Project,
+ m_Oplog);
+ Result.ErrorCode = gsl::narrow<int32_t>(HttpResponseCode::InternalServerError);
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+ virtual LoadAttachmentResult LoadAttachment(const IoHash& RawHash) override
+ {
+ Stopwatch Timer;
+
+ std::unique_ptr<cpr::Session> Session(AllocateSession());
+ auto _ = MakeGuard([this, &Session]() { ReleaseSession(std::move(Session)); });
+
+ std::string LoadRequest = fmt::format("{}/{}/oplog/{}/{}"sv, m_ProjectStoreUrl, m_Project, m_Oplog, RawHash);
+ Session->SetUrl({LoadRequest});
+ Session->SetHeader({{"Accept", std::string(MapContentTypeToString(HttpContentType::kCompressedBinary))}});
+ cpr::Response Response = Session->Get();
+ LoadAttachmentResult Result = LoadAttachmentResult{ConvertResult(Response)};
+ if (!Result.ErrorCode)
+ {
+ Result.Bytes = IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size());
+ }
+ Result.ElapsedSeconds = Timer.GetElapsedTimeMs() / 1000.500;
+ return Result;
+ }
+
+private:
+ std::unique_ptr<cpr::Session> AllocateSession()
+ {
+ RwLock::ExclusiveLockScope _(SessionsLock);
+ if (Sessions.empty())
+ {
+ Sessions.emplace_back(std::make_unique<cpr::Session>());
+ }
+ std::unique_ptr<cpr::Session> Session = std::move(Sessions.back());
+ Sessions.pop_back();
+ return Session;
+ }
+
+ void ReleaseSession(std::unique_ptr<cpr::Session>&& Session)
+ {
+ RwLock::ExclusiveLockScope _(SessionsLock);
+ Sessions.emplace_back(std::move(Session));
+ }
+
+ static Result ConvertResult(const cpr::Response& Response)
+ {
+ std::string Text;
+ std::string Reason = Response.reason;
+ int32_t ErrorCode = 0;
+ if (Response.error.code != cpr::ErrorCode::OK)
+ {
+ ErrorCode = static_cast<int32_t>(Response.error.code);
+ if (!Response.error.message.empty())
+ {
+ Reason = Response.error.message;
+ }
+ }
+ else if (!IsHttpSuccessCode(Response.status_code))
+ {
+ ErrorCode = static_cast<int32_t>(Response.status_code);
+
+ if (auto It = Response.header.find("Content-Type"); It != Response.header.end())
+ {
+ zen::HttpContentType ContentType = zen::ParseContentType(It->second);
+ if (ContentType == zen::HttpContentType::kText)
+ {
+ Text = Response.text;
+ }
+ }
+
+ Reason = fmt::format("{}"sv, Response.status_code);
+ }
+ return {.ErrorCode = ErrorCode, .ElapsedSeconds = Response.elapsed, .Reason = Reason, .Text = Text};
+ }
+
+ RwLock SessionsLock;
+ std::vector<std::unique_ptr<cpr::Session>> Sessions;
+
+ const std::string m_HostAddress;
+ const std::string m_ProjectStoreUrl;
+ const std::string m_Project;
+ const std::string m_Oplog;
+ const size_t m_MaxBlockSize;
+ const size_t m_MaxChunkEmbedSize;
+};
+
+std::unique_ptr<RemoteProjectStore>
+CreateZenRemoteStore(const ZenRemoteStoreOptions& Options)
+{
+ std::string Url = Options.Url;
+ if (Url.find("://"sv) == std::string::npos)
+ {
+ // Assume https URL
+ Url = fmt::format("http://{}"sv, Url);
+ }
+ std::unique_ptr<RemoteProjectStore> RemoteStore =
+ std::make_unique<ZenRemoteStore>(Url, Options.ProjectId, Options.OplogId, Options.MaxBlockSize, Options.MaxChunkEmbedSize);
+ return RemoteStore;
+}
+
+} // namespace zen
diff --git a/src/zenserver/projectstore/zenremoteprojectstore.h b/src/zenserver/projectstore/zenremoteprojectstore.h
new file mode 100644
index 000000000..ef9dcad8c
--- /dev/null
+++ b/src/zenserver/projectstore/zenremoteprojectstore.h
@@ -0,0 +1,18 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "remoteprojectstore.h"
+
+namespace zen {
+
+struct ZenRemoteStoreOptions : RemoteStoreOptions
+{
+ std::string Url;
+ std::string ProjectId;
+ std::string OplogId;
+};
+
+std::unique_ptr<RemoteProjectStore> CreateZenRemoteStore(const ZenRemoteStoreOptions& Options);
+
+} // namespace zen
diff --git a/src/zenserver/resource.h b/src/zenserver/resource.h
new file mode 100644
index 000000000..f2e3b471b
--- /dev/null
+++ b/src/zenserver/resource.h
@@ -0,0 +1,18 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+//{{NO_DEPENDENCIES}}
+// Microsoft Visual C++ generated include file.
+// Used by zenserver.rc
+//
+#define IDI_ICON1 101
+
+// Next default values for new objects
+//
+#ifdef APSTUDIO_INVOKED
+# ifndef APSTUDIO_READONLY_SYMBOLS
+# define _APS_NEXT_RESOURCE_VALUE 102
+# define _APS_NEXT_COMMAND_VALUE 40001
+# define _APS_NEXT_CONTROL_VALUE 1001
+# define _APS_NEXT_SYMED_VALUE 101
+# endif
+#endif
diff --git a/src/zenserver/targetver.h b/src/zenserver/targetver.h
new file mode 100644
index 000000000..d432d6993
--- /dev/null
+++ b/src/zenserver/targetver.h
@@ -0,0 +1,10 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+// Including SDKDDKVer.h defines the highest available Windows platform.
+
+// If you wish to build your application for a previous Windows platform, include WinSDKVer.h and
+// set the _WIN32_WINNT macro to the platform you wish to support before including SDKDDKVer.h.
+
+#include <SDKDDKVer.h>
diff --git a/src/zenserver/testing/httptest.cpp b/src/zenserver/testing/httptest.cpp
new file mode 100644
index 000000000..349a95ab3
--- /dev/null
+++ b/src/zenserver/testing/httptest.cpp
@@ -0,0 +1,207 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "httptest.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/timer.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+HttpTestingService::HttpTestingService()
+{
+ m_Router.RegisterRoute(
+ "hello",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponse(HttpResponseCode::OK); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "hello_slow",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) {
+ Stopwatch Timer;
+ Sleep(1000);
+ Request.WriteResponse(HttpResponseCode::OK,
+ HttpContentType::kText,
+ fmt::format("hello, took me {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs())));
+ });
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "hello_veryslow",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponseAsync([](HttpServerRequest& Request) {
+ Stopwatch Timer;
+ Sleep(60000);
+ Request.WriteResponse(HttpResponseCode::OK,
+ HttpContentType::kText,
+ fmt::format("hello, took me {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs())));
+ });
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "hello_throw",
+ [](HttpRouterRequest& Req) {
+ Req.ServerRequest().WriteResponseAsync([](HttpServerRequest&) { throw std::runtime_error("intentional error"); });
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "hello_noresponse",
+ [](HttpRouterRequest& Req) { Req.ServerRequest().WriteResponseAsync([](HttpServerRequest&) {}); },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "metrics",
+ [this](HttpRouterRequest& Req) {
+ metrics::OperationTiming::Scope _(m_TimingStats);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK);
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "get_metrics",
+ [this](HttpRouterRequest& Req) {
+ CbObjectWriter Cbo;
+ EmitSnapshot("requests", m_TimingStats, Cbo);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "json",
+ [this](HttpRouterRequest& Req) {
+ CbObjectWriter Obj;
+ Obj.AddBool("ok", true);
+ Obj.AddInteger("counter", ++m_Counter);
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Obj.Save());
+ },
+ HttpVerb::kGet);
+
+ m_Router.RegisterRoute(
+ "echo",
+ [](HttpRouterRequest& Req) {
+ IoBuffer Body = Req.ServerRequest().ReadPayload();
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, HttpContentType::kBinary, Body);
+ },
+ HttpVerb::kPost);
+
+ m_Router.RegisterRoute(
+ "package",
+ [](HttpRouterRequest& Req) {
+ CbPackage Pkg = Req.ServerRequest().ReadPayloadPackage();
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Pkg);
+ },
+ HttpVerb::kPost);
+}
+
+HttpTestingService::~HttpTestingService()
+{
+}
+
+const char*
+HttpTestingService::BaseUri() const
+{
+ return "/testing/";
+}
+
+void
+HttpTestingService::HandleRequest(HttpServerRequest& Request)
+{
+ m_Router.HandleRequest(Request);
+}
+
+Ref<IHttpPackageHandler>
+HttpTestingService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
+{
+ RwLock::ExclusiveLockScope _(m_RwLock);
+
+ const uint32_t RequestId = HttpServiceRequest.RequestId();
+
+ if (auto It = m_HandlerMap.find(RequestId); It != m_HandlerMap.end())
+ {
+ Ref<HttpTestingService::PackageHandler> Handler = std::move(It->second);
+
+ m_HandlerMap.erase(It);
+
+ return Handler;
+ }
+
+ auto InsertResult = m_HandlerMap.insert({RequestId, Ref<PackageHandler>()});
+
+ _.ReleaseNow();
+
+ return (InsertResult.first->second = Ref<PackageHandler>(new PackageHandler(*this, RequestId)));
+}
+
+void
+HttpTestingService::RegisterHandlers(WebSocketServer& Server)
+{
+ Server.RegisterRequestHandler("SayHello"sv, *this);
+}
+
+bool
+HttpTestingService::HandleRequest(const WebSocketMessage& RequestMsg)
+{
+ CbObjectView Request = RequestMsg.Body().GetObject();
+
+ std::string_view Method = Request["Method"].AsString();
+
+ if (Method != "SayHello"sv)
+ {
+ return false;
+ }
+
+ CbObjectWriter Response;
+ Response.AddString("Result"sv, "Hello Friend!!");
+
+ WebSocketMessage ResponseMsg;
+ ResponseMsg.SetMessageType(WebSocketMessageType::kResponse);
+ ResponseMsg.SetCorrelationId(RequestMsg.CorrelationId());
+ ResponseMsg.SetSocketId(RequestMsg.SocketId());
+ ResponseMsg.SetBody(Response.Save());
+
+ SocketServer().SendResponse(std::move(ResponseMsg));
+
+ return true;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpTestingService::PackageHandler::PackageHandler(HttpTestingService& Svc, uint32_t RequestId) : m_Svc(Svc), m_RequestId(RequestId)
+{
+}
+
+HttpTestingService::PackageHandler::~PackageHandler()
+{
+}
+
+void
+HttpTestingService::PackageHandler::FilterOffer(std::vector<IoHash>& OfferCids)
+{
+ ZEN_UNUSED(OfferCids);
+ // No-op
+ return;
+}
+void
+HttpTestingService::PackageHandler::OnRequestBegin()
+{
+}
+
+void
+HttpTestingService::PackageHandler::OnRequestComplete()
+{
+}
+
+IoBuffer
+HttpTestingService::PackageHandler::CreateTarget(const IoHash& Cid, uint64_t StorageSize)
+{
+ ZEN_UNUSED(Cid);
+ return IoBuffer{StorageSize};
+}
+
+} // namespace zen
diff --git a/src/zenserver/testing/httptest.h b/src/zenserver/testing/httptest.h
new file mode 100644
index 000000000..57d2d63f3
--- /dev/null
+++ b/src/zenserver/testing/httptest.h
@@ -0,0 +1,55 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/logging.h>
+#include <zencore/stats.h>
+#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
+
+#include <atomic>
+
+namespace zen {
+
+/**
+ * Test service to facilitate testing the HTTP framework and client interactions
+ */
+class HttpTestingService : public HttpService, public WebSocketService
+{
+public:
+ HttpTestingService();
+ ~HttpTestingService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(HttpServerRequest& Request) override;
+ virtual Ref<IHttpPackageHandler> HandlePackageRequest(HttpServerRequest& HttpServiceRequest) override;
+
+ class PackageHandler : public IHttpPackageHandler
+ {
+ public:
+ PackageHandler(HttpTestingService& Svc, uint32_t RequestId);
+ ~PackageHandler();
+
+ virtual void FilterOffer(std::vector<IoHash>& OfferCids) override;
+ virtual void OnRequestBegin() override;
+ virtual IoBuffer CreateTarget(const IoHash& Cid, uint64_t StorageSize) override;
+ virtual void OnRequestComplete() override;
+
+ private:
+ HttpTestingService& m_Svc;
+ uint32_t m_RequestId;
+ };
+
+private:
+ virtual void RegisterHandlers(WebSocketServer& Server) override;
+ virtual bool HandleRequest(const WebSocketMessage& Request) override;
+
+ HttpRequestRouter m_Router;
+ std::atomic<uint32_t> m_Counter{0};
+ metrics::OperationTiming m_TimingStats;
+
+ RwLock m_RwLock;
+ std::unordered_map<uint32_t, Ref<PackageHandler>> m_HandlerMap;
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/hordecompute.cpp b/src/zenserver/upstream/hordecompute.cpp
new file mode 100644
index 000000000..64d9fff72
--- /dev/null
+++ b/src/zenserver/upstream/hordecompute.cpp
@@ -0,0 +1,1457 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "upstreamapply.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include "jupiter.h"
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/compactbinarypackage.h>
+# include <zencore/compactbinaryvalidation.h>
+# include <zencore/fmtutils.h>
+# include <zencore/session.h>
+# include <zencore/stream.h>
+# include <zencore/thread.h>
+# include <zencore/timer.h>
+# include <zencore/workthreadpool.h>
+
+# include <zenstore/cidstore.h>
+
+# include <auth/authmgr.h>
+# include <upstream/upstreamcache.h>
+
+# include "cache/structuredcachestore.h"
+# include "diag/logging.h"
+
+# include <fmt/format.h>
+
+# include <algorithm>
+# include <atomic>
+# include <set>
+# include <stack>
+
+namespace zen {
+
+using namespace std::literals;
+
+static const IoBuffer EmptyBuffer;
+static const IoHash EmptyBufferId = IoHash::HashBuffer(EmptyBuffer);
+
+namespace detail {
+
+ class HordeUpstreamApplyEndpoint final : public UpstreamApplyEndpoint
+ {
+ public:
+ HordeUpstreamApplyEndpoint(const CloudCacheClientOptions& ComputeOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ CidStore& CidStore,
+ AuthMgr& Mgr)
+ : m_Log(logging::Get("upstream-apply"))
+ , m_CidStore(CidStore)
+ , m_AuthMgr(Mgr)
+ {
+ m_DisplayName = fmt::format("{} - '{}'+'{}'", ComputeOptions.Name, ComputeOptions.ServiceUrl, StorageOptions.ServiceUrl);
+ m_ChannelId = fmt::format("zen-{}", zen::GetSessionIdString());
+
+ {
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+
+ if (ComputeAuthConfig.OAuthUrl.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = ComputeAuthConfig.OAuthUrl,
+ .ClientId = ComputeAuthConfig.OAuthClientId,
+ .ClientSecret = ComputeAuthConfig.OAuthClientSecret});
+ }
+ else if (ComputeAuthConfig.OpenIdProvider.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(ComputeAuthConfig.OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+ else
+ {
+ CloudCacheAccessToken AccessToken{.Value = std::string(ComputeAuthConfig.AccessToken),
+ .ExpireTime = CloudCacheAccessToken::TimePoint::max()};
+ TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken);
+ }
+
+ m_Client = new CloudCacheClient(ComputeOptions, std::move(TokenProvider));
+ }
+
+ {
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+
+ if (StorageAuthConfig.OAuthUrl.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromOAuthClientCredentials({.Url = StorageAuthConfig.OAuthUrl,
+ .ClientId = StorageAuthConfig.OAuthClientId,
+ .ClientSecret = StorageAuthConfig.OAuthClientSecret});
+ }
+ else if (StorageAuthConfig.OpenIdProvider.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(StorageAuthConfig.OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+ else
+ {
+ CloudCacheAccessToken AccessToken{.Value = std::string(StorageAuthConfig.AccessToken),
+ .ExpireTime = CloudCacheAccessToken::TimePoint::max()};
+ TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken);
+ }
+
+ m_StorageClient = new CloudCacheClient(StorageOptions, std::move(TokenProvider));
+ }
+ }
+
+ virtual ~HordeUpstreamApplyEndpoint() = default;
+
+ virtual UpstreamEndpointHealth Initialize() override { return CheckHealth(); }
+
+ virtual bool IsHealthy() const override { return m_HealthOk.load(); }
+
+ virtual UpstreamEndpointHealth CheckHealth() override
+ {
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ CloudCacheResult Result = Session.Authenticate();
+
+ m_HealthOk = Result.ErrorCode == 0;
+
+ return {.Reason = std::move(Result.Reason), .Ok = Result.Success};
+ }
+ catch (std::exception& Err)
+ {
+ return {.Reason = Err.what(), .Ok = false};
+ }
+ }
+
+ virtual std::string_view DisplayName() const override { return m_DisplayName; }
+
+ virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) override
+ {
+ PostUpstreamApplyResult ApplyResult{};
+ ApplyResult.Timepoints.merge(ApplyRecord.Timepoints);
+
+ try
+ {
+ UpstreamData UpstreamData;
+ if (!ProcessApplyKey(ApplyRecord, UpstreamData))
+ {
+ return {.Error{.ErrorCode = -1, .Reason = "Failed to generate task data"}};
+ }
+
+ {
+ ApplyResult.Timepoints["zen-storage-build-ref"] = DateTime::NowTicks();
+
+ bool AlreadyQueued;
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ AlreadyQueued = m_PendingTasks.contains(UpstreamData.TaskId);
+ }
+ if (AlreadyQueued)
+ {
+ // Pending task is already queued, return success
+ ApplyResult.Success = true;
+ return ApplyResult;
+ }
+ m_PendingTasks[UpstreamData.TaskId] = std::move(ApplyRecord);
+ }
+
+ CloudCacheSession ComputeSession(m_Client);
+ CloudCacheSession StorageSession(m_StorageClient);
+
+ {
+ CloudCacheResult Result = BatchPutBlobsIfMissing(StorageSession, UpstreamData.Blobs, UpstreamData.CasIds);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-upload-blobs"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ ApplyResult.Error = {.ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload blobs"};
+ return ApplyResult;
+ }
+ UpstreamData.Blobs.clear();
+ UpstreamData.CasIds.clear();
+ }
+
+ {
+ CloudCacheResult Result = BatchPutCompressedBlobsIfMissing(StorageSession, UpstreamData.Cids);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-upload-compressed-blobs"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ ApplyResult.Error = {
+ .ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload compressed blobs"};
+ return ApplyResult;
+ }
+ UpstreamData.Cids.clear();
+ }
+
+ {
+ CloudCacheResult Result = BatchPutObjectsIfMissing(StorageSession, UpstreamData.Objects);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-upload-objects"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ ApplyResult.Error = {.ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to upload objects"};
+ return ApplyResult;
+ }
+ }
+
+ {
+ PutRefResult RefResult = StorageSession.PutRef(StorageSession.Client().DefaultBlobStoreNamespace(),
+ "requests"sv,
+ UpstreamData.TaskId,
+ UpstreamData.Objects[UpstreamData.TaskId].GetBuffer().AsIoBuffer(),
+ ZenContentType::kCbObject);
+ Log().debug("Put ref {} Need={} Bytes={} Duration={}s Result={}",
+ UpstreamData.TaskId,
+ RefResult.Needs.size(),
+ RefResult.Bytes,
+ RefResult.ElapsedSeconds,
+ RefResult.Success);
+ ApplyResult.Bytes += RefResult.Bytes;
+ ApplyResult.ElapsedSeconds += RefResult.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-put-ref"] = DateTime::NowTicks();
+
+ if (RefResult.Needs.size() > 0)
+ {
+ Log().error("Failed to add task ref {} due to {} missing blobs", UpstreamData.TaskId, RefResult.Needs.size());
+ for (const auto& Hash : RefResult.Needs)
+ {
+ Log().debug("Task ref {} missing blob {}", UpstreamData.TaskId, Hash);
+ }
+
+ ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode,
+ .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason)
+ : "Failed to add task ref due to missing blob"};
+ return ApplyResult;
+ }
+
+ if (!RefResult.Success)
+ {
+ ApplyResult.Error = {.ErrorCode = RefResult.ErrorCode,
+ .Reason = !RefResult.Reason.empty() ? std::move(RefResult.Reason) : "Failed to add task ref"};
+ return ApplyResult;
+ }
+ UpstreamData.Objects.clear();
+ }
+
+ {
+ CbObjectWriter Writer;
+ Writer.AddString("c"sv, m_ChannelId);
+ Writer.AddObjectAttachment("r"sv, UpstreamData.RequirementsId);
+ Writer.BeginArray("t"sv);
+ Writer.AddObjectAttachment(UpstreamData.TaskId);
+ Writer.EndArray();
+ CbObject TasksObject = Writer.Save();
+ IoBuffer TasksData = TasksObject.GetBuffer().AsIoBuffer();
+
+ CloudCacheResult Result = ComputeSession.PostComputeTasks(TasksData);
+ Log().debug("Post compute task {} Bytes={} Duration={}s Result={}",
+ TasksObject.GetHash(),
+ Result.Bytes,
+ Result.ElapsedSeconds,
+ Result.Success);
+ ApplyResult.Bytes += Result.Bytes;
+ ApplyResult.ElapsedSeconds += Result.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-horde-post-task"] = DateTime::NowTicks();
+ if (!Result.Success)
+ {
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ m_PendingTasks.erase(UpstreamData.TaskId);
+ }
+
+ ApplyResult.Error = {.ErrorCode = Result.ErrorCode,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to post compute task"};
+ return ApplyResult;
+ }
+ }
+
+ Log().info("Task posted {}", UpstreamData.TaskId);
+ ApplyResult.Success = true;
+ return ApplyResult;
+ }
+ catch (std::exception& Err)
+ {
+ m_HealthOk = false;
+ return {.Error{.ErrorCode = -1, .Reason = Err.what()}};
+ }
+ }
+
+ [[nodiscard]] CloudCacheResult BatchPutBlobsIfMissing(CloudCacheSession& Session,
+ const std::map<IoHash, IoBuffer>& Blobs,
+ const std::set<IoHash>& CasIds)
+ {
+ if (Blobs.size() == 0 && CasIds.size() == 0)
+ {
+ return {.Success = true};
+ }
+
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ // Batch check for missing blobs
+ std::set<IoHash> Keys;
+ std::transform(Blobs.begin(), Blobs.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; });
+ Keys.insert(CasIds.begin(), CasIds.end());
+
+ CloudCacheExistsResult ExistsResult = Session.BlobExists(Session.Client().DefaultBlobStoreNamespace(), Keys);
+ Log().debug("Queried {} missing blobs Need={} Duration={}s Result={}",
+ Keys.size(),
+ ExistsResult.Needs.size(),
+ ExistsResult.ElapsedSeconds,
+ ExistsResult.Success);
+ ElapsedSeconds += ExistsResult.ElapsedSeconds;
+ if (!ExistsResult.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1,
+ .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if blobs exist"};
+ }
+
+ for (const auto& Hash : ExistsResult.Needs)
+ {
+ IoBuffer DataBuffer;
+ if (Blobs.contains(Hash))
+ {
+ DataBuffer = Blobs.at(Hash);
+ }
+ else
+ {
+ DataBuffer = m_CidStore.FindChunkByCid(Hash);
+ if (!DataBuffer)
+ {
+ Log().warn("Put blob FAILED, input chunk '{}' missing", Hash);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put blobs"};
+ }
+ }
+
+ CloudCacheResult Result = Session.PutBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer);
+ Log().debug("Put blob {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success);
+ Bytes += Result.Bytes;
+ ElapsedSeconds += Result.ElapsedSeconds;
+ if (!Result.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put blobs"};
+ }
+ }
+
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+
+ [[nodiscard]] CloudCacheResult BatchPutCompressedBlobsIfMissing(CloudCacheSession& Session, const std::set<IoHash>& Cids)
+ {
+ if (Cids.size() == 0)
+ {
+ return {.Success = true};
+ }
+
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ // Batch check for missing compressed blobs
+ CloudCacheExistsResult ExistsResult = Session.CompressedBlobExists(Session.Client().DefaultBlobStoreNamespace(), Cids);
+ Log().debug("Queried {} missing compressed blobs Need={} Duration={}s Result={}",
+ Cids.size(),
+ ExistsResult.Needs.size(),
+ ExistsResult.ElapsedSeconds,
+ ExistsResult.Success);
+ ElapsedSeconds += ExistsResult.ElapsedSeconds;
+ if (!ExistsResult.Success)
+ {
+ return {
+ .Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1,
+ .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if compressed blobs exist"};
+ }
+
+ for (const auto& Hash : ExistsResult.Needs)
+ {
+ IoBuffer DataBuffer = m_CidStore.FindChunkByCid(Hash);
+ if (!DataBuffer)
+ {
+ Log().warn("Put compressed blob FAILED, input CID chunk '{}' missing", Hash);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .ErrorCode = -1, .Reason = "Failed to put compressed blobs"};
+ }
+
+ CloudCacheResult Result = Session.PutCompressedBlob(Session.Client().DefaultBlobStoreNamespace(), Hash, DataBuffer);
+ Log().debug("Put compressed blob {} Bytes={} Duration={}s Result={}",
+ Hash,
+ Result.Bytes,
+ Result.ElapsedSeconds,
+ Result.Success);
+ Bytes += Result.Bytes;
+ ElapsedSeconds += Result.ElapsedSeconds;
+ if (!Result.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put compressed blobs"};
+ }
+ }
+
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+
+ [[nodiscard]] CloudCacheResult BatchPutObjectsIfMissing(CloudCacheSession& Session, const std::map<IoHash, CbObject>& Objects)
+ {
+ if (Objects.size() == 0)
+ {
+ return {.Success = true};
+ }
+
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ // Batch check for missing objects
+ std::set<IoHash> Keys;
+ std::transform(Objects.begin(), Objects.end(), std::inserter(Keys, Keys.end()), [](const auto& It) { return It.first; });
+
+ CloudCacheExistsResult ExistsResult = Session.ObjectExists(Session.Client().DefaultBlobStoreNamespace(), Keys);
+ Log().debug("Queried {} missing objects Need={} Duration={}s Result={}",
+ Keys.size(),
+ ExistsResult.Needs.size(),
+ ExistsResult.ElapsedSeconds,
+ ExistsResult.Success);
+ ElapsedSeconds += ExistsResult.ElapsedSeconds;
+ if (!ExistsResult.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = ExistsResult.ErrorCode ? ExistsResult.ErrorCode : -1,
+ .Reason = !ExistsResult.Reason.empty() ? std::move(ExistsResult.Reason) : "Failed to check if objects exist"};
+ }
+
+ for (const auto& Hash : ExistsResult.Needs)
+ {
+ CloudCacheResult Result =
+ Session.PutObject(Session.Client().DefaultBlobStoreNamespace(), Hash, Objects.at(Hash).GetBuffer().AsIoBuffer());
+ Log().debug("Put object {} Bytes={} Duration={}s Result={}", Hash, Result.Bytes, Result.ElapsedSeconds, Result.Success);
+ Bytes += Result.Bytes;
+ ElapsedSeconds += Result.ElapsedSeconds;
+ if (!Result.Success)
+ {
+ return {.Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ .ErrorCode = Result.ErrorCode ? Result.ErrorCode : -1,
+ .Reason = !Result.Reason.empty() ? std::move(Result.Reason) : "Failed to put objects"};
+ }
+ }
+
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+
+ enum class ComputeTaskState : int32_t
+ {
+ Queued = 0,
+ Executing = 1,
+ Complete = 2,
+ };
+
+ enum class ComputeTaskOutcome : int32_t
+ {
+ Success = 0,
+ Failed = 1,
+ Cancelled = 2,
+ NoResult = 3,
+ Exipred = 4,
+ BlobNotFound = 5,
+ Exception = 6,
+ };
+
+ [[nodiscard]] static std::string_view ComputeTaskStateToString(const ComputeTaskState Outcome)
+ {
+ switch (Outcome)
+ {
+ case ComputeTaskState::Queued:
+ return "Queued"sv;
+ case ComputeTaskState::Executing:
+ return "Executing"sv;
+ case ComputeTaskState::Complete:
+ return "Complete"sv;
+ };
+ return "Unknown"sv;
+ }
+
+ [[nodiscard]] static std::string_view ComputeTaskOutcomeToString(const ComputeTaskOutcome Outcome)
+ {
+ switch (Outcome)
+ {
+ case ComputeTaskOutcome::Success:
+ return "Success"sv;
+ case ComputeTaskOutcome::Failed:
+ return "Failed"sv;
+ case ComputeTaskOutcome::Cancelled:
+ return "Cancelled"sv;
+ case ComputeTaskOutcome::NoResult:
+ return "NoResult"sv;
+ case ComputeTaskOutcome::Exipred:
+ return "Exipred"sv;
+ case ComputeTaskOutcome::BlobNotFound:
+ return "BlobNotFound"sv;
+ case ComputeTaskOutcome::Exception:
+ return "Exception"sv;
+ };
+ return "Unknown"sv;
+ }
+
+ virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) override
+ {
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ if (m_PendingTasks.empty())
+ {
+ if (m_CompletedTasks.empty())
+ {
+ // Nothing to do.
+ return {.Success = true};
+ }
+
+ UpstreamApplyCompleted CompletedTasks;
+ std::swap(CompletedTasks, m_CompletedTasks);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true};
+ }
+ }
+
+ try
+ {
+ CloudCacheSession ComputeSession(m_Client);
+
+ CloudCacheResult UpdatesResult = ComputeSession.GetComputeUpdates(m_ChannelId);
+ Log().debug("Get compute updates Bytes={} Duration={}s Result={}",
+ UpdatesResult.Bytes,
+ UpdatesResult.ElapsedSeconds,
+ UpdatesResult.Success);
+ Bytes += UpdatesResult.Bytes;
+ ElapsedSeconds += UpdatesResult.ElapsedSeconds;
+ if (!UpdatesResult.Success)
+ {
+ return {.Error{.ErrorCode = UpdatesResult.ErrorCode, .Reason = std::move(UpdatesResult.Reason)},
+ .Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds};
+ }
+
+ if (!UpdatesResult.Success)
+ {
+ return {.Error{.ErrorCode = -1, .Reason = "Failed get task updates"}, .Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds};
+ }
+
+ CbObject TaskStatus = LoadCompactBinaryObject(std::move(UpdatesResult.Response));
+
+ for (auto& It : TaskStatus["u"sv])
+ {
+ CbObjectView Status = It.AsObjectView();
+ IoHash TaskId = Status["h"sv].AsHash();
+ const ComputeTaskState State = (ComputeTaskState)Status["s"sv].AsInt32();
+ const ComputeTaskOutcome Outcome = (ComputeTaskOutcome)Status["o"sv].AsInt32();
+
+ Log().info("Task {} State={}", TaskId, ComputeTaskStateToString(State));
+
+ // Only completed tasks need to be processed
+ if (State != ComputeTaskState::Complete)
+ {
+ continue;
+ }
+
+ IoHash WorkerId{};
+ IoHash ActionId{};
+ UpstreamApplyType ApplyType{};
+
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ auto TaskIt = m_PendingTasks.find(TaskId);
+ if (TaskIt != m_PendingTasks.end())
+ {
+ WorkerId = TaskIt->second.WorkerDescriptor.GetHash();
+ ActionId = TaskIt->second.Action.GetHash();
+ ApplyType = TaskIt->second.Type;
+ m_PendingTasks.erase(TaskIt);
+ }
+ }
+
+ if (WorkerId == IoHash::Zero)
+ {
+ Log().warn("Task {} missing from pending tasks", TaskId);
+ continue;
+ }
+
+ std::map<std::string, uint64_t> Timepoints;
+ ProcessQueueTimings(Status["qs"sv].AsObjectView(), Timepoints);
+ ProcessExecuteTimings(Status["es"sv].AsObjectView(), Timepoints);
+
+ if (Outcome != ComputeTaskOutcome::Success)
+ {
+ const std::string_view Detail = Status["d"sv].AsString();
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ m_CompletedTasks[WorkerId][ActionId] = {
+ .Error{.ErrorCode = -1, .Reason = fmt::format("Task {} {}", ComputeTaskOutcomeToString(Outcome), Detail)},
+ .Timepoints = std::move(Timepoints)};
+ }
+ continue;
+ }
+
+ Timepoints["zen-complete-queue-added"] = DateTime::NowTicks();
+ ThreadPool.ScheduleWork([this,
+ ApplyType,
+ ResultHash = Status["r"sv].AsHash(),
+ Timepoints = std::move(Timepoints),
+ TaskId = std::move(TaskId),
+ WorkerId = std::move(WorkerId),
+ ActionId = std::move(ActionId)]() mutable {
+ Timepoints["zen-complete-queue-dispatched"] = DateTime::NowTicks();
+ GetUpstreamApplyResult Result = ProcessTaskStatus(ApplyType, ResultHash);
+ Timepoints["zen-complete-queue-complete"] = DateTime::NowTicks();
+ Result.Timepoints.merge(Timepoints);
+
+ Log().debug("Task Processed {} Files={} Attachments={} ExitCode={}",
+ TaskId,
+ Result.OutputFiles.size(),
+ Result.OutputPackage.GetAttachments().size(),
+ Result.Error.ErrorCode);
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ m_CompletedTasks[WorkerId][ActionId] = std::move(Result);
+ }
+ });
+ }
+
+ {
+ std::scoped_lock Lock(m_TaskMutex);
+ if (m_CompletedTasks.empty())
+ {
+ // Nothing to do.
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Success = true};
+ }
+ UpstreamApplyCompleted CompletedTasks;
+ std::swap(CompletedTasks, m_CompletedTasks);
+ return {.Bytes = Bytes, .ElapsedSeconds = ElapsedSeconds, .Completed = std::move(CompletedTasks), .Success = true};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_HealthOk = false;
+ return {
+ .Error{.ErrorCode = -1, .Reason = Err.what()},
+ .Bytes = Bytes,
+ .ElapsedSeconds = ElapsedSeconds,
+ };
+ }
+ }
+
+ virtual UpstreamApplyEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ CidStore& m_CidStore;
+ AuthMgr& m_AuthMgr;
+ std::string m_DisplayName;
+ RefPtr<CloudCacheClient> m_Client;
+ RefPtr<CloudCacheClient> m_StorageClient;
+ UpstreamApplyEndpointStats m_Stats;
+ std::atomic_bool m_HealthOk{false};
+ std::string m_ChannelId;
+
+ std::mutex m_TaskMutex;
+ std::unordered_map<IoHash, UpstreamApplyRecord> m_PendingTasks;
+ UpstreamApplyCompleted m_CompletedTasks;
+
+ struct UpstreamData
+ {
+ std::map<IoHash, IoBuffer> Blobs;
+ std::map<IoHash, CbObject> Objects;
+ std::set<IoHash> CasIds;
+ std::set<IoHash> Cids;
+ IoHash TaskId;
+ IoHash RequirementsId;
+ };
+
+ struct UpstreamDirectory
+ {
+ std::filesystem::path Path;
+ std::map<std::string, UpstreamDirectory> Directories;
+ std::set<std::string> Files;
+ };
+
+ static void ProcessQueueTimings(CbObjectView QueueStats, std::map<std::string, uint64_t>& Timepoints)
+ {
+ uint64_t Ticks = QueueStats["t"sv].AsDateTimeTicks();
+ if (Ticks == 0)
+ {
+ return;
+ }
+
+ // Scope is an array of miliseconds after start time
+ // TODO: cleanup
+ Timepoints["horde-queue-added"] = Ticks;
+ int Index = 0;
+ for (auto& Item : QueueStats["s"sv].AsArrayView())
+ {
+ Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond;
+ switch (Index)
+ {
+ case 0:
+ Timepoints["horde-queue-dispatched"] = Ticks;
+ break;
+ case 1:
+ Timepoints["horde-queue-complete"] = Ticks;
+ break;
+ }
+ Index++;
+ }
+ }
+
+ static void ProcessExecuteTimings(CbObjectView ExecutionStats, std::map<std::string, uint64_t>& Timepoints)
+ {
+ uint64_t Ticks = ExecutionStats["t"sv].AsDateTimeTicks();
+ if (Ticks == 0)
+ {
+ return;
+ }
+
+ // Scope is an array of miliseconds after start time
+ // TODO: cleanup
+ Timepoints["horde-execution-start"] = Ticks;
+ int Index = 0;
+ for (auto& Item : ExecutionStats["s"sv].AsArrayView())
+ {
+ Ticks += Item.AsInt32() * TimeSpan::TicksPerMillisecond;
+ switch (Index)
+ {
+ case 0:
+ Timepoints["horde-execution-download-ref"] = Ticks;
+ break;
+ case 1:
+ Timepoints["horde-execution-download-input"] = Ticks;
+ break;
+ case 2:
+ Timepoints["horde-execution-execute"] = Ticks;
+ break;
+ case 3:
+ Timepoints["horde-execution-upload-log"] = Ticks;
+ break;
+ case 4:
+ Timepoints["horde-execution-upload-output"] = Ticks;
+ break;
+ case 5:
+ Timepoints["horde-execution-upload-ref"] = Ticks;
+ break;
+ }
+ Index++;
+ }
+ }
+
+ [[nodiscard]] GetUpstreamApplyResult ProcessTaskStatus(const UpstreamApplyType ApplyType, const IoHash& ResultHash)
+ {
+ try
+ {
+ CloudCacheSession Session(m_StorageClient);
+
+ GetUpstreamApplyResult ApplyResult{};
+
+ IoHash StdOutHash;
+ IoHash StdErrHash;
+ IoHash OutputHash;
+
+ std::map<IoHash, IoBuffer> BinaryData;
+
+ {
+ CloudCacheResult ObjectRefResult =
+ Session.GetRef(Session.Client().DefaultBlobStoreNamespace(), "responses"sv, ResultHash, ZenContentType::kCbObject);
+ Log().debug("Get ref {} Bytes={} Duration={}s Result={}",
+ ResultHash,
+ ObjectRefResult.Bytes,
+ ObjectRefResult.ElapsedSeconds,
+ ObjectRefResult.Success);
+ ApplyResult.Bytes += ObjectRefResult.Bytes;
+ ApplyResult.ElapsedSeconds += ObjectRefResult.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-get-ref"] = DateTime::NowTicks();
+
+ if (!ObjectRefResult.Success)
+ {
+ ApplyResult.Error.Reason = "Failed to get result object data";
+ return ApplyResult;
+ }
+
+ CbObject ResultObject = LoadCompactBinaryObject(ObjectRefResult.Response);
+ ApplyResult.Error.ErrorCode = ResultObject["e"sv].AsInt32();
+ StdOutHash = ResultObject["so"sv].AsBinaryAttachment();
+ StdErrHash = ResultObject["se"sv].AsBinaryAttachment();
+ OutputHash = ResultObject["o"sv].AsObjectAttachment();
+ }
+
+ {
+ std::set<IoHash> NeededData;
+ if (OutputHash != IoHash::Zero)
+ {
+ GetObjectReferencesResult ObjectReferenceResult =
+ Session.GetObjectReferences(Session.Client().DefaultBlobStoreNamespace(), OutputHash);
+ Log().debug("Get object references {} References={} Bytes={} Duration={}s Result={}",
+ ResultHash,
+ ObjectReferenceResult.References.size(),
+ ObjectReferenceResult.Bytes,
+ ObjectReferenceResult.ElapsedSeconds,
+ ObjectReferenceResult.Success);
+ ApplyResult.Bytes += ObjectReferenceResult.Bytes;
+ ApplyResult.ElapsedSeconds += ObjectReferenceResult.ElapsedSeconds;
+ ApplyResult.Timepoints["zen-storage-get-object-references"] = DateTime::NowTicks();
+
+ if (!ObjectReferenceResult.Success)
+ {
+ ApplyResult.Error.Reason = "Failed to get result object references";
+ return ApplyResult;
+ }
+
+ NeededData = std::move(ObjectReferenceResult.References);
+ }
+
+ NeededData.insert(OutputHash);
+ NeededData.insert(StdOutHash);
+ NeededData.insert(StdErrHash);
+
+ for (const auto& Hash : NeededData)
+ {
+ if (Hash == IoHash::Zero)
+ {
+ continue;
+ }
+ CloudCacheResult BlobResult = Session.GetBlob(Session.Client().DefaultBlobStoreNamespace(), Hash);
+ Log().debug("Get blob {} Bytes={} Duration={}s Result={}",
+ Hash,
+ BlobResult.Bytes,
+ BlobResult.ElapsedSeconds,
+ BlobResult.Success);
+ ApplyResult.Bytes += BlobResult.Bytes;
+ ApplyResult.ElapsedSeconds += BlobResult.ElapsedSeconds;
+ if (!BlobResult.Success)
+ {
+ ApplyResult.Error.Reason = "Failed to get blob";
+ return ApplyResult;
+ }
+ BinaryData[Hash] = std::move(BlobResult.Response);
+ }
+ ApplyResult.Timepoints["zen-storage-get-blobs"] = DateTime::NowTicks();
+ }
+
+ ApplyResult.StdOut = StdOutHash != IoHash::Zero
+ ? std::string((const char*)BinaryData[StdOutHash].GetData(), BinaryData[StdOutHash].GetSize())
+ : "";
+ ApplyResult.StdErr = StdErrHash != IoHash::Zero
+ ? std::string((const char*)BinaryData[StdErrHash].GetData(), BinaryData[StdErrHash].GetSize())
+ : "";
+
+ if (OutputHash == IoHash::Zero)
+ {
+ ApplyResult.Error.Reason = "Task completed with no output object";
+ return ApplyResult;
+ }
+
+ CbObject OutputObject = LoadCompactBinaryObject(BinaryData[OutputHash]);
+
+ switch (ApplyType)
+ {
+ case UpstreamApplyType::Simple:
+ {
+ ResolveMerkleTreeDirectory(""sv, OutputHash, BinaryData, ApplyResult.OutputFiles);
+ for (const auto& Pair : BinaryData)
+ {
+ ApplyResult.FileData[Pair.first] = std::move(BinaryData.at(Pair.first));
+ }
+
+ ApplyResult.Success = ApplyResult.Error.ErrorCode == 0;
+ return ApplyResult;
+ }
+ break;
+ case UpstreamApplyType::Asset:
+ {
+ if (ApplyResult.Error.ErrorCode != 0)
+ {
+ ApplyResult.Error.Reason = "Task completed with errors";
+ return ApplyResult;
+ }
+
+ // Get build.output
+ IoHash BuildOutputId;
+ IoBuffer BuildOutput;
+ for (auto& It : OutputObject["f"sv])
+ {
+ const CbObjectView FileObject = It.AsObjectView();
+ if (FileObject["n"sv].AsString() == "Build.output"sv)
+ {
+ BuildOutputId = FileObject["h"sv].AsBinaryAttachment();
+ BuildOutput = BinaryData[BuildOutputId];
+ break;
+ }
+ }
+
+ if (BuildOutput.GetSize() == 0)
+ {
+ ApplyResult.Error.Reason = "Build.output file not found in task results";
+ return ApplyResult;
+ }
+
+ // Get Output directory node
+ IoBuffer OutputDirectoryTree;
+ for (auto& It : OutputObject["d"sv])
+ {
+ const CbObjectView DirectoryObject = It.AsObjectView();
+ if (DirectoryObject["n"sv].AsString() == "Outputs"sv)
+ {
+ OutputDirectoryTree = BinaryData[DirectoryObject["h"sv].AsObjectAttachment()];
+ break;
+ }
+ }
+
+ if (OutputDirectoryTree.GetSize() == 0)
+ {
+ ApplyResult.Error.Reason = "Outputs directory not found in task results";
+ return ApplyResult;
+ }
+
+ // load build.output as CbObject
+
+ // Move Outputs from Horde to CbPackage
+
+ std::unordered_map<IoHash, IoHash> CidToCompressedId;
+ CbPackage OutputPackage;
+ CbObject OutputDirectoryTreeObject = LoadCompactBinaryObject(OutputDirectoryTree);
+
+ for (auto& It : OutputDirectoryTreeObject["f"sv])
+ {
+ CbObjectView FileObject = It.AsObjectView();
+ // Name is the uncompressed hash
+ IoHash DecompressedId = IoHash::FromHexString(FileObject["n"sv].AsString());
+ // Hash is the compressed data hash, and how it is stored in Horde
+ IoHash CompressedId = FileObject["h"sv].AsBinaryAttachment();
+
+ if (!BinaryData.contains(CompressedId))
+ {
+ Log().warn("Object attachment chunk not retrieved from Horde {}", CompressedId);
+ ApplyResult.Error.Reason = "Object attachment chunk not retrieved from Horde";
+ return ApplyResult;
+ }
+ CidToCompressedId[DecompressedId] = CompressedId;
+ }
+
+ // Iterate attachments, verify all chunks exist, and add to CbPackage
+ bool AnyErrors = false;
+ CbObject BuildOutputObject = LoadCompactBinaryObject(BuildOutput);
+ BuildOutputObject.IterateAttachments([&](CbFieldView Field) {
+ const IoHash DecompressedId = Field.AsHash();
+ if (!CidToCompressedId.contains(DecompressedId))
+ {
+ Log().warn("Attachment not found {}", DecompressedId);
+ AnyErrors = true;
+ return;
+ }
+ const IoHash& CompressedId = CidToCompressedId.at(DecompressedId);
+
+ if (!BinaryData.contains(CompressedId))
+ {
+ Log().warn("Missing output {} compressed {} uncompressed", CompressedId, DecompressedId);
+ AnyErrors = true;
+ return;
+ }
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer AttachmentBuffer =
+ CompressedBuffer::FromCompressed(SharedBuffer(BinaryData[CompressedId]), RawHash, RawSize);
+
+ if (!AttachmentBuffer || RawHash != DecompressedId)
+ {
+ Log().warn(
+ "Invalid output encountered (not valid CompressedBuffer format) {} compressed {} uncompressed",
+ CompressedId,
+ DecompressedId);
+ AnyErrors = true;
+ return;
+ }
+
+ ApplyResult.TotalAttachmentBytes += AttachmentBuffer.GetCompressedSize();
+ ApplyResult.TotalRawAttachmentBytes += RawSize;
+
+ CbAttachment Attachment(AttachmentBuffer, DecompressedId);
+ OutputPackage.AddAttachment(Attachment);
+ });
+
+ if (AnyErrors)
+ {
+ ApplyResult.Error.Reason = "Failed to get result object attachment data";
+ return ApplyResult;
+ }
+
+ OutputPackage.SetObject(BuildOutputObject);
+ ApplyResult.OutputPackage = std::move(OutputPackage);
+
+ ApplyResult.Success = ApplyResult.Error.ErrorCode == 0;
+ return ApplyResult;
+ }
+ break;
+ }
+
+ ApplyResult.Error.Reason = "Unknown apply type";
+ return ApplyResult;
+ }
+ catch (std::exception& Err)
+ {
+ return {.Error{.ErrorCode = -1, .Reason = Err.what()}};
+ }
+ }
+
+ [[nodiscard]] bool ProcessApplyKey(const UpstreamApplyRecord& ApplyRecord, UpstreamData& Data)
+ {
+ std::string ExecutablePath;
+ std::string WorkingDirectory;
+ std::vector<std::string> Arguments;
+ std::map<std::string, std::string> Environment;
+ std::set<std::filesystem::path> InputFiles;
+ std::set<std::string> Outputs;
+ std::map<std::filesystem::path, IoHash> InputFileHashes;
+
+ ExecutablePath = ApplyRecord.WorkerDescriptor["path"sv].AsString();
+ if (ExecutablePath.empty())
+ {
+ Log().warn("process apply upstream FAILED, '{}', path missing from worker descriptor",
+ ApplyRecord.WorkerDescriptor.GetHash());
+ return false;
+ }
+
+ WorkingDirectory = ApplyRecord.WorkerDescriptor["workdir"sv].AsString();
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["executables"sv])
+ {
+ CbObjectView FileEntry = It.AsObjectView();
+ if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds))
+ {
+ return false;
+ }
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["files"sv])
+ {
+ CbObjectView FileEntry = It.AsObjectView();
+ if (!ProcessFileEntry(FileEntry, InputFiles, InputFileHashes, Data.CasIds))
+ {
+ return false;
+ }
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["dirs"sv])
+ {
+ std::string_view Directory = It.AsString();
+ std::string DummyFile = fmt::format("{}/.zen_empty_file", Directory);
+ InputFiles.insert(DummyFile);
+ Data.Blobs[EmptyBufferId] = EmptyBuffer;
+ InputFileHashes[DummyFile] = EmptyBufferId;
+ }
+
+ if (!WorkingDirectory.empty())
+ {
+ std::string DummyFile = fmt::format("{}/.zen_empty_file", WorkingDirectory);
+ InputFiles.insert(DummyFile);
+ Data.Blobs[EmptyBufferId] = EmptyBuffer;
+ InputFileHashes[DummyFile] = EmptyBufferId;
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["environment"sv])
+ {
+ std::string_view Env = It.AsString();
+ auto Index = Env.find('=');
+ if (Index == std::string_view::npos)
+ {
+ Log().warn("process apply upstream FAILED, environment '{}' malformed", Env);
+ return false;
+ }
+
+ Environment[std::string(Env.substr(0, Index))] = Env.substr(Index + 1);
+ }
+
+ switch (ApplyRecord.Type)
+ {
+ case UpstreamApplyType::Simple:
+ {
+ for (auto& It : ApplyRecord.WorkerDescriptor["arguments"sv])
+ {
+ Arguments.push_back(std::string(It.AsString()));
+ }
+
+ for (auto& It : ApplyRecord.WorkerDescriptor["outputs"sv])
+ {
+ Outputs.insert(std::string(It.AsString()));
+ }
+ }
+ break;
+ case UpstreamApplyType::Asset:
+ {
+ static const std::filesystem::path BuildActionPath = "Build.action"sv;
+ static const std::filesystem::path InputPath = "Inputs"sv;
+ const IoHash ActionId = ApplyRecord.Action.GetHash();
+
+ Arguments.push_back("-Build=build.action");
+ Outputs.insert("Build.output");
+ Outputs.insert("Outputs");
+
+ InputFiles.insert(BuildActionPath);
+ InputFileHashes[BuildActionPath] = ActionId;
+ Data.Blobs[ActionId] = IoBufferBuilder::MakeCloneFromMemory(ApplyRecord.Action.GetBuffer().GetData(),
+ ApplyRecord.Action.GetBuffer().GetSize());
+
+ bool AnyErrors = false;
+ ApplyRecord.Action.IterateAttachments([&](CbFieldView Field) {
+ const IoHash Cid = Field.AsHash();
+ const std::filesystem::path FilePath = {InputPath / Cid.ToHexString()};
+
+ if (!m_CidStore.ContainsChunk(Cid))
+ {
+ Log().warn("process apply upstream FAILED, input CID chunk '{}' missing", Cid);
+ AnyErrors = true;
+ return;
+ }
+
+ if (InputFiles.contains(FilePath))
+ {
+ return;
+ }
+
+ InputFiles.insert(FilePath);
+ InputFileHashes[FilePath] = Cid;
+ Data.Cids.insert(Cid);
+ });
+
+ if (AnyErrors)
+ {
+ return false;
+ }
+ }
+ break;
+ }
+
+ const UpstreamDirectory RootDirectory = BuildDirectoryTree(InputFiles);
+
+ CbObject Sandbox = BuildMerkleTreeDirectory(RootDirectory, InputFileHashes, Data.Cids, Data.Objects);
+ const IoHash SandboxHash = Sandbox.GetHash();
+ Data.Objects[SandboxHash] = std::move(Sandbox);
+
+ {
+ std::string_view HostPlatform = ApplyRecord.WorkerDescriptor["host"sv].AsString();
+ if (HostPlatform.empty())
+ {
+ Log().warn("process apply upstream FAILED, 'host' platform not provided");
+ return false;
+ }
+
+ int32_t LogicalCores = ApplyRecord.WorkerDescriptor["cores"sv].AsInt32();
+ int64_t Memory = ApplyRecord.WorkerDescriptor["memory"sv].AsInt64();
+ bool Exclusive = ApplyRecord.WorkerDescriptor["exclusive"sv].AsBool();
+
+ std::string Condition = fmt::format("Platform == '{}'", HostPlatform);
+ if (HostPlatform == "Win64")
+ {
+ // TODO
+ // Condition += " && Pool == 'Win-RemoteExec'";
+ }
+
+ std::map<std::string_view, int64_t> Resources;
+ if (LogicalCores > 0)
+ {
+ Resources["LogicalCores"sv] = LogicalCores;
+ }
+ if (Memory > 0)
+ {
+ Resources["RAM"sv] = std::max(Memory / 1024LL / 1024LL / 1024LL, 1LL);
+ }
+
+ CbObject Requirements = BuildRequirements(Condition, Resources, Exclusive);
+ const IoHash RequirementsId = Requirements.GetHash();
+ Data.Objects[RequirementsId] = std::move(Requirements);
+ Data.RequirementsId = RequirementsId;
+ }
+
+ CbObject Task = BuildTask(ExecutablePath, Arguments, Environment, WorkingDirectory, SandboxHash, Data.RequirementsId, Outputs);
+
+ const IoHash TaskId = Task.GetHash();
+ Data.Objects[TaskId] = std::move(Task);
+ Data.TaskId = TaskId;
+
+ return true;
+ }
+
+ [[nodiscard]] bool ProcessFileEntry(const CbObjectView& FileEntry,
+ std::set<std::filesystem::path>& InputFiles,
+ std::map<std::filesystem::path, IoHash>& InputFileHashes,
+ std::set<IoHash>& CasIds)
+ {
+ const std::filesystem::path FilePath = FileEntry["name"sv].AsString();
+ const IoHash ChunkId = FileEntry["hash"sv].AsHash();
+ const uint64_t Size = FileEntry["size"sv].AsUInt64();
+
+ if (!m_CidStore.ContainsChunk(ChunkId))
+ {
+ Log().warn("process apply upstream FAILED, worker CAS chunk '{}' missing", ChunkId);
+ return false;
+ }
+
+ if (InputFiles.contains(FilePath))
+ {
+ Log().warn("process apply upstream FAILED, worker CAS chunk '{}' size: {} duplicate filename {}", ChunkId, Size, FilePath);
+ return false;
+ }
+
+ InputFiles.insert(FilePath);
+ InputFileHashes[FilePath] = ChunkId;
+ CasIds.insert(ChunkId);
+ return true;
+ }
+
+ [[nodiscard]] UpstreamDirectory BuildDirectoryTree(const std::set<std::filesystem::path>& InputFiles)
+ {
+ static const std::filesystem::path RootPath;
+ std::map<std::filesystem::path, UpstreamDirectory*> AllDirectories;
+ UpstreamDirectory RootDirectory = {.Path = RootPath};
+
+ AllDirectories[RootPath] = &RootDirectory;
+
+ // Build tree from flat list
+ for (const auto& Path : InputFiles)
+ {
+ if (Path.has_parent_path())
+ {
+ if (!AllDirectories.contains(Path.parent_path()))
+ {
+ std::stack<std::string> PathSplit;
+ {
+ std::filesystem::path ParentPath = Path.parent_path();
+ PathSplit.push(ParentPath.filename().string());
+ while (ParentPath.has_parent_path())
+ {
+ ParentPath = ParentPath.parent_path();
+ PathSplit.push(ParentPath.filename().string());
+ }
+ }
+ UpstreamDirectory* ParentPtr = &RootDirectory;
+ while (!PathSplit.empty())
+ {
+ if (!ParentPtr->Directories.contains(PathSplit.top()))
+ {
+ std::filesystem::path NewParentPath = {ParentPtr->Path / PathSplit.top()};
+ ParentPtr->Directories[PathSplit.top()] = {.Path = NewParentPath};
+ AllDirectories[NewParentPath] = &ParentPtr->Directories[PathSplit.top()];
+ }
+ ParentPtr = &ParentPtr->Directories[PathSplit.top()];
+ PathSplit.pop();
+ }
+ }
+
+ AllDirectories[Path.parent_path()]->Files.insert(Path.filename().string());
+ }
+ else
+ {
+ RootDirectory.Files.insert(Path.filename().string());
+ }
+ }
+
+ return RootDirectory;
+ }
+
+ [[nodiscard]] CbObject BuildMerkleTreeDirectory(const UpstreamDirectory& RootDirectory,
+ const std::map<std::filesystem::path, IoHash>& InputFileHashes,
+ const std::set<IoHash>& Cids,
+ std::map<IoHash, CbObject>& Objects)
+ {
+ CbObjectWriter DirectoryTreeWriter;
+
+ if (!RootDirectory.Files.empty())
+ {
+ DirectoryTreeWriter.BeginArray("f"sv);
+ for (const auto& File : RootDirectory.Files)
+ {
+ const std::filesystem::path FilePath = {RootDirectory.Path / File};
+ const IoHash& FileHash = InputFileHashes.at(FilePath);
+ const bool Compressed = Cids.contains(FileHash);
+ DirectoryTreeWriter.BeginObject();
+ DirectoryTreeWriter.AddString("n"sv, File);
+ DirectoryTreeWriter.AddBinaryAttachment("h"sv, FileHash);
+ DirectoryTreeWriter.AddBool("c"sv, Compressed);
+ DirectoryTreeWriter.EndObject();
+ }
+ DirectoryTreeWriter.EndArray();
+ }
+
+ if (!RootDirectory.Directories.empty())
+ {
+ DirectoryTreeWriter.BeginArray("d"sv);
+ for (const auto& Item : RootDirectory.Directories)
+ {
+ CbObject Directory = BuildMerkleTreeDirectory(Item.second, InputFileHashes, Cids, Objects);
+ const IoHash DirectoryHash = Directory.GetHash();
+ Objects[DirectoryHash] = std::move(Directory);
+
+ DirectoryTreeWriter.BeginObject();
+ DirectoryTreeWriter.AddString("n"sv, Item.first);
+ DirectoryTreeWriter.AddObjectAttachment("h"sv, DirectoryHash);
+ DirectoryTreeWriter.EndObject();
+ }
+ DirectoryTreeWriter.EndArray();
+ }
+
+ return DirectoryTreeWriter.Save();
+ }
+
+ void ResolveMerkleTreeDirectory(const std::filesystem::path& ParentDirectory,
+ const IoHash& DirectoryHash,
+ const std::map<IoHash, IoBuffer>& Objects,
+ std::map<std::filesystem::path, IoHash>& OutputFiles)
+ {
+ CbObject Directory = LoadCompactBinaryObject(Objects.at(DirectoryHash));
+
+ for (auto& It : Directory["f"sv])
+ {
+ const CbObjectView FileObject = It.AsObjectView();
+ const std::filesystem::path Path = ParentDirectory / FileObject["n"sv].AsString();
+
+ OutputFiles[Path] = FileObject["h"sv].AsBinaryAttachment();
+ }
+
+ for (auto& It : Directory["d"sv])
+ {
+ const CbObjectView DirectoryObject = It.AsObjectView();
+
+ ResolveMerkleTreeDirectory(ParentDirectory / DirectoryObject["n"sv].AsString(),
+ DirectoryObject["h"sv].AsObjectAttachment(),
+ Objects,
+ OutputFiles);
+ }
+ }
+
+ [[nodiscard]] CbObject BuildRequirements(const std::string_view Condition,
+ const std::map<std::string_view, int64_t>& Resources,
+ const bool Exclusive)
+ {
+ CbObjectWriter Writer;
+ Writer.AddString("c", Condition);
+ if (!Resources.empty())
+ {
+ Writer.BeginArray("r");
+ for (const auto& Resource : Resources)
+ {
+ Writer.BeginArray();
+ Writer.AddString(Resource.first);
+ Writer.AddInteger(Resource.second);
+ Writer.EndArray();
+ }
+ Writer.EndArray();
+ }
+ Writer.AddBool("e", Exclusive);
+ return Writer.Save();
+ }
+
+ [[nodiscard]] CbObject BuildTask(const std::string_view Executable,
+ const std::vector<std::string>& Arguments,
+ const std::map<std::string, std::string>& Environment,
+ const std::string_view WorkingDirectory,
+ const IoHash& SandboxHash,
+ const IoHash& RequirementsId,
+ const std::set<std::string>& Outputs)
+ {
+ CbObjectWriter TaskWriter;
+ TaskWriter.AddString("e"sv, Executable);
+
+ if (!Arguments.empty())
+ {
+ TaskWriter.BeginArray("a"sv);
+ for (const auto& Argument : Arguments)
+ {
+ TaskWriter.AddString(Argument);
+ }
+ TaskWriter.EndArray();
+ }
+
+ if (!Environment.empty())
+ {
+ TaskWriter.BeginArray("v"sv);
+ for (const auto& Env : Environment)
+ {
+ TaskWriter.BeginArray();
+ TaskWriter.AddString(Env.first);
+ TaskWriter.AddString(Env.second);
+ TaskWriter.EndArray();
+ }
+ TaskWriter.EndArray();
+ }
+
+ if (!WorkingDirectory.empty())
+ {
+ TaskWriter.AddString("w"sv, WorkingDirectory);
+ }
+
+ TaskWriter.AddObjectAttachment("s"sv, SandboxHash);
+ TaskWriter.AddObjectAttachment("r"sv, RequirementsId);
+
+ // Outputs
+ if (!Outputs.empty())
+ {
+ TaskWriter.BeginArray("o"sv);
+ for (const auto& Output : Outputs)
+ {
+ TaskWriter.AddString(Output);
+ }
+ TaskWriter.EndArray();
+ }
+
+ return TaskWriter.Save();
+ }
+ };
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+std::unique_ptr<UpstreamApplyEndpoint>
+UpstreamApplyEndpoint::CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ CidStore& CidStore,
+ AuthMgr& Mgr)
+{
+ return std::make_unique<detail::HordeUpstreamApplyEndpoint>(ComputeOptions,
+ ComputeAuthConfig,
+ StorageOptions,
+ StorageAuthConfig,
+ CidStore,
+ Mgr);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/upstream/jupiter.cpp b/src/zenserver/upstream/jupiter.cpp
new file mode 100644
index 000000000..dbb185bec
--- /dev/null
+++ b/src/zenserver/upstream/jupiter.cpp
@@ -0,0 +1,965 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "jupiter.h"
+
+#include "diag/formatters.h"
+#include "diag/logging.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compositebuffer.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+#include <fmt/format.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#if ZEN_PLATFORM_WINDOWS
+# pragma comment(lib, "Crypt32.lib")
+# pragma comment(lib, "Wldap32.lib")
+#endif
+
+#include <json11.hpp>
+
+using namespace std::literals;
+
+namespace zen {
+
+namespace detail {
+ struct CloudCacheSessionState
+ {
+ CloudCacheSessionState(CloudCacheClient& Client) : m_Client(Client) {}
+
+ const CloudCacheAccessToken& GetAccessToken(bool RefreshToken)
+ {
+ if (RefreshToken)
+ {
+ m_AccessToken = m_Client.AcquireAccessToken();
+ }
+
+ return m_AccessToken;
+ }
+
+ cpr::Session& GetSession() { return m_Session; }
+
+ void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout)
+ {
+ m_Session.SetBody({});
+ m_Session.SetHeader({});
+ m_Session.SetConnectTimeout(ConnectTimeout);
+ m_Session.SetTimeout(Timeout);
+ }
+
+ private:
+ friend class zen::CloudCacheClient;
+
+ CloudCacheClient& m_Client;
+ CloudCacheAccessToken m_AccessToken;
+ cpr::Session m_Session;
+ };
+
+} // namespace detail
+
+CloudCacheSession::CloudCacheSession(CloudCacheClient* CacheClient) : m_Log(CacheClient->Logger()), m_CacheClient(CacheClient)
+{
+ m_SessionState = m_CacheClient->AllocSessionState();
+}
+
+CloudCacheSession::~CloudCacheSession()
+{
+ m_CacheClient->FreeSessionState(m_SessionState);
+}
+
+CloudCacheResult
+CloudCacheSession::Authenticate()
+{
+ const bool RefreshToken = true;
+ const CloudCacheAccessToken& AccessToken = GetAccessToken(RefreshToken);
+
+ return {.Success = AccessToken.IsValid()};
+}
+
+CloudCacheResult
+CloudCacheSession::GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType)
+{
+ const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream";
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", ContentType}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetBlob(std::string_view Namespace, const IoHash& Key)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/octet-stream"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer =
+ Success && Response.text.size() > 0 ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetCompressedBlob(std::string_view Namespace, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::GetCompressedBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-comp"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash)
+{
+ ZEN_TRACE_CPU("HordeClient::GetInlineBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-jupiter-inline"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+ if (auto It = Response.header.find("X-Jupiter-InlinePayloadHash"); It != Response.header.end())
+ {
+ const std::string& PayloadHashHeader = It->second;
+ if (PayloadHashHeader.length() == IoHash::StringLength)
+ {
+ OutPayloadHash = IoHash::FromHexString(PayloadHashHeader);
+ }
+ }
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+CloudCacheResult
+CloudCacheSession::GetObject(std::string_view Namespace, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::GetObject");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+PutRefResult
+CloudCacheSession::PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType)
+{
+ ZEN_TRACE_CPU("HordeClient::PutRef");
+
+ IoHash Hash = IoHash::HashBuffer(Ref.Data(), Ref.Size());
+
+ const std::string ContentType = RefType == ZenContentType::kCbObject ? "application/x-ue-cb" : "application/octet-stream";
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(
+ cpr::Header{{"Authorization", AccessToken.Value}, {"X-Jupiter-IoHash", Hash.ToHexString()}, {"Content-Type", ContentType}});
+ Session.SetBody(cpr::Body{(const char*)Ref.Data(), Ref.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ PutRefResult Result;
+ Result.ErrorCode = static_cast<int32_t>(Response.error.code);
+ Result.Reason = std::move(Response.error.message);
+ return Result;
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ PutRefResult Result;
+ Result.ErrorCode = 401;
+ Result.Reason = "Invalid access token"sv;
+ return Result;
+ }
+
+ PutRefResult Result;
+ Result.Success = (Response.status_code == 200 || Response.status_code == 201);
+ Result.Bytes = Response.uploaded_bytes;
+ Result.ElapsedSeconds = Response.elapsed;
+
+ if (Result.Success)
+ {
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+ if (JsonError.empty())
+ {
+ json11::Json::array Needs = Json["needs"].array_items();
+ for (const auto& Need : Needs)
+ {
+ Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value()));
+ }
+ }
+ }
+
+ return Result;
+}
+
+FinalizeRefResult
+CloudCacheSession::FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHash)
+{
+ ZEN_TRACE_CPU("HordeClient::FinalizeRef");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString() << "/finalize/"
+ << RefHash.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value},
+ {"X-Jupiter-IoHash", RefHash.ToHexString()},
+ {"Content-Type", "application/x-ue-cb"}});
+ Session.SetBody(cpr::Body{});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ FinalizeRefResult Result;
+ Result.ErrorCode = static_cast<int32_t>(Response.error.code);
+ Result.Reason = std::move(Response.error.message);
+ return Result;
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ FinalizeRefResult Result;
+ Result.ErrorCode = 401;
+ Result.Reason = "Invalid access token"sv;
+ return Result;
+ }
+
+ FinalizeRefResult Result;
+ Result.Success = (Response.status_code == 200 || Response.status_code == 201);
+ Result.Bytes = Response.uploaded_bytes;
+ Result.ElapsedSeconds = Response.elapsed;
+
+ if (Result.Success)
+ {
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+ if (JsonError.empty())
+ {
+ json11::Json::array Needs = Json["needs"].array_items();
+ for (const auto& Need : Needs)
+ {
+ Result.Needs.emplace_back(IoHash::FromHexString(Need.string_value()));
+ }
+ }
+ }
+
+ return Result;
+}
+
+CloudCacheResult
+CloudCacheSession::PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob)
+{
+ ZEN_TRACE_CPU("HordeClient::PutBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/octet-stream"}});
+ Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob)
+{
+ ZEN_TRACE_CPU("HordeClient::PutCompressedBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}});
+ Session.SetBody(cpr::Body{(const char*)Blob.Data(), Blob.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Payload)
+{
+ ZEN_TRACE_CPU("HordeClient::PutCompressedBlob");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compressed-blobs/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-comp"}});
+ uint64_t SizeLeft = Payload.GetSize();
+ CompositeBuffer::Iterator BufferIt = Payload.GetIterator(0);
+ auto ReadCallback = [&Payload, &BufferIt, &SizeLeft](char* buffer, size_t& size, intptr_t) {
+ size = Min<size_t>(size, SizeLeft);
+ MutableMemoryView Data(buffer, size);
+ Payload.CopyTo(Data, BufferIt);
+ SizeLeft -= size;
+ return true;
+ };
+ Session.SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(SizeLeft), ReadCallback));
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object)
+{
+ ZEN_TRACE_CPU("HordeClient::PutObject");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}});
+ Session.SetBody(cpr::Body{(const char*)Object.Data(), Object.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Success = (Response.status_code == 200 || Response.status_code == 201)};
+}
+
+CloudCacheResult
+CloudCacheSession::RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::RefExists");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/refs/" << Namespace << "/" << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Head();
+ ZEN_DEBUG("HEAD {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+GetObjectReferencesResult
+CloudCacheSession::GetObjectReferences(std::string_view Namespace, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::GetObjectReferences");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/objects/" << Namespace << "/" << Key.ToHexString() << "/references";
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}};
+ }
+
+ GetObjectReferencesResult Result{
+ CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}};
+
+ if (Result.Success)
+ {
+ IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ const CbObject ReferencesResponse = LoadCompactBinaryObject(Buffer);
+ for (auto& Item : ReferencesResponse["references"sv])
+ {
+ Result.References.insert(Item.AsHash());
+ }
+ }
+
+ return Result;
+}
+
+CloudCacheResult
+CloudCacheSession::BlobExists(std::string_view Namespace, const IoHash& Key)
+{
+ return CacheTypeExists(Namespace, "blobs"sv, Key);
+}
+
+CloudCacheResult
+CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const IoHash& Key)
+{
+ return CacheTypeExists(Namespace, "compressed-blobs"sv, Key);
+}
+
+CloudCacheResult
+CloudCacheSession::ObjectExists(std::string_view Namespace, const IoHash& Key)
+{
+ return CacheTypeExists(Namespace, "objects"sv, Key);
+}
+
+CloudCacheExistsResult
+CloudCacheSession::BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys)
+{
+ return CacheTypeExists(Namespace, "blobs"sv, Keys);
+}
+
+CloudCacheExistsResult
+CloudCacheSession::CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys)
+{
+ return CacheTypeExists(Namespace, "compressed-blobs"sv, Keys);
+}
+
+CloudCacheExistsResult
+CloudCacheSession::ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys)
+{
+ return CacheTypeExists(Namespace, "objects"sv, Keys);
+}
+
+CloudCacheResult
+CloudCacheSession::PostComputeTasks(IoBuffer TasksData)
+{
+ ZEN_TRACE_CPU("HordeClient::PostComputeTasks");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Content-Type", "application/x-ue-cb"}});
+ Session.SetBody(cpr::Body{(const char*)TasksData.Data(), TasksData.Size()});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+CloudCacheResult
+CloudCacheSession::GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds)
+{
+ ZEN_TRACE_CPU("HordeClient::GetComputeUpdates");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/compute/" << m_CacheClient->ComputeCluster() << "/updates/" << ChannelId
+ << "?wait=" << WaitSeconds;
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+std::vector<IoHash>
+CloudCacheSession::Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl();
+ Uri << "/api/v1/s/" << Namespace;
+
+ ZEN_UNUSED(BucketId, ChunkHashes);
+
+ return {};
+}
+
+cpr::Session&
+CloudCacheSession::GetSession()
+{
+ return m_SessionState->GetSession();
+}
+
+CloudCacheAccessToken
+CloudCacheSession::GetAccessToken(bool RefreshToken)
+{
+ return m_SessionState->GetAccessToken(RefreshToken);
+}
+
+bool
+CloudCacheSession::VerifyAccessToken(long StatusCode)
+{
+ return StatusCode != 401;
+}
+
+CloudCacheResult
+CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key)
+{
+ ZEN_TRACE_CPU("HordeClient::CacheTypeExists");
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/" << Key.ToHexString();
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(cpr::Header{{"Authorization", AccessToken.Value}});
+ Session.SetOption(cpr::Body{});
+
+ cpr::Response Response = Session.Head();
+ ZEN_DEBUG("HEAD {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {.ErrorCode = 401, .Reason = std::string("Invalid access token")};
+ }
+
+ return {.ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+CloudCacheExistsResult
+CloudCacheSession::CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys)
+{
+ ZEN_TRACE_CPU("HordeClient::CacheTypeExists");
+
+ ExtendableStringBuilder<256> Body;
+ Body << "[";
+ for (const auto& Key : Keys)
+ {
+ Body << (Body.Size() != 1 ? ",\"" : "\"") << Key.ToHexString() << "\"";
+ }
+ Body << "]";
+
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_CacheClient->ServiceUrl() << "/api/v1/" << TypeId << "/" << Namespace << "/exist";
+
+ cpr::Session& Session = GetSession();
+ const CloudCacheAccessToken& AccessToken = GetAccessToken();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetOption(
+ cpr::Header{{"Authorization", AccessToken.Value}, {"Accept", "application/x-ue-cb"}, {"Content-Type", "application/json"}});
+ Session.SetOption(cpr::Body(Body.ToString()));
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {CloudCacheResult{.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = Response.error.message}};
+ }
+ else if (!VerifyAccessToken(Response.status_code))
+ {
+ return {CloudCacheResult{.ErrorCode = 401, .Reason = std::string("Invalid access token")}};
+ }
+
+ CloudCacheExistsResult Result{
+ CloudCacheResult{.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200}};
+
+ if (Result.Success)
+ {
+ IoBuffer Buffer = IoBuffer(zen::IoBuffer::Wrap, Response.text.data(), Response.text.size());
+ const CbObject ExistsResponse = LoadCompactBinaryObject(Buffer);
+ for (auto& Item : ExistsResponse["needs"sv])
+ {
+ Result.Needs.insert(Item.AsHash());
+ }
+ }
+
+ return Result;
+}
+
+/**
+ * An access token provider that holds a token that will never change.
+ */
+class StaticTokenProvider final : public CloudCacheTokenProvider
+{
+public:
+ StaticTokenProvider(CloudCacheAccessToken Token) : m_Token(std::move(Token)) {}
+
+ virtual ~StaticTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Token; }
+
+private:
+ CloudCacheAccessToken m_Token;
+};
+
+std::unique_ptr<CloudCacheTokenProvider>
+CloudCacheTokenProvider::CreateFromStaticToken(CloudCacheAccessToken Token)
+{
+ return std::make_unique<StaticTokenProvider>(std::move(Token));
+}
+
+class OAuthClientCredentialsTokenProvider final : public CloudCacheTokenProvider
+{
+public:
+ OAuthClientCredentialsTokenProvider(const CloudCacheTokenProvider::OAuthClientCredentialsParams& Params)
+ {
+ m_Url = std::string(Params.Url);
+ m_ClientId = std::string(Params.ClientId);
+ m_ClientSecret = std::string(Params.ClientSecret);
+ }
+
+ virtual ~OAuthClientCredentialsTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() final override
+ {
+ using namespace std::chrono;
+
+ std::string Body =
+ fmt::format("client_id={}&scope=cache_access&grant_type=client_credentials&client_secret={}", m_ClientId, m_ClientSecret);
+
+ cpr::Response Response =
+ cpr::Post(cpr::Url{m_Url}, cpr::Header{{"Content-Type", "application/x-www-form-urlencoded"}}, cpr::Body{std::move(Body)});
+
+ if (Response.error || Response.status_code != 200)
+ {
+ return {};
+ }
+
+ std::string JsonError;
+ json11::Json Json = json11::Json::parse(Response.text, JsonError);
+
+ if (JsonError.empty() == false)
+ {
+ return {};
+ }
+
+ std::string Token = Json["access_token"].string_value();
+ int64_t ExpiresInSeconds = static_cast<int64_t>(Json["expires_in"].int_value());
+ CloudCacheAccessToken::TimePoint ExpireTime = CloudCacheAccessToken::Clock::now() + seconds(ExpiresInSeconds);
+
+ return {.Value = fmt::format("Bearer {}", Token), .ExpireTime = ExpireTime};
+ }
+
+private:
+ std::string m_Url;
+ std::string m_ClientId;
+ std::string m_ClientSecret;
+};
+
+std::unique_ptr<CloudCacheTokenProvider>
+CloudCacheTokenProvider::CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params)
+{
+ return std::make_unique<OAuthClientCredentialsTokenProvider>(Params);
+}
+
+class CallbackTokenProvider final : public CloudCacheTokenProvider
+{
+public:
+ CallbackTokenProvider(std::function<CloudCacheAccessToken()>&& Callback) : m_Callback(std::move(Callback)) {}
+
+ virtual ~CallbackTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() final override { return m_Callback(); }
+
+private:
+ std::function<CloudCacheAccessToken()> m_Callback;
+};
+
+std::unique_ptr<CloudCacheTokenProvider>
+CloudCacheTokenProvider::CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback)
+{
+ return std::make_unique<CallbackTokenProvider>(std::move(Callback));
+}
+
+CloudCacheClient::CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider)
+: m_Log(zen::logging::Get("jupiter"))
+, m_ServiceUrl(Options.ServiceUrl)
+, m_DefaultDdcNamespace(Options.DdcNamespace)
+, m_DefaultBlobStoreNamespace(Options.BlobStoreNamespace)
+, m_ComputeCluster(Options.ComputeCluster)
+, m_ConnectTimeout(Options.ConnectTimeout)
+, m_Timeout(Options.Timeout)
+, m_TokenProvider(std::move(TokenProvider))
+{
+ ZEN_ASSERT(m_TokenProvider.get() != nullptr);
+}
+
+CloudCacheClient::~CloudCacheClient()
+{
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+
+ for (auto State : m_SessionStateCache)
+ {
+ delete State;
+ }
+}
+
+CloudCacheAccessToken
+CloudCacheClient::AcquireAccessToken()
+{
+ ZEN_TRACE_CPU("HordeClient::AcquireAccessToken");
+
+ return m_TokenProvider->AcquireAccessToken();
+}
+
+detail::CloudCacheSessionState*
+CloudCacheClient::AllocSessionState()
+{
+ detail::CloudCacheSessionState* State = nullptr;
+
+ bool IsTokenValid = false;
+
+ {
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+
+ if (m_SessionStateCache.empty() == false)
+ {
+ State = m_SessionStateCache.front();
+ IsTokenValid = State->m_AccessToken.IsValid();
+
+ m_SessionStateCache.pop_front();
+ }
+ }
+
+ if (State == nullptr)
+ {
+ State = new detail::CloudCacheSessionState(*this);
+ }
+
+ State->Reset(m_ConnectTimeout, m_Timeout);
+
+ if (IsTokenValid == false)
+ {
+ State->m_AccessToken = m_TokenProvider->AcquireAccessToken();
+ }
+
+ return State;
+}
+
+void
+CloudCacheClient::FreeSessionState(detail::CloudCacheSessionState* State)
+{
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+ m_SessionStateCache.push_front(State);
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/jupiter.h b/src/zenserver/upstream/jupiter.h
new file mode 100644
index 000000000..99e5c530f
--- /dev/null
+++ b/src/zenserver/upstream/jupiter.h
@@ -0,0 +1,217 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/refcount.h>
+#include <zencore/thread.h>
+#include <zenhttp/httpserver.h>
+
+#include <atomic>
+#include <chrono>
+#include <list>
+#include <memory>
+#include <set>
+#include <vector>
+
+struct ZenCacheValue;
+
+namespace cpr {
+class Session;
+}
+
+namespace zen {
+namespace detail {
+ struct CloudCacheSessionState;
+}
+
+class CbObjectView;
+class CloudCacheClient;
+class IoBuffer;
+struct IoHash;
+
+/**
+ * Cached access token, for use with `Authorization:` header
+ */
+struct CloudCacheAccessToken
+{
+ using Clock = std::chrono::system_clock;
+ using TimePoint = Clock::time_point;
+
+ static constexpr int64_t ExpireMarginInSeconds = 30;
+
+ std::string Value;
+ TimePoint ExpireTime;
+
+ bool IsValid() const
+ {
+ return Value.empty() == false &&
+ ExpireMarginInSeconds < std::chrono::duration_cast<std::chrono::seconds>(ExpireTime - Clock::now()).count();
+ }
+};
+
+struct CloudCacheResult
+{
+ IoBuffer Response;
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ int32_t ErrorCode{};
+ std::string Reason;
+ bool Success = false;
+};
+
+struct PutRefResult : CloudCacheResult
+{
+ std::vector<IoHash> Needs;
+};
+
+struct FinalizeRefResult : CloudCacheResult
+{
+ std::vector<IoHash> Needs;
+};
+
+struct CloudCacheExistsResult : CloudCacheResult
+{
+ std::set<IoHash> Needs;
+};
+
+struct GetObjectReferencesResult : CloudCacheResult
+{
+ std::set<IoHash> References;
+};
+
+/**
+ * Context for performing Jupiter operations
+ *
+ * Maintains an HTTP connection so that subsequent operations don't need to go
+ * through the whole connection setup process
+ *
+ */
+class CloudCacheSession
+{
+public:
+ CloudCacheSession(CloudCacheClient* CacheClient);
+ ~CloudCacheSession();
+
+ CloudCacheResult Authenticate();
+ CloudCacheResult GetRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType RefType);
+ CloudCacheResult GetBlob(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult GetCompressedBlob(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult GetObject(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult GetInlineBlob(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoHash& OutPayloadHash);
+
+ PutRefResult PutRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, IoBuffer Ref, ZenContentType RefType);
+ CloudCacheResult PutBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob);
+ CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, IoBuffer Blob);
+ CloudCacheResult PutCompressedBlob(std::string_view Namespace, const IoHash& Key, const CompositeBuffer& Blob);
+ CloudCacheResult PutObject(std::string_view Namespace, const IoHash& Key, IoBuffer Object);
+
+ FinalizeRefResult FinalizeRef(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& RefHah);
+
+ CloudCacheResult RefExists(std::string_view Namespace, std::string_view BucketId, const IoHash& Key);
+
+ GetObjectReferencesResult GetObjectReferences(std::string_view Namespace, const IoHash& Key);
+
+ CloudCacheResult BlobExists(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult CompressedBlobExists(std::string_view Namespace, const IoHash& Key);
+ CloudCacheResult ObjectExists(std::string_view Namespace, const IoHash& Key);
+
+ CloudCacheExistsResult BlobExists(std::string_view Namespace, const std::set<IoHash>& Keys);
+ CloudCacheExistsResult CompressedBlobExists(std::string_view Namespace, const std::set<IoHash>& Keys);
+ CloudCacheExistsResult ObjectExists(std::string_view Namespace, const std::set<IoHash>& Keys);
+
+ CloudCacheResult PostComputeTasks(IoBuffer TasksData);
+ CloudCacheResult GetComputeUpdates(std::string_view ChannelId, const uint32_t WaitSeconds = 0);
+
+ std::vector<IoHash> Filter(std::string_view Namespace, std::string_view BucketId, const std::vector<IoHash>& ChunkHashes);
+
+ CloudCacheClient& Client() { return *m_CacheClient; };
+
+private:
+ inline spdlog::logger& Log() { return m_Log; }
+ cpr::Session& GetSession();
+ CloudCacheAccessToken GetAccessToken(bool RefreshToken = false);
+ bool VerifyAccessToken(long StatusCode);
+
+ CloudCacheResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const IoHash& Key);
+
+ CloudCacheExistsResult CacheTypeExists(std::string_view Namespace, std::string_view TypeId, const std::set<IoHash>& Keys);
+
+ spdlog::logger& m_Log;
+ RefPtr<CloudCacheClient> m_CacheClient;
+ detail::CloudCacheSessionState* m_SessionState;
+};
+
+/**
+ * Access token provider interface
+ */
+class CloudCacheTokenProvider
+{
+public:
+ virtual ~CloudCacheTokenProvider() = default;
+
+ virtual CloudCacheAccessToken AcquireAccessToken() = 0;
+
+ static std::unique_ptr<CloudCacheTokenProvider> CreateFromStaticToken(CloudCacheAccessToken Token);
+
+ struct OAuthClientCredentialsParams
+ {
+ std::string_view Url;
+ std::string_view ClientId;
+ std::string_view ClientSecret;
+ };
+
+ static std::unique_ptr<CloudCacheTokenProvider> CreateFromOAuthClientCredentials(const OAuthClientCredentialsParams& Params);
+
+ static std::unique_ptr<CloudCacheTokenProvider> CreateFromCallback(std::function<CloudCacheAccessToken()>&& Callback);
+};
+
+struct CloudCacheClientOptions
+{
+ std::string_view Name;
+ std::string_view ServiceUrl;
+ std::string_view DdcNamespace;
+ std::string_view BlobStoreNamespace;
+ std::string_view ComputeCluster;
+ std::chrono::milliseconds ConnectTimeout{5000};
+ std::chrono::milliseconds Timeout{};
+};
+
+/**
+ * Jupiter upstream cache client
+ */
+class CloudCacheClient : public RefCounted
+{
+public:
+ CloudCacheClient(const CloudCacheClientOptions& Options, std::unique_ptr<CloudCacheTokenProvider> TokenProvider);
+ ~CloudCacheClient();
+
+ CloudCacheAccessToken AcquireAccessToken();
+ std::string_view DefaultDdcNamespace() const { return m_DefaultDdcNamespace; }
+ std::string_view DefaultBlobStoreNamespace() const { return m_DefaultBlobStoreNamespace; }
+ std::string_view ComputeCluster() const { return m_ComputeCluster; }
+ std::string_view ServiceUrl() const { return m_ServiceUrl; }
+
+ spdlog::logger& Logger() { return m_Log; }
+
+private:
+ spdlog::logger& m_Log;
+ std::string m_ServiceUrl;
+ std::string m_DefaultDdcNamespace;
+ std::string m_DefaultBlobStoreNamespace;
+ std::string m_ComputeCluster;
+ std::chrono::milliseconds m_ConnectTimeout{};
+ std::chrono::milliseconds m_Timeout{};
+ std::unique_ptr<CloudCacheTokenProvider> m_TokenProvider;
+
+ RwLock m_SessionStateLock;
+ std::list<detail::CloudCacheSessionState*> m_SessionStateCache;
+
+ detail::CloudCacheSessionState* AllocSessionState();
+ void FreeSessionState(detail::CloudCacheSessionState*);
+
+ friend class CloudCacheSession;
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstream.h b/src/zenserver/upstream/upstream.h
new file mode 100644
index 000000000..a57301206
--- /dev/null
+++ b/src/zenserver/upstream/upstream.h
@@ -0,0 +1,8 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <upstream/jupiter.h>
+#include <upstream/upstreamcache.h>
+#include <upstream/upstreamservice.h>
+#include <upstream/zen.h>
diff --git a/src/zenserver/upstream/upstreamapply.cpp b/src/zenserver/upstream/upstreamapply.cpp
new file mode 100644
index 000000000..c719b225d
--- /dev/null
+++ b/src/zenserver/upstream/upstreamapply.cpp
@@ -0,0 +1,459 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "upstreamapply.h"
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinary.h>
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/fmtutils.h>
+# include <zencore/stream.h>
+# include <zencore/timer.h>
+# include <zencore/workthreadpool.h>
+
+# include <zenstore/cidstore.h>
+
+# include "diag/logging.h"
+
+# include <fmt/format.h>
+
+# include <atomic>
+
+namespace zen {
+
+using namespace std::literals;
+
+struct UpstreamApplyStats
+{
+ static constexpr uint64_t MaxSampleCount = 1000ull;
+
+ UpstreamApplyStats(bool Enabled) : m_Enabled(Enabled) {}
+
+ void Add(UpstreamApplyEndpoint& Endpoint, const PostUpstreamApplyResult& Result)
+ {
+ UpstreamApplyEndpointStats& Stats = Endpoint.Stats();
+
+ if (Result.Error)
+ {
+ Stats.ErrorCount.Increment(1);
+ }
+ else if (Result.Success)
+ {
+ Stats.PostCount.Increment(1);
+ Stats.UpBytes.Increment(Result.Bytes / 1024 / 1024);
+ }
+ }
+
+ void Add(UpstreamApplyEndpoint& Endpoint, const GetUpstreamApplyUpdatesResult& Result)
+ {
+ UpstreamApplyEndpointStats& Stats = Endpoint.Stats();
+
+ if (Result.Error)
+ {
+ Stats.ErrorCount.Increment(1);
+ }
+ else if (Result.Success)
+ {
+ Stats.UpdateCount.Increment(1);
+ Stats.DownBytes.Increment(Result.Bytes / 1024 / 1024);
+ if (!Result.Completed.empty())
+ {
+ uint64_t Completed = 0;
+ for (auto& It : Result.Completed)
+ {
+ Completed += It.second.size();
+ }
+ Stats.CompleteCount.Increment(Completed);
+ }
+ }
+ }
+
+ bool m_Enabled;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+class UpstreamApplyImpl final : public UpstreamApply
+{
+public:
+ UpstreamApplyImpl(const UpstreamApplyOptions& Options, CidStore& CidStore)
+ : m_Log(logging::Get("upstream-apply"))
+ , m_Options(Options)
+ , m_CidStore(CidStore)
+ , m_Stats(Options.StatsEnabled)
+ , m_UpstreamAsyncWorkPool(Options.UpstreamThreadCount)
+ , m_DownstreamAsyncWorkPool(Options.DownstreamThreadCount)
+ {
+ }
+
+ virtual ~UpstreamApplyImpl() { Shutdown(); }
+
+ virtual bool Initialize() override
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ const UpstreamEndpointHealth Health = Endpoint->Initialize();
+ if (Health.Ok)
+ {
+ Log().info("initialize endpoint '{}' OK", Endpoint->DisplayName());
+ }
+ else
+ {
+ Log().warn("initialize endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason);
+ }
+ }
+
+ m_RunState.IsRunning = !m_Endpoints.empty();
+
+ if (m_RunState.IsRunning)
+ {
+ m_ShutdownEvent.Reset();
+
+ m_UpstreamUpdatesThread = std::thread(&UpstreamApplyImpl::ProcessUpstreamUpdates, this);
+
+ m_EndpointMonitorThread = std::thread(&UpstreamApplyImpl::MonitorEndpoints, this);
+ }
+
+ return m_RunState.IsRunning;
+ }
+
+ virtual bool IsHealthy() const override
+ {
+ if (m_RunState.IsRunning)
+ {
+ for (const auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->IsHealthy())
+ {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) override
+ {
+ m_Endpoints.emplace_back(std::move(Endpoint));
+ }
+
+ virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) override
+ {
+ if (m_RunState.IsRunning)
+ {
+ const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash();
+ const IoHash ActionId = ApplyRecord.Action.GetHash();
+ const uint32_t TimeoutSeconds = ApplyRecord.WorkerDescriptor["timeout"sv].AsInt32(300);
+
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ // Already in progress
+ return {.ApplyId = ActionId, .Success = true};
+ }
+
+ std::chrono::steady_clock::time_point ExpireTime =
+ TimeoutSeconds > 0 ? std::chrono::steady_clock::now() + std::chrono::seconds(TimeoutSeconds)
+ : std::chrono::steady_clock::time_point::max();
+
+ m_ApplyTasks[WorkerId][ActionId] = {.State = UpstreamApplyState::Queued, .Result{}, .ExpireTime = std::move(ExpireTime)};
+ }
+
+ ApplyRecord.Timepoints["zen-queue-added"] = DateTime::NowTicks();
+ m_UpstreamAsyncWorkPool.ScheduleWork(
+ [this, ApplyRecord = std::move(ApplyRecord)]() { ProcessApplyRecord(std::move(ApplyRecord)); });
+
+ return {.ApplyId = ActionId, .Success = true};
+ }
+
+ return {};
+ }
+
+ virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) override
+ {
+ if (m_RunState.IsRunning)
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ return {.Status = *Status, .Success = true};
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetStatus(CbObjectWriter& Status) override
+ {
+ Status << "upstream_worker_threads" << m_Options.UpstreamThreadCount;
+ Status << "upstream_queue_count" << m_UpstreamAsyncWorkPool.PendingWork();
+ Status << "downstream_worker_threads" << m_Options.DownstreamThreadCount;
+ Status << "downstream_queue_count" << m_DownstreamAsyncWorkPool.PendingWork();
+
+ Status.BeginArray("endpoints");
+ for (const auto& Ep : m_Endpoints)
+ {
+ Status.BeginObject();
+ Status << "name" << Ep->DisplayName();
+ Status << "health" << (Ep->IsHealthy() ? "ok"sv : "inactive"sv);
+
+ UpstreamApplyEndpointStats& Stats = Ep->Stats();
+ const uint64_t PostCount = Stats.PostCount.Value();
+ const uint64_t CompleteCount = Stats.CompleteCount.Value();
+ // const uint64_t UpdateCount = Stats.UpdateCount;
+ const double CompleteRate = CompleteCount > 0 ? (double(PostCount) / double(CompleteCount)) : 0.0;
+
+ Status << "post_count" << PostCount;
+ Status << "complete_count" << PostCount;
+ Status << "update_count" << Stats.UpdateCount.Value();
+
+ Status << "complete_ratio" << CompleteRate;
+ Status << "downloaded_mb" << Stats.DownBytes.Value();
+ Status << "uploaded_mb" << Stats.UpBytes.Value();
+ Status << "error_count" << Stats.ErrorCount.Value();
+
+ Status.EndObject();
+ }
+ Status.EndArray();
+ }
+
+private:
+ // The caller is responsible for locking if required
+ UpstreamApplyStatus* FindStatus(const IoHash& WorkerId, const IoHash& ActionId)
+ {
+ if (auto It = m_ApplyTasks.find(WorkerId); It != m_ApplyTasks.end())
+ {
+ if (auto It2 = It->second.find(ActionId); It2 != It->second.end())
+ {
+ return &It2->second;
+ }
+ }
+ return nullptr;
+ }
+
+ void ProcessApplyRecord(UpstreamApplyRecord ApplyRecord)
+ {
+ const IoHash WorkerId = ApplyRecord.WorkerDescriptor.GetHash();
+ const IoHash ActionId = ApplyRecord.Action.GetHash();
+ try
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->IsHealthy())
+ {
+ ApplyRecord.Timepoints["zen-queue-dispatched"] = DateTime::NowTicks();
+ PostUpstreamApplyResult Result = Endpoint->PostApply(std::move(ApplyRecord));
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ Status->Timepoints.merge(Result.Timepoints);
+
+ if (Result.Success)
+ {
+ Status->State = UpstreamApplyState::Executing;
+ }
+ else
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = {.Error = std::move(Result.Error),
+ .Bytes = Result.Bytes,
+ .ElapsedSeconds = Result.ElapsedSeconds};
+ }
+ }
+ }
+ m_Stats.Add(*Endpoint, Result);
+ return;
+ }
+ }
+
+ Log().warn("process upstream apply ({}/{}) FAILED 'No available endpoint'", WorkerId, ActionId);
+
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = {.Error{.ErrorCode = -1, .Reason = "No available endpoint"}};
+ }
+ }
+ }
+ catch (std::exception& e)
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(WorkerId, ActionId); Status != nullptr)
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = {.Error{.ErrorCode = -1, .Reason = e.what()}};
+ }
+ Log().warn("process upstream apply ({}/{}) FAILED '{}'", WorkerId, ActionId, e.what());
+ }
+ }
+
+ void ProcessApplyUpdates()
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->IsHealthy())
+ {
+ GetUpstreamApplyUpdatesResult Result = Endpoint->GetUpdates(m_DownstreamAsyncWorkPool);
+ m_Stats.Add(*Endpoint, Result);
+
+ if (!Result.Success)
+ {
+ Log().warn("process upstream apply updates FAILED '{}'", Result.Error.Reason);
+ }
+
+ if (!Result.Completed.empty())
+ {
+ for (auto& It : Result.Completed)
+ {
+ for (auto& It2 : It.second)
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ if (auto Status = FindStatus(It.first, It2.first); Status != nullptr)
+ {
+ Status->State = UpstreamApplyState::Complete;
+ Status->Result = std::move(It2.second);
+ Status->Result.Timepoints.merge(Status->Timepoints);
+ Status->Result.Timepoints["zen-queue-complete"] = DateTime::NowTicks();
+ Status->Timepoints.clear();
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ void ProcessUpstreamUpdates()
+ {
+ const auto& UpdateSleep = std::chrono::milliseconds(m_Options.UpdatesInterval);
+ while (!m_ShutdownEvent.Wait(uint32_t(UpdateSleep.count())))
+ {
+ if (!m_RunState.IsRunning)
+ {
+ break;
+ }
+
+ ProcessApplyUpdates();
+
+ // Remove any expired tasks, regardless of state
+ {
+ std::scoped_lock Lock(m_ApplyTasksMutex);
+ for (auto& WorkerIt : m_ApplyTasks)
+ {
+ const auto Count = std::erase_if(WorkerIt.second, [](const auto& Item) {
+ return Item.second.ExpireTime < std::chrono::steady_clock::now();
+ });
+ if (Count > 0)
+ {
+ Log().debug("Removed '{}' expired tasks", Count);
+ }
+ }
+ const auto Count = std::erase_if(m_ApplyTasks, [](const auto& Item) { return Item.second.empty(); });
+ if (Count > 0)
+ {
+ Log().debug("Removed '{}' empty task lists", Count);
+ }
+ }
+ }
+ }
+
+ void MonitorEndpoints()
+ {
+ for (;;)
+ {
+ {
+ std::unique_lock Lock(m_RunState.Mutex);
+ if (m_RunState.ExitSignal.wait_for(Lock, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); }))
+ {
+ break;
+ }
+ }
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (!Endpoint->IsHealthy())
+ {
+ if (const UpstreamEndpointHealth Health = Endpoint->CheckHealth(); Health.Ok)
+ {
+ Log().warn("health check endpoint '{}' OK", Endpoint->DisplayName(), Health.Reason);
+ }
+ else
+ {
+ Log().warn("health check endpoint '{}' FAILED, reason '{}'", Endpoint->DisplayName(), Health.Reason);
+ }
+ }
+ }
+ }
+ }
+
+ void Shutdown()
+ {
+ if (m_RunState.Stop())
+ {
+ m_ShutdownEvent.Set();
+ m_EndpointMonitorThread.join();
+ m_UpstreamUpdatesThread.join();
+ m_Endpoints.clear();
+ }
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ struct RunState
+ {
+ std::mutex Mutex;
+ std::condition_variable ExitSignal;
+ std::atomic_bool IsRunning{false};
+
+ bool Stop()
+ {
+ bool Stopped = false;
+ {
+ std::scoped_lock Lock(Mutex);
+ Stopped = IsRunning.exchange(false);
+ }
+ if (Stopped)
+ {
+ ExitSignal.notify_all();
+ }
+ return Stopped;
+ }
+ };
+
+ spdlog::logger& m_Log;
+ UpstreamApplyOptions m_Options;
+ CidStore& m_CidStore;
+ UpstreamApplyStats m_Stats;
+ UpstreamApplyTasks m_ApplyTasks;
+ std::mutex m_ApplyTasksMutex;
+ std::vector<std::unique_ptr<UpstreamApplyEndpoint>> m_Endpoints;
+ Event m_ShutdownEvent;
+ WorkerThreadPool m_UpstreamAsyncWorkPool;
+ WorkerThreadPool m_DownstreamAsyncWorkPool;
+ std::thread m_UpstreamUpdatesThread;
+ std::thread m_EndpointMonitorThread;
+ RunState m_RunState;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+bool
+UpstreamApply::IsHealthy() const
+{
+ return false;
+}
+
+std::unique_ptr<UpstreamApply>
+UpstreamApply::Create(const UpstreamApplyOptions& Options, CidStore& CidStore)
+{
+ return std::make_unique<UpstreamApplyImpl>(Options, CidStore);
+}
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/upstream/upstreamapply.h b/src/zenserver/upstream/upstreamapply.h
new file mode 100644
index 000000000..4a095be6c
--- /dev/null
+++ b/src/zenserver/upstream/upstreamapply.h
@@ -0,0 +1,192 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#if ZEN_WITH_COMPUTE_SERVICES
+
+# include <zencore/compactbinarypackage.h>
+# include <zencore/iobuffer.h>
+# include <zencore/iohash.h>
+# include <zencore/stats.h>
+# include <zencore/zencore.h>
+
+# include <chrono>
+# include <map>
+# include <unordered_map>
+# include <unordered_set>
+
+namespace zen {
+
+class AuthMgr;
+class CbObjectWriter;
+class CidStore;
+class CloudCacheTokenProvider;
+class WorkerThreadPool;
+class ZenCacheNamespace;
+struct CloudCacheClientOptions;
+struct UpstreamAuthConfig;
+
+enum class UpstreamApplyState : int32_t
+{
+ Queued = 0,
+ Executing = 1,
+ Complete = 2,
+};
+
+enum class UpstreamApplyType
+{
+ Simple = 0,
+ Asset = 1,
+};
+
+struct UpstreamApplyRecord
+{
+ CbObject WorkerDescriptor;
+ CbObject Action;
+ UpstreamApplyType Type;
+ std::map<std::string, uint64_t> Timepoints{};
+};
+
+struct UpstreamApplyOptions
+{
+ std::chrono::seconds HealthCheckInterval{5};
+ std::chrono::seconds UpdatesInterval{5};
+ uint32_t UpstreamThreadCount = 4;
+ uint32_t DownstreamThreadCount = 4;
+ bool StatsEnabled = false;
+};
+
+struct UpstreamApplyError
+{
+ int32_t ErrorCode{};
+ std::string Reason{};
+
+ explicit operator bool() const { return ErrorCode != 0; }
+};
+
+struct PostUpstreamApplyResult
+{
+ UpstreamApplyError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ std::map<std::string, uint64_t> Timepoints{};
+ bool Success = false;
+};
+
+struct GetUpstreamApplyResult
+{
+ // UpstreamApplyType::Simple
+ std::map<std::filesystem::path, IoHash> OutputFiles{};
+ std::map<IoHash, IoBuffer> FileData{};
+
+ // UpstreamApplyType::Asset
+ CbPackage OutputPackage{};
+ int64_t TotalAttachmentBytes{};
+ int64_t TotalRawAttachmentBytes{};
+
+ UpstreamApplyError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ std::string StdOut{};
+ std::string StdErr{};
+ std::string Agent{};
+ std::string Detail{};
+ std::map<std::string, uint64_t> Timepoints{};
+ bool Success = false;
+};
+
+using UpstreamApplyCompleted = std::unordered_map<IoHash, std::unordered_map<IoHash, GetUpstreamApplyResult>>;
+
+struct GetUpstreamApplyUpdatesResult
+{
+ UpstreamApplyError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ UpstreamApplyCompleted Completed{};
+ bool Success = false;
+};
+
+struct UpstreamApplyStatus
+{
+ UpstreamApplyState State{};
+ GetUpstreamApplyResult Result{};
+ std::chrono::steady_clock::time_point ExpireTime{};
+ std::map<std::string, uint64_t> Timepoints{};
+};
+
+using UpstreamApplyTasks = std::unordered_map<IoHash, std::unordered_map<IoHash, UpstreamApplyStatus>>;
+
+struct UpstreamEndpointHealth
+{
+ std::string Reason;
+ bool Ok = false;
+};
+
+struct UpstreamApplyEndpointStats
+{
+ metrics::Counter PostCount;
+ metrics::Counter CompleteCount;
+ metrics::Counter UpdateCount;
+ metrics::Counter ErrorCount;
+ metrics::Counter UpBytes;
+ metrics::Counter DownBytes;
+};
+
+/**
+ * The upstream apply endpoint is responsible for handling remote execution.
+ */
+class UpstreamApplyEndpoint
+{
+public:
+ virtual ~UpstreamApplyEndpoint() = default;
+
+ virtual UpstreamEndpointHealth Initialize() = 0;
+ virtual bool IsHealthy() const = 0;
+ virtual UpstreamEndpointHealth CheckHealth() = 0;
+ virtual std::string_view DisplayName() const = 0;
+ virtual PostUpstreamApplyResult PostApply(UpstreamApplyRecord ApplyRecord) = 0;
+ virtual GetUpstreamApplyUpdatesResult GetUpdates(WorkerThreadPool& ThreadPool) = 0;
+ virtual UpstreamApplyEndpointStats& Stats() = 0;
+
+ static std::unique_ptr<UpstreamApplyEndpoint> CreateHordeEndpoint(const CloudCacheClientOptions& ComputeOptions,
+ const UpstreamAuthConfig& ComputeAuthConfig,
+ const CloudCacheClientOptions& StorageOptions,
+ const UpstreamAuthConfig& StorageAuthConfig,
+ CidStore& CidStore,
+ AuthMgr& Mgr);
+};
+
+/**
+ * Manages one or more upstream compute endpoints.
+ */
+class UpstreamApply
+{
+public:
+ virtual ~UpstreamApply() = default;
+
+ virtual bool Initialize() = 0;
+ virtual bool IsHealthy() const = 0;
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamApplyEndpoint> Endpoint) = 0;
+
+ struct EnqueueResult
+ {
+ IoHash ApplyId{};
+ bool Success = false;
+ };
+
+ struct StatusResult
+ {
+ UpstreamApplyStatus Status{};
+ bool Success = false;
+ };
+
+ virtual EnqueueResult EnqueueUpstream(UpstreamApplyRecord ApplyRecord) = 0;
+ virtual StatusResult GetStatus(const IoHash& WorkerId, const IoHash& ActionId) = 0;
+ virtual void GetStatus(CbObjectWriter& CbO) = 0;
+
+ static std::unique_ptr<UpstreamApply> Create(const UpstreamApplyOptions& Options, CidStore& CidStore);
+};
+
+} // namespace zen
+
+#endif // ZEN_WITH_COMPUTE_SERVICES
diff --git a/src/zenserver/upstream/upstreamcache.cpp b/src/zenserver/upstream/upstreamcache.cpp
new file mode 100644
index 000000000..e838b5fe2
--- /dev/null
+++ b/src/zenserver/upstream/upstreamcache.cpp
@@ -0,0 +1,2112 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "upstreamcache.h"
+#include "jupiter.h"
+#include "zen.h"
+
+#include <zencore/blockingqueue.h>
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/fmtutils.h>
+#include <zencore/stats.h>
+#include <zencore/stream.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+
+#include <zenhttp/httpshared.h>
+
+#include <zenstore/cidstore.h>
+
+#include <auth/authmgr.h>
+#include "cache/structuredcache.h"
+#include "cache/structuredcachestore.h"
+#include "diag/logging.h"
+
+#include <fmt/format.h>
+
+#include <algorithm>
+#include <atomic>
+#include <shared_mutex>
+#include <thread>
+#include <unordered_map>
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace detail {
+
+ class UpstreamStatus
+ {
+ public:
+ UpstreamEndpointState EndpointState() const { return static_cast<UpstreamEndpointState>(m_State.load(std::memory_order_relaxed)); }
+
+ UpstreamEndpointStatus EndpointStatus() const
+ {
+ const UpstreamEndpointState State = EndpointState();
+ {
+ std::unique_lock _(m_Mutex);
+ return {.Reason = m_ErrorText, .State = State};
+ }
+ }
+
+ void Set(UpstreamEndpointState NewState)
+ {
+ m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed);
+ {
+ std::unique_lock _(m_Mutex);
+ m_ErrorText.clear();
+ }
+ }
+
+ void Set(UpstreamEndpointState NewState, std::string ErrorText)
+ {
+ m_State.store(static_cast<uint32_t>(NewState), std::memory_order_relaxed);
+ {
+ std::unique_lock _(m_Mutex);
+ m_ErrorText = std::move(ErrorText);
+ }
+ }
+
+ void SetFromErrorCode(int32_t ErrorCode, std::string_view ErrorText)
+ {
+ if (ErrorCode != 0)
+ {
+ Set(ErrorCode == 401 ? UpstreamEndpointState::kUnauthorized : UpstreamEndpointState::kError, std::string(ErrorText));
+ }
+ }
+
+ private:
+ mutable std::mutex m_Mutex;
+ std::string m_ErrorText;
+ std::atomic_uint32_t m_State;
+ };
+
+ class JupiterUpstreamEndpoint final : public UpstreamEndpoint
+ {
+ public:
+ JupiterUpstreamEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr)
+ : m_AuthMgr(Mgr)
+ , m_Log(zen::logging::Get("upstream"))
+ {
+ ZEN_ASSERT(!Options.Name.empty());
+ m_Info.Name = Options.Name;
+ m_Info.Url = Options.ServiceUrl;
+
+ std::unique_ptr<CloudCacheTokenProvider> TokenProvider;
+
+ if (AuthConfig.OAuthUrl.empty() == false)
+ {
+ TokenProvider = CloudCacheTokenProvider::CreateFromOAuthClientCredentials(
+ {.Url = AuthConfig.OAuthUrl, .ClientId = AuthConfig.OAuthClientId, .ClientSecret = AuthConfig.OAuthClientSecret});
+ }
+ else if (AuthConfig.OpenIdProvider.empty() == false)
+ {
+ TokenProvider =
+ CloudCacheTokenProvider::CreateFromCallback([this, ProviderName = std::string(AuthConfig.OpenIdProvider)]() {
+ AuthMgr::OpenIdAccessToken Token = m_AuthMgr.GetOpenIdAccessToken(ProviderName);
+ return CloudCacheAccessToken{.Value = Token.AccessToken, .ExpireTime = Token.ExpireTime};
+ });
+ }
+ else
+ {
+ CloudCacheAccessToken AccessToken{.Value = std::string(AuthConfig.AccessToken),
+ .ExpireTime = CloudCacheAccessToken::TimePoint::max()};
+
+ TokenProvider = CloudCacheTokenProvider::CreateFromStaticToken(AccessToken);
+ }
+
+ m_Client = new CloudCacheClient(Options, std::move(TokenProvider));
+ }
+
+ virtual ~JupiterUpstreamEndpoint() = default;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; }
+
+ virtual UpstreamEndpointStatus Initialize() override
+ {
+ try
+ {
+ if (m_Status.EndpointState() == UpstreamEndpointState::kOk)
+ {
+ return {.State = UpstreamEndpointState::kOk};
+ }
+
+ CloudCacheSession Session(m_Client);
+ const CloudCacheResult Result = Session.Authenticate();
+
+ if (Result.Success)
+ {
+ m_Status.Set(UpstreamEndpointState::kOk);
+ }
+ else if (Result.ErrorCode != 0)
+ {
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+ }
+ else
+ {
+ m_Status.Set(UpstreamEndpointState::kUnauthorized);
+ }
+
+ return m_Status.EndpointStatus();
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = Err.what(), .State = GetState()};
+ }
+ }
+
+ std::string_view GetActualDdcNamespace(CloudCacheSession& Session, std::string_view Namespace)
+ {
+ if (Namespace == ZenCacheStore::DefaultNamespace)
+ {
+ return Session.Client().DefaultDdcNamespace();
+ }
+ return Namespace;
+ }
+
+ std::string_view GetActualBlobStoreNamespace(CloudCacheSession& Session, std::string_view Namespace)
+ {
+ if (Namespace == ZenCacheStore::DefaultNamespace)
+ {
+ return Session.Client().DefaultBlobStoreNamespace();
+ }
+ return Namespace;
+ }
+
+ virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); }
+
+ virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheRecord");
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ CloudCacheResult Result;
+
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+
+ if (Type == ZenContentType::kCompressedBinary)
+ {
+ Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject);
+
+ if (Result.Success)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All);
+ if (Result.Success = ValidationResult == CbValidateError::None; Result.Success)
+ {
+ CbObject CacheRecord = LoadCompactBinaryObject(Result.Response);
+ IoBuffer ContentBuffer;
+ int NumAttachments = 0;
+
+ CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ Result.Bytes += AttachmentResult.Bytes;
+ Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds;
+ Result.ErrorCode = AttachmentResult.ErrorCode;
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(AttachmentResult.Response, RawHash, RawSize))
+ {
+ Result.Response = AttachmentResult.Response;
+ ++NumAttachments;
+ }
+ else
+ {
+ Result.Success = false;
+ }
+ });
+ if (NumAttachments != 1)
+ {
+ Result.Success = false;
+ }
+ }
+ }
+ }
+ else
+ {
+ const ZenContentType AcceptType = Type == ZenContentType::kCbPackage ? ZenContentType::kCbObject : Type;
+ Result = Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, AcceptType);
+
+ if (Result.Success && Type == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+
+ const CbValidateError ValidationResult = ValidateCompactBinary(Result.Response, CbValidateMode::All);
+ if (Result.Success = ValidationResult == CbValidateError::None; Result.Success)
+ {
+ CbObject CacheRecord = LoadCompactBinaryObject(Result.Response);
+
+ CacheRecord.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult AttachmentResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ Result.Bytes += AttachmentResult.Bytes;
+ Result.ElapsedSeconds += AttachmentResult.ElapsedSeconds;
+ Result.ErrorCode = AttachmentResult.ErrorCode;
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer Chunk =
+ CompressedBuffer::FromCompressed(SharedBuffer(AttachmentResult.Response), RawHash, RawSize))
+ {
+ Package.AddAttachment(CbAttachment(Chunk, AttachmentHash.AsHash()));
+ }
+ else
+ {
+ Result.Success = false;
+ }
+ });
+
+ Package.SetObject(CacheRecord);
+ }
+
+ if (Result.Success)
+ {
+ BinaryWriter MemStream;
+ Package.Save(MemStream);
+
+ Result.Response = IoBuffer(IoBuffer::Clone, MemStream.Data(), MemStream.Size());
+ }
+ }
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheRecords");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheKeyRequest* Request : Requests)
+ {
+ const CacheKey& CacheKey = Request->Key;
+ CbPackage Package;
+ CbObject Record;
+
+ double ElapsedSeconds = 0.0;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ CloudCacheResult RefResult =
+ Session.GetRef(BlobStoreNamespace, CacheKey.Bucket, CacheKey.Hash, ZenContentType::kCbObject);
+ AppendResult(RefResult, Result);
+ ElapsedSeconds = RefResult.ElapsedSeconds;
+
+ m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason);
+
+ if (RefResult.ErrorCode == 0)
+ {
+ const CbValidateError ValidationResult = ValidateCompactBinary(RefResult.Response, CbValidateMode::All);
+ if (ValidationResult == CbValidateError::None)
+ {
+ Record = LoadCompactBinaryObject(RefResult.Response);
+ Record.IterateAttachments([&](CbFieldView AttachmentHash) {
+ CloudCacheResult BlobResult = Session.GetCompressedBlob(BlobStoreNamespace, AttachmentHash.AsHash());
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+
+ if (BlobResult.ErrorCode == 0)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer Chunk =
+ CompressedBuffer::FromCompressed(SharedBuffer(BlobResult.Response), RawHash, RawSize))
+ {
+ if (RawHash == AttachmentHash.AsHash())
+ {
+ Package.AddAttachment(CbAttachment(Chunk, RawHash));
+ }
+ }
+ }
+ });
+ }
+ }
+ }
+
+ OnComplete(
+ {.Request = *Request, .Record = Record, .Package = Package, .ElapsedSeconds = ElapsedSeconds, .Source = &m_Info});
+ }
+
+ return Result;
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey&,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetSingleCacheChunk");
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const CloudCacheResult Result = Session.GetCompressedBlob(BlobStoreNamespace, ValueContentId);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheChunks");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ CacheChunkRequest& Request = *RequestPtr;
+ IoBuffer Payload;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+
+ double ElapsedSeconds = 0.0;
+ bool IsCompressed = false;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const CloudCacheResult BlobResult =
+ Request.ChunkId == IoHash::Zero
+ ? Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, Request.ChunkId)
+ : Session.GetCompressedBlob(BlobStoreNamespace, Request.ChunkId);
+ ElapsedSeconds = BlobResult.ElapsedSeconds;
+ Payload = BlobResult.Response;
+
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+ if (Payload && IsCompressedBinary(Payload.GetContentType()))
+ {
+ IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize);
+ }
+ }
+
+ if (IsCompressed)
+ {
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = Payload,
+ .ElapsedSeconds = ElapsedSeconds,
+ .Source = &m_Info});
+ }
+ else
+ {
+ OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ return Result;
+ }
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::GetCacheValues");
+
+ CloudCacheSession Session(m_Client);
+ GetUpstreamCacheResult Result;
+
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ CacheValueRequest& Request = *RequestPtr;
+ IoBuffer Payload;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+
+ double ElapsedSeconds = 0.0;
+ bool IsCompressed = false;
+ if (!Result.Error)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ IoHash PayloadHash;
+ const CloudCacheResult BlobResult =
+ Session.GetInlineBlob(BlobStoreNamespace, Request.Key.Bucket, Request.Key.Hash, PayloadHash);
+ ElapsedSeconds = BlobResult.ElapsedSeconds;
+ Payload = BlobResult.Response;
+
+ AppendResult(BlobResult, Result);
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+ if (Payload)
+ {
+ if (IsCompressedBinary(Payload.GetContentType()))
+ {
+ IsCompressed = CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize) && RawHash != PayloadHash;
+ }
+ else
+ {
+ CompressedBuffer Compressed = CompressedBuffer::Compress(SharedBuffer(Payload));
+ RawHash = Compressed.DecodeRawHash();
+ if (RawHash == PayloadHash)
+ {
+ IsCompressed = true;
+ }
+ else
+ {
+ ZEN_WARN("Horde request for inline payload of {}/{}/{} has hash {}, expected hash {} from header",
+ Namespace,
+ Request.Key.Bucket,
+ Request.Key.Hash.ToHexString(),
+ RawHash.ToHexString(),
+ PayloadHash.ToHexString());
+ }
+ }
+ }
+ }
+
+ if (IsCompressed)
+ {
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = Payload,
+ .ElapsedSeconds = ElapsedSeconds,
+ .Source = &m_Info});
+ }
+ else
+ {
+ OnComplete({.Request = Request, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ return Result;
+ }
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Values) override
+ {
+ ZEN_TRACE_CPU("Upstream::Horde::PutCacheRecord");
+
+ ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size());
+ const int32_t MaxAttempts = 3;
+
+ try
+ {
+ CloudCacheSession Session(m_Client);
+
+ if (CacheRecord.Type == ZenContentType::kBinary)
+ {
+ CloudCacheResult Result;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, CacheRecord.Namespace);
+ Result = Session.PutRef(BlobStoreNamespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ RecordValue,
+ ZenContentType::kBinary);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ return {.Reason = std::move(Result.Reason),
+ .Bytes = Result.Bytes,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Success = Result.Success};
+ }
+ else if (CacheRecord.Type == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (!CompressedBuffer::ValidateCompressedHeader(RecordValue, RawHash, RawSize))
+ {
+ return {.Reason = std::string("Invalid compressed value buffer"), .Success = false};
+ }
+
+ CbObjectWriter ReferencingObject;
+ ReferencingObject.AddBinaryAttachment("RawHash", RawHash);
+ ReferencingObject.AddInteger("RawSize", RawSize);
+
+ return PerformStructuredPut(
+ Session,
+ CacheRecord.Namespace,
+ CacheRecord.Key,
+ ReferencingObject.Save().GetBuffer().AsIoBuffer(),
+ MaxAttempts,
+ [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) {
+ if (ValueContentId != RawHash)
+ {
+ OutReason =
+ fmt::format("Value '{}' MISMATCHED from compressed buffer raw hash {}", ValueContentId, RawHash);
+ return false;
+ }
+
+ OutBuffer = RecordValue;
+ return true;
+ });
+ }
+ else
+ {
+ return PerformStructuredPut(
+ Session,
+ CacheRecord.Namespace,
+ CacheRecord.Key,
+ RecordValue,
+ MaxAttempts,
+ [&](const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason) {
+ const auto It =
+ std::find(std::begin(CacheRecord.ValueContentIds), std::end(CacheRecord.ValueContentIds), ValueContentId);
+
+ if (It == std::end(CacheRecord.ValueContentIds))
+ {
+ OutReason = fmt::format("value '{}' MISSING from local cache", ValueContentId);
+ return false;
+ }
+
+ const size_t Idx = std::distance(std::begin(CacheRecord.ValueContentIds), It);
+
+ OutBuffer = Values[Idx];
+ return true;
+ });
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = std::string(Err.what()), .Success = false};
+ }
+ }
+
+ virtual UpstreamEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ static void AppendResult(const CloudCacheResult& Result, GetUpstreamCacheResult& Out)
+ {
+ Out.Success &= Result.Success;
+ Out.Bytes += Result.Bytes;
+ Out.ElapsedSeconds += Result.ElapsedSeconds;
+
+ if (Result.ErrorCode)
+ {
+ Out.Error = {.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)};
+ }
+ };
+
+ PutUpstreamCacheResult PerformStructuredPut(
+ CloudCacheSession& Session,
+ std::string_view Namespace,
+ const CacheKey& Key,
+ IoBuffer ObjectBuffer,
+ const int32_t MaxAttempts,
+ std::function<bool(const IoHash& ValueContentId, IoBuffer& OutBuffer, std::string& OutReason)>&& BlobFetchFn)
+ {
+ int64_t TotalBytes = 0ull;
+ double TotalElapsedSeconds = 0.0;
+
+ std::string_view BlobStoreNamespace = GetActualBlobStoreNamespace(Session, Namespace);
+ const auto PutBlobs = [&](std::span<IoHash> ValueContentIds, std::string& OutReason) -> bool {
+ for (const IoHash& ValueContentId : ValueContentIds)
+ {
+ IoBuffer BlobBuffer;
+ if (!BlobFetchFn(ValueContentId, BlobBuffer, OutReason))
+ {
+ return false;
+ }
+
+ CloudCacheResult BlobResult;
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !BlobResult.Success; Attempt++)
+ {
+ BlobResult = Session.PutCompressedBlob(BlobStoreNamespace, ValueContentId, BlobBuffer);
+ }
+
+ m_Status.SetFromErrorCode(BlobResult.ErrorCode, BlobResult.Reason);
+
+ if (!BlobResult.Success)
+ {
+ OutReason = fmt::format("upload value '{}' FAILED, reason '{}'", ValueContentId, BlobResult.Reason);
+ return false;
+ }
+
+ TotalBytes += BlobResult.Bytes;
+ TotalElapsedSeconds += BlobResult.ElapsedSeconds;
+ }
+
+ return true;
+ };
+
+ PutRefResult RefResult;
+ for (int32_t Attempt = 0; Attempt < MaxAttempts && !RefResult.Success; Attempt++)
+ {
+ RefResult = Session.PutRef(BlobStoreNamespace, Key.Bucket, Key.Hash, ObjectBuffer, ZenContentType::kCbObject);
+ }
+
+ m_Status.SetFromErrorCode(RefResult.ErrorCode, RefResult.Reason);
+
+ if (!RefResult.Success)
+ {
+ return {.Reason = fmt::format("upload cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, RefResult.Reason),
+ .Success = false};
+ }
+
+ TotalBytes += RefResult.Bytes;
+ TotalElapsedSeconds += RefResult.ElapsedSeconds;
+
+ std::string Reason;
+ if (!PutBlobs(RefResult.Needs, Reason))
+ {
+ return {.Reason = std::move(Reason), .Success = false};
+ }
+
+ const IoHash RefHash = IoHash::HashBuffer(ObjectBuffer);
+ FinalizeRefResult FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash);
+
+ m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason);
+
+ if (!FinalizeResult.Success)
+ {
+ return {
+ .Reason = fmt::format("finalize cache record '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason),
+ .Success = false};
+ }
+
+ if (!FinalizeResult.Needs.empty())
+ {
+ if (!PutBlobs(FinalizeResult.Needs, Reason))
+ {
+ return {.Reason = std::move(Reason), .Success = false};
+ }
+
+ FinalizeResult = Session.FinalizeRef(BlobStoreNamespace, Key.Bucket, Key.Hash, RefHash);
+
+ m_Status.SetFromErrorCode(FinalizeResult.ErrorCode, FinalizeResult.Reason);
+
+ if (!FinalizeResult.Success)
+ {
+ return {.Reason = fmt::format("finalize '{}/{}' FAILED, reason '{}'", Key.Bucket, Key.Hash, FinalizeResult.Reason),
+ .Success = false};
+ }
+
+ if (!FinalizeResult.Needs.empty())
+ {
+ ExtendableStringBuilder<256> Sb;
+ for (const IoHash& MissingHash : FinalizeResult.Needs)
+ {
+ Sb << MissingHash.ToHexString() << ",";
+ }
+
+ return {
+ .Reason = fmt::format("finalize '{}/{}' FAILED, still needs value(s) '{}'", Key.Bucket, Key.Hash, Sb.ToString()),
+ .Success = false};
+ }
+ }
+
+ TotalBytes += FinalizeResult.Bytes;
+ TotalElapsedSeconds += FinalizeResult.ElapsedSeconds;
+
+ return {.Bytes = TotalBytes, .ElapsedSeconds = TotalElapsedSeconds, .Success = true};
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ AuthMgr& m_AuthMgr;
+ spdlog::logger& m_Log;
+ UpstreamEndpointInfo m_Info;
+ UpstreamStatus m_Status;
+ UpstreamEndpointStats m_Stats;
+ RefPtr<CloudCacheClient> m_Client;
+ };
+
+ class ZenUpstreamEndpoint final : public UpstreamEndpoint
+ {
+ struct ZenEndpoint
+ {
+ std::string Url;
+ std::string Reason;
+ double Latency{};
+ bool Ok = false;
+
+ bool operator<(const ZenEndpoint& RHS) const { return Ok && RHS.Ok ? Latency < RHS.Latency : Ok; }
+ };
+
+ public:
+ ZenUpstreamEndpoint(const ZenStructuredCacheClientOptions& Options)
+ : m_Log(zen::logging::Get("upstream"))
+ , m_ConnectTimeout(Options.ConnectTimeout)
+ , m_Timeout(Options.Timeout)
+ {
+ ZEN_ASSERT(!Options.Name.empty());
+ m_Info.Name = Options.Name;
+
+ for (const auto& Url : Options.Urls)
+ {
+ m_Endpoints.push_back({.Url = Url});
+ }
+ }
+
+ ~ZenUpstreamEndpoint() = default;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const override { return m_Info; }
+
+ virtual UpstreamEndpointStatus Initialize() override
+ {
+ try
+ {
+ if (m_Status.EndpointState() == UpstreamEndpointState::kOk)
+ {
+ return {.State = UpstreamEndpointState::kOk};
+ }
+
+ const ZenEndpoint& Ep = GetEndpoint();
+
+ if (m_Info.Url != Ep.Url)
+ {
+ ZEN_INFO("Setting Zen upstream URL to '{}'", Ep.Url);
+ m_Info.Url = Ep.Url;
+ }
+
+ if (Ep.Ok)
+ {
+ RwLock::ExclusiveLockScope _(m_ClientLock);
+ m_Client = new ZenStructuredCacheClient({.Url = m_Info.Url, .ConnectTimeout = m_ConnectTimeout, .Timeout = m_Timeout});
+ m_Status.Set(UpstreamEndpointState::kOk);
+ }
+ else
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Ep.Reason);
+ }
+
+ return m_Status.EndpointStatus();
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = Err.what(), .State = GetState()};
+ }
+ }
+
+ virtual UpstreamEndpointState GetState() override { return m_Status.EndpointState(); }
+
+ virtual UpstreamEndpointStatus GetStatus() override { return m_Status.EndpointStatus(); }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetSingleCacheRecord");
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ const ZenCacheResult Result = Session.GetCacheRecord(Namespace, CacheKey.Bucket, CacheKey.Hash, Type);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheRecords");
+ ZEN_ASSERT(Requests.size() > 0);
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheRecords"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = Requests[0]->Policy.GetRecordPolicy();
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy);
+
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("Requests"sv);
+ for (CacheKeyRequest* Request : Requests)
+ {
+ BatchRequest.BeginObject();
+ {
+ const CacheKey& Key = Request->Key;
+ BatchRequest.BeginObject("Key"sv);
+ {
+ BatchRequest << "Bucket"sv << Key.Bucket;
+ BatchRequest << "Hash"sv << Key.Hash;
+ }
+ BatchRequest.EndObject();
+ if (!Request->Policy.IsUniform() || Request->Policy.GetRecordPolicy() != DefaultPolicy)
+ {
+ BatchRequest.SetName("Policy"sv);
+ Request->Policy.Save(BatchRequest);
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (Results.Num() != Requests.size())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheRecords invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t Index = 0; CbFieldView Record : Results)
+ {
+ CacheKeyRequest* Request = Requests[Index++];
+ OnComplete({.Request = *Request,
+ .Record = Record.AsObjectView(),
+ .Package = BatchResponse,
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheRecords invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheKeyRequest* Request : Requests)
+ {
+ OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunk");
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ const ZenCacheResult Result = Session.GetCacheChunk(Namespace, CacheKey.Bucket, CacheKey.Hash, ValueContentId);
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.ErrorCode == 0)
+ {
+ return {.Status = {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = Result.Success},
+ .Value = Result.Response,
+ .Source = &m_Info};
+ }
+ else
+ {
+ return {.Status = {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}}};
+ }
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Status = {.Error{.ErrorCode = -1, .Reason = Err.what()}}};
+ }
+ }
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheValues");
+ ZEN_ASSERT(!CacheValueRequests.empty());
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheValues"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = CacheValueRequests[0]->Policy;
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView();
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("Requests"sv);
+ {
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ const CacheValueRequest& Request = *RequestPtr;
+
+ BatchRequest.BeginObject();
+ {
+ BatchRequest.BeginObject("Key"sv);
+ BatchRequest << "Bucket"sv << Request.Key.Bucket;
+ BatchRequest << "Hash"sv << Request.Key.Hash;
+ BatchRequest.EndObject();
+ if (Request.Policy != DefaultPolicy)
+ {
+ BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView();
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (CacheValueRequests.size() != Results.Num())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheValues invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t RequestIndex = 0; CbFieldView ChunkField : Results)
+ {
+ CacheValueRequest& Request = *CacheValueRequests[RequestIndex++];
+ CbObjectView ChunkObject = ChunkField.AsObjectView();
+ IoHash RawHash = ChunkObject["RawHash"sv].AsHash();
+ IoBuffer Payload;
+ uint64_t RawSize = 0;
+ if (RawHash != IoHash::Zero)
+ {
+ bool Success = false;
+ const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash);
+ if (Attachment)
+ {
+ if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary())
+ {
+ Payload = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+ RawSize = Compressed.DecodeRawSize();
+ Success = true;
+ }
+ }
+ if (!Success)
+ {
+ CbFieldView RawSizeField = ChunkObject["RawSize"sv];
+ RawSize = RawSizeField.AsUInt64();
+ Success = !RawSizeField.HasError();
+ }
+ if (!Success)
+ {
+ RawHash = IoHash::Zero;
+ }
+ }
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = std::move(Payload),
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheValues invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheValueRequest* RequestPtr : CacheValueRequests)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::GetCacheChunks");
+ ZEN_ASSERT(!CacheChunkRequests.empty());
+
+ CbObjectWriter BatchRequest;
+ BatchRequest << "Method"sv
+ << "GetCacheChunks"sv;
+ BatchRequest << "Accept"sv << kCbPkgMagic;
+
+ BatchRequest.BeginObject("Params"sv);
+ {
+ CachePolicy DefaultPolicy = CacheChunkRequests[0]->Policy;
+ BatchRequest << "DefaultPolicy"sv << WriteToString<128>(DefaultPolicy).ToView();
+ BatchRequest << "Namespace"sv << Namespace;
+
+ BatchRequest.BeginArray("ChunkRequests"sv);
+ {
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ const CacheChunkRequest& Request = *RequestPtr;
+
+ BatchRequest.BeginObject();
+ {
+ BatchRequest.BeginObject("Key"sv);
+ BatchRequest << "Bucket"sv << Request.Key.Bucket;
+ BatchRequest << "Hash"sv << Request.Key.Hash;
+ BatchRequest.EndObject();
+ if (Request.ValueId)
+ {
+ BatchRequest.AddObjectId("ValueId"sv, Request.ValueId);
+ }
+ if (Request.ChunkId != Request.ChunkId.Zero)
+ {
+ BatchRequest << "ChunkId"sv << Request.ChunkId;
+ }
+ if (Request.RawOffset != 0)
+ {
+ BatchRequest << "RawOffset"sv << Request.RawOffset;
+ }
+ if (Request.RawSize != UINT64_MAX)
+ {
+ BatchRequest << "RawSize"sv << Request.RawSize;
+ }
+ if (Request.Policy != DefaultPolicy)
+ {
+ BatchRequest << "Policy"sv << WriteToString<128>(Request.Policy).ToView();
+ }
+ }
+ BatchRequest.EndObject();
+ }
+ }
+ BatchRequest.EndArray();
+ }
+ BatchRequest.EndObject();
+
+ ZenCacheResult Result;
+
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ Result = Session.InvokeRpc(BatchRequest.Save());
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ if (Result.Success)
+ {
+ CbPackage BatchResponse;
+ if (ParsePackageMessageWithLegacyFallback(Result.Response, BatchResponse))
+ {
+ CbArrayView Results = BatchResponse.GetObject()["Result"sv].AsArrayView();
+ if (CacheChunkRequests.size() != Results.Num())
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheChunks invalid number of Response results from Upstream.");
+ }
+ else
+ {
+ for (size_t RequestIndex = 0; CbFieldView ChunkField : Results)
+ {
+ CacheChunkRequest& Request = *CacheChunkRequests[RequestIndex++];
+ CbObjectView ChunkObject = ChunkField.AsObjectView();
+ IoHash RawHash = ChunkObject["RawHash"sv].AsHash();
+ IoBuffer Payload;
+ uint64_t RawSize = 0;
+ if (RawHash != IoHash::Zero)
+ {
+ bool Success = false;
+ const CbAttachment* Attachment = BatchResponse.FindAttachment(RawHash);
+ if (Attachment)
+ {
+ if (const CompressedBuffer& Compressed = Attachment->AsCompressedBinary())
+ {
+ Payload = Compressed.GetCompressed().Flatten().AsIoBuffer();
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+ RawSize = Compressed.DecodeRawSize();
+ Success = true;
+ }
+ }
+ if (!Success)
+ {
+ CbFieldView RawSizeField = ChunkObject["RawSize"sv];
+ RawSize = RawSizeField.AsUInt64();
+ Success = !RawSizeField.HasError();
+ }
+ if (!Success)
+ {
+ RawHash = IoHash::Zero;
+ }
+ }
+ OnComplete({.Request = Request,
+ .RawHash = RawHash,
+ .RawSize = RawSize,
+ .Value = std::move(Payload),
+ .ElapsedSeconds = Result.ElapsedSeconds,
+ .Source = &m_Info});
+ }
+
+ return {.Bytes = Result.Bytes, .ElapsedSeconds = Result.ElapsedSeconds, .Success = true};
+ }
+ }
+ else
+ {
+ ZEN_WARN("Upstream::Zen::GetCacheChunks invalid Response from Upstream.");
+ }
+ }
+
+ for (CacheChunkRequest* RequestPtr : CacheChunkRequests)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+
+ return {.Error{.ErrorCode = Result.ErrorCode, .Reason = std::move(Result.Reason)}};
+ }
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Values) override
+ {
+ ZEN_TRACE_CPU("Upstream::Zen::PutCacheRecord");
+
+ ZEN_ASSERT(CacheRecord.ValueContentIds.size() == Values.size());
+ const int32_t MaxAttempts = 3;
+
+ try
+ {
+ ZenStructuredCacheSession Session(GetClientRef());
+ ZenCacheResult Result;
+ int64_t TotalBytes = 0ull;
+ double TotalElapsedSeconds = 0.0;
+
+ if (CacheRecord.Type == ZenContentType::kCbPackage)
+ {
+ CbPackage Package;
+ Package.SetObject(CbObject(SharedBuffer(RecordValue)));
+
+ for (const IoBuffer& Value : Values)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer AttachmentBuffer = CompressedBuffer::FromCompressed(SharedBuffer(Value), RawHash, RawSize))
+ {
+ Package.AddAttachment(CbAttachment(AttachmentBuffer, RawHash));
+ }
+ else
+ {
+ return {.Reason = std::string("Invalid value buffer"), .Success = false};
+ }
+ }
+
+ BinaryWriter MemStream;
+ Package.Save(MemStream);
+ IoBuffer PackagePayload(IoBuffer::Wrap, MemStream.Data(), MemStream.Size());
+
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheRecord(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ PackagePayload,
+ CacheRecord.Type);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes = Result.Bytes;
+ TotalElapsedSeconds = Result.ElapsedSeconds;
+ }
+ else if (CacheRecord.Type == ZenContentType::kCompressedBinary)
+ {
+ IoHash RawHash;
+ uint64_t RawSize;
+ CompressedBuffer Compressed = CompressedBuffer::FromCompressed(SharedBuffer(RecordValue), RawHash, RawSize);
+ if (!Compressed)
+ {
+ return {.Reason = std::string("Invalid value compressed buffer"), .Success = false};
+ }
+
+ CbPackage BatchPackage;
+ CbObjectWriter BatchWriter;
+ BatchWriter << "Method"sv
+ << "PutCacheValues"sv;
+ BatchWriter << "Accept"sv << kCbPkgMagic;
+
+ BatchWriter.BeginObject("Params"sv);
+ {
+ // DefaultPolicy unspecified and expected to be Default
+
+ BatchWriter << "Namespace"sv << CacheRecord.Namespace;
+
+ BatchWriter.BeginArray("Requests"sv);
+ {
+ BatchWriter.BeginObject();
+ {
+ const CacheKey& Key = CacheRecord.Key;
+ BatchWriter.BeginObject("Key"sv);
+ {
+ BatchWriter << "Bucket"sv << Key.Bucket;
+ BatchWriter << "Hash"sv << Key.Hash;
+ }
+ BatchWriter.EndObject();
+ // Policy unspecified and expected to be Default
+ BatchWriter.AddBinaryAttachment("RawHash"sv, RawHash);
+ BatchPackage.AddAttachment(CbAttachment(Compressed, RawHash));
+ }
+ BatchWriter.EndObject();
+ }
+ BatchWriter.EndArray();
+ }
+ BatchWriter.EndObject();
+ BatchPackage.SetObject(BatchWriter.Save());
+
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.InvokeRpc(BatchPackage);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+ }
+ else
+ {
+ for (size_t Idx = 0, Count = Values.size(); Idx < Count; Idx++)
+ {
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheValue(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ CacheRecord.ValueContentIds[Idx],
+ Values[Idx]);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+
+ if (!Result.Success)
+ {
+ return {.Reason = "Failed to upload value",
+ .Bytes = TotalBytes,
+ .ElapsedSeconds = TotalElapsedSeconds,
+ .Success = false};
+ }
+ }
+
+ Result.Success = false;
+ for (uint32_t Attempt = 0; Attempt < MaxAttempts && !Result.Success; Attempt++)
+ {
+ Result = Session.PutCacheRecord(CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ RecordValue,
+ CacheRecord.Type);
+ }
+
+ m_Status.SetFromErrorCode(Result.ErrorCode, Result.Reason);
+
+ TotalBytes += Result.Bytes;
+ TotalElapsedSeconds += Result.ElapsedSeconds;
+ }
+
+ return {.Reason = std::move(Result.Reason),
+ .Bytes = TotalBytes,
+ .ElapsedSeconds = TotalElapsedSeconds,
+ .Success = Result.Success};
+ }
+ catch (std::exception& Err)
+ {
+ m_Status.Set(UpstreamEndpointState::kError, Err.what());
+
+ return {.Reason = std::string(Err.what()), .Success = false};
+ }
+ }
+
+ virtual UpstreamEndpointStats& Stats() override { return m_Stats; }
+
+ private:
+ Ref<ZenStructuredCacheClient> GetClientRef()
+ {
+ // m_Client can be modified at any time by a different thread.
+ // Make sure we safely bump the refcount inside a scope lock
+ RwLock::SharedLockScope _(m_ClientLock);
+ ZEN_ASSERT(m_Client);
+ Ref<ZenStructuredCacheClient> ClientRef(m_Client);
+ _.ReleaseNow();
+ return ClientRef;
+ }
+
+ const ZenEndpoint& GetEndpoint()
+ {
+ for (ZenEndpoint& Ep : m_Endpoints)
+ {
+ Ref<ZenStructuredCacheClient> Client(
+ new ZenStructuredCacheClient({.Url = Ep.Url, .ConnectTimeout = std::chrono::milliseconds(1000)}));
+ ZenStructuredCacheSession Session(std::move(Client));
+ const int32_t SampleCount = 2;
+
+ Ep.Ok = false;
+ Ep.Latency = {};
+
+ for (int32_t Sample = 0; Sample < SampleCount; ++Sample)
+ {
+ ZenCacheResult Result = Session.CheckHealth();
+ Ep.Ok = Result.Success;
+ Ep.Reason = std::move(Result.Reason);
+ Ep.Latency += Result.ElapsedSeconds;
+ }
+ Ep.Latency /= double(SampleCount);
+ }
+
+ std::sort(std::begin(m_Endpoints), std::end(m_Endpoints));
+
+ for (const auto& Ep : m_Endpoints)
+ {
+ ZEN_INFO("ping 'Zen' endpoint '{}' latency '{:.3}s' {}", Ep.Url, Ep.Latency, Ep.Ok ? "OK" : Ep.Reason);
+ }
+
+ return m_Endpoints.front();
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ UpstreamEndpointInfo m_Info;
+ UpstreamStatus m_Status;
+ UpstreamEndpointStats m_Stats;
+ std::vector<ZenEndpoint> m_Endpoints;
+ std::chrono::milliseconds m_ConnectTimeout;
+ std::chrono::milliseconds m_Timeout;
+ RwLock m_ClientLock;
+ RefPtr<ZenStructuredCacheClient> m_Client;
+ };
+
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+class UpstreamCacheImpl final : public UpstreamCache
+{
+public:
+ UpstreamCacheImpl(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore)
+ : m_Log(logging::Get("upstream"))
+ , m_Options(Options)
+ , m_CacheStore(CacheStore)
+ , m_CidStore(CidStore)
+ {
+ }
+
+ virtual ~UpstreamCacheImpl() { Shutdown(); }
+
+ virtual void Initialize() override
+ {
+ for (uint32_t Idx = 0; Idx < m_Options.ThreadCount; Idx++)
+ {
+ m_UpstreamThreads.emplace_back(&UpstreamCacheImpl::ProcessUpstreamQueue, this);
+ }
+
+ m_EndpointMonitorThread = std::thread(&UpstreamCacheImpl::MonitorEndpoints, this);
+ m_RunState.IsRunning = true;
+ }
+
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) override
+ {
+ const UpstreamEndpointStatus Status = Endpoint->Initialize();
+ const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo();
+
+ if (Status.State == UpstreamEndpointState::kOk)
+ {
+ ZEN_INFO("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State));
+ }
+ else
+ {
+ ZEN_WARN("register endpoint '{} - {}' {}", Info.Name, Info.Url, ToString(Status.State));
+ }
+
+ // Register endpoint even if it fails, the health monitor thread will probe failing endpoint(s)
+ std::unique_lock<std::shared_mutex> _(m_EndpointsMutex);
+ m_Endpoints.emplace_back(std::move(Endpoint));
+ }
+
+ virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) override
+ {
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Ep : m_Endpoints)
+ {
+ if (!Fn(*Ep))
+ {
+ break;
+ }
+ }
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) override
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheRecord");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+ GetUpstreamCacheSingleResult Result = Endpoint->GetCacheRecord(Namespace, CacheKey, Type);
+ Scope.Stop();
+
+ Stats.CacheGetCount.Increment(1);
+ Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes);
+
+ if (Result.Status.Success)
+ {
+ Stats.CacheHitCount.Increment(1);
+
+ return Result;
+ }
+
+ if (Result.Status.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache record FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Status.Error.Reason,
+ Result.Status.Error.ErrorCode);
+ }
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheRecords");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheKeyRequest*> RemainingKeys(Requests.begin(), Requests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheKeyRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheRecords(Namespace, RemainingKeys, [&](CacheRecordGetCompleteParams&& Params) {
+ if (Params.Record)
+ {
+ OnComplete(std::forward<CacheRecordGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache record(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheKeyRequest* Request : RemainingKeys)
+ {
+ OnComplete({.Request = *Request, .Record = CbObjectView(), .Package = CbPackage()});
+ }
+ }
+
+ virtual void GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheChunks");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheChunkRequest*> RemainingKeys(CacheChunkRequests.begin(), CacheChunkRequests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheChunkRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheChunks(Namespace, RemainingKeys, [&](CacheChunkGetCompleteParams&& Params) {
+ if (Params.RawHash != Params.RawHash.Zero)
+ {
+ OnComplete(std::forward<CacheChunkGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache chunks(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheChunkRequest* RequestPtr : RemainingKeys)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) override
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheChunk");
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ metrics::OperationTiming::Scope Scope(Stats.CacheGetRequestTiming);
+ GetUpstreamCacheSingleResult Result = Endpoint->GetCacheChunk(Namespace, CacheKey, ValueContentId);
+ Scope.Stop();
+
+ Stats.CacheGetCount.Increment(1);
+ Stats.CacheGetTotalBytes.Increment(Result.Status.Bytes);
+
+ if (Result.Status.Success)
+ {
+ Stats.CacheHitCount.Increment(1);
+
+ return Result;
+ }
+
+ if (Result.Status.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache chunk FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Status.Error.Reason,
+ Result.Status.Error.ErrorCode);
+ }
+ }
+ }
+
+ return {};
+ }
+
+ virtual void GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) override final
+ {
+ ZEN_TRACE_CPU("Upstream::GetCacheValues");
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ std::vector<CacheValueRequest*> RemainingKeys(CacheValueRequests.begin(), CacheValueRequests.end());
+
+ if (m_Options.ReadUpstream)
+ {
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (RemainingKeys.empty())
+ {
+ break;
+ }
+
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ std::vector<CacheValueRequest*> Missing;
+ GetUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Endpoint->Stats().CacheGetRequestTiming);
+
+ Result = Endpoint->GetCacheValues(Namespace, RemainingKeys, [&](CacheValueGetCompleteParams&& Params) {
+ if (Params.RawHash != Params.RawHash.Zero)
+ {
+ OnComplete(std::forward<CacheValueGetCompleteParams>(Params));
+
+ Stats.CacheHitCount.Increment(1);
+ }
+ else
+ {
+ Missing.push_back(&Params.Request);
+ }
+ });
+ }
+
+ Stats.CacheGetCount.Increment(int64_t(RemainingKeys.size()));
+ Stats.CacheGetTotalBytes.Increment(Result.Bytes);
+
+ if (Result.Error)
+ {
+ Stats.CacheErrorCount.Increment(1);
+
+ ZEN_WARN("get cache values(s) (rpc) FAILED, endpoint '{}', reason '{}', error code '{}'",
+ Endpoint->GetEndpointInfo().Url,
+ Result.Error.Reason,
+ Result.Error.ErrorCode);
+ }
+
+ RemainingKeys = std::move(Missing);
+ }
+ }
+
+ const UpstreamEndpointInfo Info;
+ for (CacheValueRequest* RequestPtr : RemainingKeys)
+ {
+ OnComplete({.Request = *RequestPtr, .RawHash = IoHash::Zero, .RawSize = 0, .Value = IoBuffer()});
+ }
+ }
+
+ virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) override
+ {
+ if (m_RunState.IsRunning && m_Options.WriteUpstream && m_Endpoints.size() > 0)
+ {
+ if (!m_UpstreamThreads.empty())
+ {
+ m_UpstreamQueue.Enqueue(std::move(CacheRecord));
+ }
+ else
+ {
+ ProcessCacheRecord(std::move(CacheRecord));
+ }
+ }
+ }
+
+ virtual void GetStatus(CbObjectWriter& Status) override
+ {
+ Status << "reading" << m_Options.ReadUpstream;
+ Status << "writing" << m_Options.WriteUpstream;
+ Status << "worker_threads" << m_Options.ThreadCount;
+ Status << "queue_count" << m_UpstreamQueue.Size();
+
+ Status.BeginArray("endpoints");
+ for (const auto& Ep : m_Endpoints)
+ {
+ const UpstreamEndpointInfo& EpInfo = Ep->GetEndpointInfo();
+ const UpstreamEndpointStatus EpStatus = Ep->GetStatus();
+ UpstreamEndpointStats& EpStats = Ep->Stats();
+
+ Status.BeginObject();
+ Status << "name" << EpInfo.Name;
+ Status << "url" << EpInfo.Url;
+ Status << "state" << ToString(EpStatus.State);
+ Status << "reason" << EpStatus.Reason;
+
+ Status.BeginObject("cache"sv);
+ {
+ const int64_t GetCount = EpStats.CacheGetCount.Value();
+ const int64_t HitCount = EpStats.CacheHitCount.Value();
+ const int64_t ErrorCount = EpStats.CacheErrorCount.Value();
+ const double HitRatio = GetCount > 0 ? double(HitCount) / double(GetCount) : 0.0;
+ const double ErrorRatio = GetCount > 0 ? double(ErrorCount) / double(GetCount) : 0.0;
+
+ metrics::EmitSnapshot("get_requests"sv, EpStats.CacheGetRequestTiming, Status);
+ Status << "get_bytes" << EpStats.CacheGetTotalBytes.Value();
+ Status << "get_count" << GetCount;
+ Status << "hit_count" << HitCount;
+ Status << "hit_ratio" << HitRatio;
+ Status << "error_count" << ErrorCount;
+ Status << "error_ratio" << ErrorRatio;
+ metrics::EmitSnapshot("put_requests"sv, EpStats.CachePutRequestTiming, Status);
+ Status << "put_bytes" << EpStats.CachePutTotalBytes.Value();
+ }
+ Status.EndObject();
+
+ Status.EndObject();
+ }
+ Status.EndArray();
+ }
+
+private:
+ void ProcessCacheRecord(UpstreamCacheRecord CacheRecord)
+ {
+ ZEN_TRACE_CPU("Upstream::ProcessCacheRecord");
+
+ ZenCacheValue CacheValue;
+ std::vector<IoBuffer> Payloads;
+
+ if (!m_CacheStore.Get(CacheRecord.Namespace, CacheRecord.Key.Bucket, CacheRecord.Key.Hash, CacheValue))
+ {
+ ZEN_WARN("process upstream FAILED, '{}/{}/{}', cache record doesn't exist",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash);
+ return;
+ }
+
+ for (const IoHash& ValueContentId : CacheRecord.ValueContentIds)
+ {
+ if (IoBuffer Payload = m_CidStore.FindChunkByCid(ValueContentId))
+ {
+ Payloads.push_back(Payload);
+ }
+ else
+ {
+ ZEN_WARN("process upstream FAILED, '{}/{}/{}/{}', ValueContentId doesn't exist in CAS",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ ValueContentId);
+ return;
+ }
+ }
+
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ if (Endpoint->GetState() != UpstreamEndpointState::kOk)
+ {
+ continue;
+ }
+
+ UpstreamEndpointStats& Stats = Endpoint->Stats();
+ PutUpstreamCacheResult Result;
+ {
+ metrics::OperationTiming::Scope Scope(Stats.CachePutRequestTiming);
+ Result = Endpoint->PutCacheRecord(CacheRecord, CacheValue.Value, std::span(Payloads));
+ }
+
+ Stats.CachePutTotalBytes.Increment(Result.Bytes);
+
+ if (!Result.Success)
+ {
+ ZEN_WARN("upload cache record '{}/{}/{}' FAILED, endpoint '{}', reason '{}'",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ Endpoint->GetEndpointInfo().Url,
+ Result.Reason);
+ }
+ }
+ }
+
+ void ProcessUpstreamQueue()
+ {
+ for (;;)
+ {
+ UpstreamCacheRecord CacheRecord;
+ if (m_UpstreamQueue.WaitAndDequeue(CacheRecord))
+ {
+ try
+ {
+ ProcessCacheRecord(std::move(CacheRecord));
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("upload cache record '{}/{}/{}' FAILED, reason '{}'",
+ CacheRecord.Namespace,
+ CacheRecord.Key.Bucket,
+ CacheRecord.Key.Hash,
+ Err.what());
+ }
+ }
+
+ if (!m_RunState.IsRunning)
+ {
+ break;
+ }
+ }
+ }
+
+ void MonitorEndpoints()
+ {
+ for (;;)
+ {
+ {
+ std::unique_lock lk(m_RunState.Mutex);
+ if (m_RunState.ExitSignal.wait_for(lk, m_Options.HealthCheckInterval, [this]() { return !m_RunState.IsRunning.load(); }))
+ {
+ break;
+ }
+ }
+
+ try
+ {
+ std::vector<UpstreamEndpoint*> Endpoints;
+
+ {
+ std::shared_lock<std::shared_mutex> _(m_EndpointsMutex);
+
+ for (auto& Endpoint : m_Endpoints)
+ {
+ UpstreamEndpointState State = Endpoint->GetState();
+ if (State == UpstreamEndpointState::kError)
+ {
+ Endpoints.push_back(Endpoint.get());
+ ZEN_WARN("HEALTH - endpoint '{} - {}' is in error state '{}'",
+ Endpoint->GetEndpointInfo().Name,
+ Endpoint->GetEndpointInfo().Url,
+ Endpoint->GetStatus().Reason);
+ }
+ if (State == UpstreamEndpointState::kUnauthorized)
+ {
+ Endpoints.push_back(Endpoint.get());
+ }
+ }
+ }
+
+ for (auto& Endpoint : Endpoints)
+ {
+ const UpstreamEndpointInfo& Info = Endpoint->GetEndpointInfo();
+ const UpstreamEndpointStatus Status = Endpoint->Initialize();
+
+ if (Status.State == UpstreamEndpointState::kOk)
+ {
+ ZEN_INFO("HEALTH - endpoint '{} - {}' Ok", Info.Name, Info.Url);
+ }
+ else
+ {
+ const std::string Reason = Status.Reason.empty() ? "" : fmt::format(", reason '{}'", Status.Reason);
+ ZEN_WARN("HEALTH - endpoint '{} - {}' {} {}", Info.Name, Info.Url, ToString(Status.State), Reason);
+ }
+ }
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("check endpoint(s) health FAILED, reason '{}'", Err.what());
+ }
+ }
+ }
+
+ void Shutdown()
+ {
+ if (m_RunState.Stop())
+ {
+ m_UpstreamQueue.CompleteAdding();
+ for (std::thread& Thread : m_UpstreamThreads)
+ {
+ Thread.join();
+ }
+
+ m_EndpointMonitorThread.join();
+ m_UpstreamThreads.clear();
+ m_Endpoints.clear();
+ }
+ }
+
+ spdlog::logger& Log() { return m_Log; }
+
+ using UpstreamQueue = BlockingQueue<UpstreamCacheRecord>;
+
+ struct RunState
+ {
+ std::mutex Mutex;
+ std::condition_variable ExitSignal;
+ std::atomic_bool IsRunning{false};
+
+ bool Stop()
+ {
+ bool Stopped = false;
+ {
+ std::lock_guard _(Mutex);
+ Stopped = IsRunning.exchange(false);
+ }
+ if (Stopped)
+ {
+ ExitSignal.notify_all();
+ }
+ return Stopped;
+ }
+ };
+
+ spdlog::logger& m_Log;
+ UpstreamCacheOptions m_Options;
+ ZenCacheStore& m_CacheStore;
+ CidStore& m_CidStore;
+ UpstreamQueue m_UpstreamQueue;
+ std::shared_mutex m_EndpointsMutex;
+ std::vector<std::unique_ptr<UpstreamEndpoint>> m_Endpoints;
+ std::vector<std::thread> m_UpstreamThreads;
+ std::thread m_EndpointMonitorThread;
+ RunState m_RunState;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+std::unique_ptr<UpstreamEndpoint>
+UpstreamEndpoint::CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options)
+{
+ return std::make_unique<detail::ZenUpstreamEndpoint>(Options);
+}
+
+std::unique_ptr<UpstreamEndpoint>
+UpstreamEndpoint::CreateJupiterEndpoint(const CloudCacheClientOptions& Options, const UpstreamAuthConfig& AuthConfig, AuthMgr& Mgr)
+{
+ return std::make_unique<detail::JupiterUpstreamEndpoint>(Options, AuthConfig, Mgr);
+}
+
+std::unique_ptr<UpstreamCache>
+UpstreamCache::Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore)
+{
+ return std::make_unique<UpstreamCacheImpl>(Options, CacheStore, CidStore);
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstreamcache.h b/src/zenserver/upstream/upstreamcache.h
new file mode 100644
index 000000000..695c06b32
--- /dev/null
+++ b/src/zenserver/upstream/upstreamcache.h
@@ -0,0 +1,252 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/compress.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/stats.h>
+#include <zencore/zencore.h>
+#include <zenutil/cache/cache.h>
+
+#include <atomic>
+#include <chrono>
+#include <functional>
+#include <memory>
+#include <vector>
+
+namespace zen {
+
+class CbObjectView;
+class AuthMgr;
+class CbObjectView;
+class CbPackage;
+class CbObjectWriter;
+class CidStore;
+class ZenCacheStore;
+struct CloudCacheClientOptions;
+class CloudCacheTokenProvider;
+struct ZenStructuredCacheClientOptions;
+
+struct UpstreamCacheRecord
+{
+ ZenContentType Type = ZenContentType::kBinary;
+ std::string Namespace;
+ CacheKey Key;
+ std::vector<IoHash> ValueContentIds;
+};
+
+struct UpstreamCacheOptions
+{
+ std::chrono::seconds HealthCheckInterval{5};
+ uint32_t ThreadCount = 4;
+ bool ReadUpstream = true;
+ bool WriteUpstream = true;
+};
+
+struct UpstreamError
+{
+ int32_t ErrorCode{};
+ std::string Reason{};
+
+ explicit operator bool() const { return ErrorCode != 0; }
+};
+
+struct UpstreamEndpointInfo
+{
+ std::string Name;
+ std::string Url;
+};
+
+struct GetUpstreamCacheResult
+{
+ UpstreamError Error{};
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ bool Success = false;
+};
+
+struct GetUpstreamCacheSingleResult
+{
+ GetUpstreamCacheResult Status;
+ IoBuffer Value;
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+struct PutUpstreamCacheResult
+{
+ std::string Reason;
+ int64_t Bytes{};
+ double ElapsedSeconds{};
+ bool Success = false;
+};
+
+struct CacheRecordGetCompleteParams
+{
+ CacheKeyRequest& Request;
+ const CbObjectView& Record;
+ const CbPackage& Package;
+ double ElapsedSeconds{};
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+using OnCacheRecordGetComplete = std::function<void(CacheRecordGetCompleteParams&&)>;
+
+struct CacheValueGetCompleteParams
+{
+ CacheValueRequest& Request;
+ IoHash RawHash;
+ uint64_t RawSize;
+ IoBuffer Value;
+ double ElapsedSeconds{};
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+using OnCacheValueGetComplete = std::function<void(CacheValueGetCompleteParams&&)>;
+
+struct CacheChunkGetCompleteParams
+{
+ CacheChunkRequest& Request;
+ IoHash RawHash;
+ uint64_t RawSize;
+ IoBuffer Value;
+ double ElapsedSeconds{};
+ const UpstreamEndpointInfo* Source = nullptr;
+};
+
+using OnCacheChunksGetComplete = std::function<void(CacheChunkGetCompleteParams&&)>;
+
+struct UpstreamEndpointStats
+{
+ metrics::OperationTiming CacheGetRequestTiming;
+ metrics::OperationTiming CachePutRequestTiming;
+ metrics::Counter CacheGetTotalBytes;
+ metrics::Counter CachePutTotalBytes;
+ metrics::Counter CacheGetCount;
+ metrics::Counter CacheHitCount;
+ metrics::Counter CacheErrorCount;
+};
+
+enum class UpstreamEndpointState : uint32_t
+{
+ kDisabled,
+ kUnauthorized,
+ kError,
+ kOk
+};
+
+inline std::string_view
+ToString(UpstreamEndpointState State)
+{
+ using namespace std::literals;
+
+ switch (State)
+ {
+ case UpstreamEndpointState::kDisabled:
+ return "Disabled"sv;
+ case UpstreamEndpointState::kUnauthorized:
+ return "Unauthorized"sv;
+ case UpstreamEndpointState::kError:
+ return "Error"sv;
+ case UpstreamEndpointState::kOk:
+ return "Ok"sv;
+ default:
+ return "Unknown"sv;
+ }
+}
+
+struct UpstreamAuthConfig
+{
+ std::string_view OAuthUrl;
+ std::string_view OAuthClientId;
+ std::string_view OAuthClientSecret;
+ std::string_view OpenIdProvider;
+ std::string_view AccessToken;
+};
+
+struct UpstreamEndpointStatus
+{
+ std::string Reason;
+ UpstreamEndpointState State;
+};
+
+/**
+ * The upstream endpoint is responsible for handling upload/downloading of cache records.
+ */
+class UpstreamEndpoint
+{
+public:
+ virtual ~UpstreamEndpoint() = default;
+
+ virtual UpstreamEndpointStatus Initialize() = 0;
+
+ virtual const UpstreamEndpointInfo& GetEndpointInfo() const = 0;
+
+ virtual UpstreamEndpointState GetState() = 0;
+ virtual UpstreamEndpointStatus GetStatus() = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0;
+ virtual GetUpstreamCacheResult GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) = 0;
+
+ virtual GetUpstreamCacheResult GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace, const CacheKey& CacheKey, const IoHash& PayloadId) = 0;
+ virtual GetUpstreamCacheResult GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) = 0;
+
+ virtual PutUpstreamCacheResult PutCacheRecord(const UpstreamCacheRecord& CacheRecord,
+ IoBuffer RecordValue,
+ std::span<IoBuffer const> Payloads) = 0;
+
+ virtual UpstreamEndpointStats& Stats() = 0;
+
+ static std::unique_ptr<UpstreamEndpoint> CreateZenEndpoint(const ZenStructuredCacheClientOptions& Options);
+
+ static std::unique_ptr<UpstreamEndpoint> CreateJupiterEndpoint(const CloudCacheClientOptions& Options,
+ const UpstreamAuthConfig& AuthConfig,
+ AuthMgr& Mgr);
+};
+
+/**
+ * Manages one or more upstream cache endpoints.
+ */
+class UpstreamCache
+{
+public:
+ virtual ~UpstreamCache() = default;
+
+ virtual void Initialize() = 0;
+
+ virtual void RegisterEndpoint(std::unique_ptr<UpstreamEndpoint> Endpoint) = 0;
+ virtual void IterateEndpoints(std::function<bool(UpstreamEndpoint&)>&& Fn) = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheRecord(std::string_view Namespace, const CacheKey& CacheKey, ZenContentType Type) = 0;
+ virtual void GetCacheRecords(std::string_view Namespace,
+ std::span<CacheKeyRequest*> Requests,
+ OnCacheRecordGetComplete&& OnComplete) = 0;
+
+ virtual void GetCacheValues(std::string_view Namespace,
+ std::span<CacheValueRequest*> CacheValueRequests,
+ OnCacheValueGetComplete&& OnComplete) = 0;
+
+ virtual GetUpstreamCacheSingleResult GetCacheChunk(std::string_view Namespace,
+ const CacheKey& CacheKey,
+ const IoHash& ValueContentId) = 0;
+ virtual void GetCacheChunks(std::string_view Namespace,
+ std::span<CacheChunkRequest*> CacheChunkRequests,
+ OnCacheChunksGetComplete&& OnComplete) = 0;
+
+ virtual void EnqueueUpstream(UpstreamCacheRecord CacheRecord) = 0;
+
+ virtual void GetStatus(CbObjectWriter& CbO) = 0;
+
+ static std::unique_ptr<UpstreamCache> Create(const UpstreamCacheOptions& Options, ZenCacheStore& CacheStore, CidStore& CidStore);
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstreamservice.cpp b/src/zenserver/upstream/upstreamservice.cpp
new file mode 100644
index 000000000..6db1357c5
--- /dev/null
+++ b/src/zenserver/upstream/upstreamservice.cpp
@@ -0,0 +1,56 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+#include <upstream/upstreamservice.h>
+
+#include <auth/authmgr.h>
+#include <upstream/upstreamcache.h>
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/string.h>
+
+namespace zen {
+
+using namespace std::literals;
+
+HttpUpstreamService::HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr) : m_Upstream(Upstream), m_AuthMgr(Mgr)
+{
+ m_Router.RegisterRoute(
+ "endpoints",
+ [this](HttpRouterRequest& Req) {
+ CbObjectWriter Writer;
+ Writer.BeginArray("Endpoints"sv);
+ m_Upstream.IterateEndpoints([&Writer](UpstreamEndpoint& Ep) {
+ UpstreamEndpointInfo Info = Ep.GetEndpointInfo();
+ UpstreamEndpointStatus Status = Ep.GetStatus();
+
+ Writer.BeginObject();
+ Writer << "Name"sv << Info.Name;
+ Writer << "Url"sv << Info.Url;
+ Writer << "State"sv << ToString(Status.State);
+ Writer << "Reason"sv << Status.Reason;
+ Writer.EndObject();
+
+ return true;
+ });
+ Writer.EndArray();
+ Req.ServerRequest().WriteResponse(HttpResponseCode::OK, Writer.Save());
+ },
+ HttpVerb::kGet);
+}
+
+HttpUpstreamService::~HttpUpstreamService()
+{
+}
+
+const char*
+HttpUpstreamService::BaseUri() const
+{
+ return "/upstream/";
+}
+
+void
+HttpUpstreamService::HandleRequest(zen::HttpServerRequest& Request)
+{
+ m_Router.HandleRequest(Request);
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/upstreamservice.h b/src/zenserver/upstream/upstreamservice.h
new file mode 100644
index 000000000..f1da03c8c
--- /dev/null
+++ b/src/zenserver/upstream/upstreamservice.h
@@ -0,0 +1,27 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenhttp/httpserver.h>
+
+namespace zen {
+
+class AuthMgr;
+class UpstreamCache;
+
+class HttpUpstreamService final : public zen::HttpService
+{
+public:
+ HttpUpstreamService(UpstreamCache& Upstream, AuthMgr& Mgr);
+ virtual ~HttpUpstreamService();
+
+ virtual const char* BaseUri() const override;
+ virtual void HandleRequest(zen::HttpServerRequest& Request) override;
+
+private:
+ UpstreamCache& m_Upstream;
+ AuthMgr& m_AuthMgr;
+ HttpRequestRouter m_Router;
+};
+
+} // namespace zen
diff --git a/src/zenserver/upstream/zen.cpp b/src/zenserver/upstream/zen.cpp
new file mode 100644
index 000000000..9e1212834
--- /dev/null
+++ b/src/zenserver/upstream/zen.cpp
@@ -0,0 +1,326 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zen.h"
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/fmtutils.h>
+#include <zencore/session.h>
+#include <zencore/stream.h>
+#include <zenhttp/httpcommon.h>
+#include <zenhttp/httpshared.h>
+
+#include "cache/structuredcachestore.h"
+#include "diag/formatters.h"
+#include "diag/logging.h"
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <cpr/cpr.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <xxhash.h>
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+namespace detail {
+ struct ZenCacheSessionState
+ {
+ ZenCacheSessionState(ZenStructuredCacheClient& Client) : OwnerClient(Client) {}
+ ~ZenCacheSessionState() {}
+
+ void Reset(std::chrono::milliseconds ConnectTimeout, std::chrono::milliseconds Timeout)
+ {
+ Session.SetBody({});
+ Session.SetHeader({});
+ Session.SetConnectTimeout(ConnectTimeout);
+ Session.SetTimeout(Timeout);
+ }
+
+ cpr::Session& GetSession() { return Session; }
+
+ private:
+ ZenStructuredCacheClient& OwnerClient;
+ cpr::Session Session;
+ };
+
+} // namespace detail
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenStructuredCacheClient::ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options)
+: m_Log(logging::Get(std::string_view("zenclient")))
+, m_ServiceUrl(Options.Url)
+, m_ConnectTimeout(Options.ConnectTimeout)
+, m_Timeout(Options.Timeout)
+{
+}
+
+ZenStructuredCacheClient::~ZenStructuredCacheClient()
+{
+}
+
+detail::ZenCacheSessionState*
+ZenStructuredCacheClient::AllocSessionState()
+{
+ detail::ZenCacheSessionState* State = nullptr;
+
+ if (RwLock::ExclusiveLockScope _(m_SessionStateLock); !m_SessionStateCache.empty())
+ {
+ State = m_SessionStateCache.front();
+ m_SessionStateCache.pop_front();
+ }
+
+ if (State == nullptr)
+ {
+ State = new detail::ZenCacheSessionState(*this);
+ }
+
+ State->Reset(m_ConnectTimeout, m_Timeout);
+
+ return State;
+}
+
+void
+ZenStructuredCacheClient::FreeSessionState(detail::ZenCacheSessionState* State)
+{
+ RwLock::ExclusiveLockScope _(m_SessionStateLock);
+ m_SessionStateCache.push_front(State);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+using namespace std::literals;
+
+ZenStructuredCacheSession::ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient)
+: m_Log(OuterClient->Log())
+, m_Client(std::move(OuterClient))
+{
+ m_SessionState = m_Client->AllocSessionState();
+}
+
+ZenStructuredCacheSession::~ZenStructuredCacheSession()
+{
+ m_Client->FreeSessionState(m_SessionState);
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::CheckHealth()
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/health/check";
+
+ cpr::Session& Session = m_SessionState->GetSession();
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ cpr::Response Response = Session.Get();
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ return {.Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Response.status_code == 200};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Accept", std::string{MapContentTypeToString(Type)}}});
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer, .Bytes = Response.downloaded_bytes, .ElapsedSeconds = Response.elapsed, .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::GetCacheChunk(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ const IoHash& ValueContentId)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Accept", "application/x-ue-comp"}});
+
+ cpr::Response Response = Session.Get();
+ ZEN_DEBUG("GET {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = Buffer,
+ .Bytes = Response.downloaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Reason = Response.reason,
+ .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::PutCacheRecord(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ IoBuffer Value,
+ ZenContentType Type)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type",
+ Type == ZenContentType::kCbPackage ? "application/x-ue-cbpkg"
+ : Type == ZenContentType::kCbObject ? "application/x-ue-cb"
+ : "application/octet-stream"}});
+ Session.SetBody(cpr::Body{static_cast<const char*>(Value.Data()), Value.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200 || Response.status_code == 201;
+ return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::PutCacheValue(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ const IoHash& ValueContentId,
+ IoBuffer Payload)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/";
+ if (Namespace != ZenCacheStore::DefaultNamespace)
+ {
+ Uri << Namespace << "/";
+ }
+ Uri << BucketId << "/" << Key.ToHexString() << "/" << ValueContentId.ToHexString();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-comp"}});
+ Session.SetBody(cpr::Body{static_cast<const char*>(Payload.Data()), Payload.Size()});
+
+ cpr::Response Response = Session.Put();
+ ZEN_DEBUG("PUT {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200 || Response.status_code == 201;
+ return {.Bytes = Response.uploaded_bytes, .ElapsedSeconds = Response.elapsed, .Reason = Response.reason, .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::InvokeRpc(const CbObjectView& Request)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/$rpc";
+
+ BinaryWriter Body;
+ Request.CopyTo(Body);
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cb"}, {"Accept", "application/x-ue-cbpkg"}});
+ Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Body.GetData()), Body.GetSize()});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = std::move(Buffer),
+ .Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Reason = Response.reason,
+ .Success = Success};
+}
+
+ZenCacheResult
+ZenStructuredCacheSession::InvokeRpc(const CbPackage& Request)
+{
+ ExtendableStringBuilder<256> Uri;
+ Uri << m_Client->ServiceUrl() << "/z$/$rpc";
+
+ SharedBuffer Message = FormatPackageMessageBuffer(Request).Flatten();
+
+ cpr::Session& Session = m_SessionState->GetSession();
+
+ Session.SetOption(cpr::Url{Uri.c_str()});
+ Session.SetHeader(cpr::Header{{"Content-Type", "application/x-ue-cbpkg"}, {"Accept", "application/x-ue-cbpkg"}});
+ Session.SetBody(cpr::Body{reinterpret_cast<const char*>(Message.GetData()), Message.GetSize()});
+
+ cpr::Response Response = Session.Post();
+ ZEN_DEBUG("POST {}", Response);
+
+ if (Response.error)
+ {
+ return {.ErrorCode = static_cast<int32_t>(Response.error.code), .Reason = std::move(Response.error.message)};
+ }
+
+ const bool Success = Response.status_code == 200;
+ const IoBuffer Buffer = Success ? IoBufferBuilder::MakeCloneFromMemory(Response.text.data(), Response.text.size()) : IoBuffer();
+
+ return {.Response = std::move(Buffer),
+ .Bytes = Response.uploaded_bytes,
+ .ElapsedSeconds = Response.elapsed,
+ .Reason = Response.reason,
+ .Success = Success};
+}
+
+} // namespace zen
diff --git a/src/zenserver/upstream/zen.h b/src/zenserver/upstream/zen.h
new file mode 100644
index 000000000..bfba8fa98
--- /dev/null
+++ b/src/zenserver/upstream/zen.h
@@ -0,0 +1,125 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/thread.h>
+#include <zencore/uid.h>
+#include <zencore/zencore.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <chrono>
+#include <list>
+
+struct ZenCacheValue;
+
+namespace spdlog {
+class logger;
+}
+
+namespace zen {
+
+class CbObjectWriter;
+class CbObjectView;
+class CbPackage;
+class ZenStructuredCacheClient;
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+ struct ZenCacheSessionState;
+}
+
+struct ZenCacheResult
+{
+ IoBuffer Response;
+ int64_t Bytes = {};
+ double ElapsedSeconds = {};
+ int32_t ErrorCode = {};
+ std::string Reason;
+ bool Success = false;
+};
+
+struct ZenStructuredCacheClientOptions
+{
+ std::string_view Name;
+ std::string_view Url;
+ std::span<std::string const> Urls;
+ std::chrono::milliseconds ConnectTimeout{};
+ std::chrono::milliseconds Timeout{};
+};
+
+/** Zen Structured Cache session
+ *
+ * This provides a context in which cache queries can be performed
+ *
+ * These are currently all synchronous. Will need to be made asynchronous
+ */
+class ZenStructuredCacheSession
+{
+public:
+ ZenStructuredCacheSession(Ref<ZenStructuredCacheClient>&& OuterClient);
+ ~ZenStructuredCacheSession();
+
+ ZenCacheResult CheckHealth();
+ ZenCacheResult GetCacheRecord(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, ZenContentType Type);
+ ZenCacheResult GetCacheChunk(std::string_view Namespace, std::string_view BucketId, const IoHash& Key, const IoHash& ValueContentId);
+ ZenCacheResult PutCacheRecord(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ IoBuffer Value,
+ ZenContentType Type);
+ ZenCacheResult PutCacheValue(std::string_view Namespace,
+ std::string_view BucketId,
+ const IoHash& Key,
+ const IoHash& ValueContentId,
+ IoBuffer Payload);
+ ZenCacheResult InvokeRpc(const CbObjectView& Request);
+ ZenCacheResult InvokeRpc(const CbPackage& Package);
+
+private:
+ inline spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ Ref<ZenStructuredCacheClient> m_Client;
+ detail::ZenCacheSessionState* m_SessionState;
+};
+
+/** Zen Structured Cache client
+ *
+ * This represents an endpoint to query -- actual queries should be done via
+ * ZenStructuredCacheSession
+ */
+class ZenStructuredCacheClient : public RefCounted
+{
+public:
+ ZenStructuredCacheClient(const ZenStructuredCacheClientOptions& Options);
+ ~ZenStructuredCacheClient();
+
+ std::string_view ServiceUrl() const { return m_ServiceUrl; }
+
+ inline spdlog::logger& Log() { return m_Log; }
+
+private:
+ spdlog::logger& m_Log;
+ std::string m_ServiceUrl;
+ std::chrono::milliseconds m_ConnectTimeout;
+ std::chrono::milliseconds m_Timeout;
+
+ RwLock m_SessionStateLock;
+ std::list<detail::ZenCacheSessionState*> m_SessionStateCache;
+
+ detail::ZenCacheSessionState* AllocSessionState();
+ void FreeSessionState(detail::ZenCacheSessionState*);
+
+ friend class ZenStructuredCacheSession;
+};
+
+} // namespace zen
diff --git a/src/zenserver/windows/service.cpp b/src/zenserver/windows/service.cpp
new file mode 100644
index 000000000..89bacab0b
--- /dev/null
+++ b/src/zenserver/windows/service.cpp
@@ -0,0 +1,646 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "service.h"
+
+#include <zencore/zencore.h>
+
+#if ZEN_PLATFORM_WINDOWS
+
+# include <zencore/except.h>
+# include <zencore/zencore.h>
+
+# include <stdio.h>
+# include <tchar.h>
+# include <zencore/windows.h>
+
+# define SVCNAME L"Zen Store"
+
+SERVICE_STATUS gSvcStatus;
+SERVICE_STATUS_HANDLE gSvcStatusHandle;
+HANDLE ghSvcStopEvent = NULL;
+
+void SvcInstall(void);
+
+void ReportSvcStatus(DWORD, DWORD, DWORD);
+void SvcReportEvent(LPTSTR);
+
+WindowsService::WindowsService()
+{
+}
+
+WindowsService::~WindowsService()
+{
+}
+
+//
+// Purpose:
+// Installs a service in the SCM database
+//
+// Parameters:
+// None
+//
+// Return value:
+// None
+//
+VOID
+WindowsService::Install()
+{
+ SC_HANDLE schSCManager;
+ SC_HANDLE schService;
+ TCHAR szPath[MAX_PATH];
+
+ if (!GetModuleFileName(NULL, szPath, MAX_PATH))
+ {
+ printf("Cannot install service (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Create the service
+
+ schService = CreateService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ SVCNAME, // service name to display
+ SERVICE_ALL_ACCESS, // desired access
+ SERVICE_WIN32_OWN_PROCESS, // service type
+ SERVICE_DEMAND_START, // start type
+ SERVICE_ERROR_NORMAL, // error control type
+ szPath, // path to service's binary
+ NULL, // no load ordering group
+ NULL, // no tag identifier
+ NULL, // no dependencies
+ NULL, // LocalSystem account
+ NULL); // no password
+
+ if (schService == NULL)
+ {
+ printf("CreateService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+ else
+ printf("Service installed successfully\n");
+
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+
+void
+WindowsService::Delete()
+{
+ SC_HANDLE schSCManager;
+ SC_HANDLE schService;
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the service.
+
+ schService = OpenService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ DELETE); // need delete access
+
+ if (schService == NULL)
+ {
+ printf("OpenService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+
+ // Delete the service.
+
+ if (!DeleteService(schService))
+ {
+ printf("DeleteService failed (%d)\n", GetLastError());
+ }
+ else
+ printf("Service deleted successfully\n");
+
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+
+WindowsService* gSvc;
+
+void WINAPI
+CallMain(DWORD, LPSTR*)
+{
+ gSvc->SvcMain();
+}
+
+int
+WindowsService::ServiceMain()
+{
+ gSvc = this;
+
+ SERVICE_TABLE_ENTRY DispatchTable[] = {{(LPWSTR)SVCNAME, (LPSERVICE_MAIN_FUNCTION)&CallMain}, {NULL, NULL}};
+
+ // This call returns when the service has stopped.
+ // The process should simply terminate when the call returns.
+
+ if (!StartServiceCtrlDispatcher(DispatchTable))
+ {
+ const DWORD dwError = zen::GetLastError();
+
+ if (dwError == ERROR_FAILED_SERVICE_CONTROLLER_CONNECT)
+ {
+ // Not actually running as a service
+ gSvc = nullptr;
+
+ zen::SetIsInteractiveSession(true);
+
+ return Run();
+ }
+ else
+ {
+ zen::ThrowSystemError(dwError, "StartServiceCtrlDispatcher failed");
+ }
+ }
+
+ zen::SetIsInteractiveSession(false);
+
+ return 0;
+}
+
+int
+WindowsService::SvcMain()
+{
+ // Register the handler function for the service
+
+ gSvcStatusHandle = RegisterServiceCtrlHandler(SVCNAME, SvcCtrlHandler);
+
+ if (!gSvcStatusHandle)
+ {
+ SvcReportEvent((LPTSTR)TEXT("RegisterServiceCtrlHandler"));
+
+ return 1;
+ }
+
+ // These SERVICE_STATUS members remain as set here
+
+ gSvcStatus.dwServiceType = SERVICE_WIN32_OWN_PROCESS;
+ gSvcStatus.dwServiceSpecificExitCode = 0;
+
+ // Report initial status to the SCM
+
+ ReportSvcStatus(SERVICE_START_PENDING, NO_ERROR, 3000);
+
+ // Create an event. The control handler function, SvcCtrlHandler,
+ // signals this event when it receives the stop control code.
+
+ ghSvcStopEvent = CreateEvent(NULL, // default security attributes
+ TRUE, // manual reset event
+ FALSE, // not signaled
+ NULL); // no name
+
+ if (ghSvcStopEvent == NULL)
+ {
+ ReportSvcStatus(SERVICE_STOPPED, GetLastError(), 0);
+
+ return 1;
+ }
+
+ // Report running status when initialization is complete.
+
+ ReportSvcStatus(SERVICE_RUNNING, NO_ERROR, 0);
+
+ int ReturnCode = Run();
+
+ ReportSvcStatus(SERVICE_STOPPED, NO_ERROR, 0);
+
+ return ReturnCode;
+}
+
+//
+// Purpose:
+// Retrieves and displays the current service configuration.
+//
+// Parameters:
+// None
+//
+// Return value:
+// None
+//
+void
+DoQuerySvc()
+{
+ SC_HANDLE schSCManager{};
+ SC_HANDLE schService{};
+ LPQUERY_SERVICE_CONFIG lpsc{};
+ LPSERVICE_DESCRIPTION lpsd{};
+ DWORD dwBytesNeeded{}, cbBufSize{}, dwError{};
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the service.
+
+ schService = OpenService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ SERVICE_QUERY_CONFIG); // need query config access
+
+ if (schService == NULL)
+ {
+ printf("OpenService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+
+ // Get the configuration information.
+
+ if (!QueryServiceConfig(schService, NULL, 0, &dwBytesNeeded))
+ {
+ dwError = GetLastError();
+ if (ERROR_INSUFFICIENT_BUFFER == dwError)
+ {
+ cbBufSize = dwBytesNeeded;
+ lpsc = (LPQUERY_SERVICE_CONFIG)LocalAlloc(LMEM_FIXED, cbBufSize);
+ }
+ else
+ {
+ printf("QueryServiceConfig failed (%d)", dwError);
+ goto cleanup;
+ }
+ }
+
+ if (!QueryServiceConfig(schService, lpsc, cbBufSize, &dwBytesNeeded))
+ {
+ printf("QueryServiceConfig failed (%d)", GetLastError());
+ goto cleanup;
+ }
+
+ if (!QueryServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, NULL, 0, &dwBytesNeeded))
+ {
+ dwError = GetLastError();
+ if (ERROR_INSUFFICIENT_BUFFER == dwError)
+ {
+ cbBufSize = dwBytesNeeded;
+ lpsd = (LPSERVICE_DESCRIPTION)LocalAlloc(LMEM_FIXED, cbBufSize);
+ }
+ else
+ {
+ printf("QueryServiceConfig2 failed (%d)", dwError);
+ goto cleanup;
+ }
+ }
+
+ if (!QueryServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, (LPBYTE)lpsd, cbBufSize, &dwBytesNeeded))
+ {
+ printf("QueryServiceConfig2 failed (%d)", GetLastError());
+ goto cleanup;
+ }
+
+ // Print the configuration information.
+
+ _tprintf(TEXT("%s configuration: \n"), SVCNAME);
+ _tprintf(TEXT(" Type: 0x%x\n"), lpsc->dwServiceType);
+ _tprintf(TEXT(" Start Type: 0x%x\n"), lpsc->dwStartType);
+ _tprintf(TEXT(" Error Control: 0x%x\n"), lpsc->dwErrorControl);
+ _tprintf(TEXT(" Binary path: %s\n"), lpsc->lpBinaryPathName);
+ _tprintf(TEXT(" Account: %s\n"), lpsc->lpServiceStartName);
+
+ if (lpsd->lpDescription != NULL && lstrcmp(lpsd->lpDescription, TEXT("")) != 0)
+ _tprintf(TEXT(" Description: %s\n"), lpsd->lpDescription);
+ if (lpsc->lpLoadOrderGroup != NULL && lstrcmp(lpsc->lpLoadOrderGroup, TEXT("")) != 0)
+ _tprintf(TEXT(" Load order group: %s\n"), lpsc->lpLoadOrderGroup);
+ if (lpsc->dwTagId != 0)
+ _tprintf(TEXT(" Tag ID: %d\n"), lpsc->dwTagId);
+ if (lpsc->lpDependencies != NULL && lstrcmp(lpsc->lpDependencies, TEXT("")) != 0)
+ _tprintf(TEXT(" Dependencies: %s\n"), lpsc->lpDependencies);
+
+ LocalFree(lpsc);
+ LocalFree(lpsd);
+
+cleanup:
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+
+//
+// Purpose:
+// Disables the service.
+//
+// Parameters:
+// None
+//
+// Return value:
+// None
+//
+void
+DoDisableSvc()
+{
+ SC_HANDLE schSCManager;
+ SC_HANDLE schService;
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the service.
+
+ schService = OpenService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ SERVICE_CHANGE_CONFIG); // need change config access
+
+ if (schService == NULL)
+ {
+ printf("OpenService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+
+ // Change the service start type.
+
+ if (!ChangeServiceConfig(schService, // handle of service
+ SERVICE_NO_CHANGE, // service type: no change
+ SERVICE_DISABLED, // service start type
+ SERVICE_NO_CHANGE, // error control: no change
+ NULL, // binary path: no change
+ NULL, // load order group: no change
+ NULL, // tag ID: no change
+ NULL, // dependencies: no change
+ NULL, // account name: no change
+ NULL, // password: no change
+ NULL)) // display name: no change
+ {
+ printf("ChangeServiceConfig failed (%d)\n", GetLastError());
+ }
+ else
+ printf("Service disabled successfully.\n");
+
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+
+//
+// Purpose:
+// Enables the service.
+//
+// Parameters:
+// None
+//
+// Return value:
+// None
+//
+VOID __stdcall DoEnableSvc()
+{
+ SC_HANDLE schSCManager;
+ SC_HANDLE schService;
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the service.
+
+ schService = OpenService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ SERVICE_CHANGE_CONFIG); // need change config access
+
+ if (schService == NULL)
+ {
+ printf("OpenService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+
+ // Change the service start type.
+
+ if (!ChangeServiceConfig(schService, // handle of service
+ SERVICE_NO_CHANGE, // service type: no change
+ SERVICE_DEMAND_START, // service start type
+ SERVICE_NO_CHANGE, // error control: no change
+ NULL, // binary path: no change
+ NULL, // load order group: no change
+ NULL, // tag ID: no change
+ NULL, // dependencies: no change
+ NULL, // account name: no change
+ NULL, // password: no change
+ NULL)) // display name: no change
+ {
+ printf("ChangeServiceConfig failed (%d)\n", GetLastError());
+ }
+ else
+ printf("Service enabled successfully.\n");
+
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+//
+// Purpose:
+// Updates the service description to "This is a test description".
+//
+// Parameters:
+// None
+//
+// Return value:
+// None
+//
+void
+DoUpdateSvcDesc()
+{
+ SC_HANDLE schSCManager;
+ SC_HANDLE schService;
+ SERVICE_DESCRIPTION sd;
+ TCHAR szDesc[] = TEXT("This is a test description");
+
+ // Get a handle to the SCM database.
+
+ schSCManager = OpenSCManager(NULL, // local computer
+ NULL, // ServicesActive database
+ SC_MANAGER_ALL_ACCESS); // full access rights
+
+ if (NULL == schSCManager)
+ {
+ printf("OpenSCManager failed (%d)\n", GetLastError());
+ return;
+ }
+
+ // Get a handle to the service.
+
+ schService = OpenService(schSCManager, // SCM database
+ SVCNAME, // name of service
+ SERVICE_CHANGE_CONFIG); // need change config access
+
+ if (schService == NULL)
+ {
+ printf("OpenService failed (%d)\n", GetLastError());
+ CloseServiceHandle(schSCManager);
+ return;
+ }
+
+ // Change the service description.
+
+ sd.lpDescription = szDesc;
+
+ if (!ChangeServiceConfig2(schService, // handle to service
+ SERVICE_CONFIG_DESCRIPTION, // change: description
+ &sd)) // new description
+ {
+ printf("ChangeServiceConfig2 failed\n");
+ }
+ else
+ printf("Service description updated successfully.\n");
+
+ CloseServiceHandle(schService);
+ CloseServiceHandle(schSCManager);
+}
+
+//
+// Purpose:
+// Sets the current service status and reports it to the SCM.
+//
+// Parameters:
+// dwCurrentState - The current state (see SERVICE_STATUS)
+// dwWin32ExitCode - The system error code
+// dwWaitHint - Estimated time for pending operation,
+// in milliseconds
+//
+// Return value:
+// None
+//
+VOID
+ReportSvcStatus(DWORD dwCurrentState, DWORD dwWin32ExitCode, DWORD dwWaitHint)
+{
+ static DWORD dwCheckPoint = 1;
+
+ // Fill in the SERVICE_STATUS structure.
+
+ gSvcStatus.dwCurrentState = dwCurrentState;
+ gSvcStatus.dwWin32ExitCode = dwWin32ExitCode;
+ gSvcStatus.dwWaitHint = dwWaitHint;
+
+ if (dwCurrentState == SERVICE_START_PENDING)
+ gSvcStatus.dwControlsAccepted = 0;
+ else
+ gSvcStatus.dwControlsAccepted = SERVICE_ACCEPT_STOP;
+
+ if ((dwCurrentState == SERVICE_RUNNING) || (dwCurrentState == SERVICE_STOPPED))
+ gSvcStatus.dwCheckPoint = 0;
+ else
+ gSvcStatus.dwCheckPoint = dwCheckPoint++;
+
+ // Report the status of the service to the SCM.
+ SetServiceStatus(gSvcStatusHandle, &gSvcStatus);
+}
+
+void
+WindowsService::SvcCtrlHandler(DWORD dwCtrl)
+{
+ // Handle the requested control code.
+ //
+ // Called by SCM whenever a control code is sent to the service
+ // using the ControlService function.
+
+ switch (dwCtrl)
+ {
+ case SERVICE_CONTROL_STOP:
+ ReportSvcStatus(SERVICE_STOP_PENDING, NO_ERROR, 0);
+
+ // Signal the service to stop.
+
+ SetEvent(ghSvcStopEvent);
+ zen::RequestApplicationExit(0);
+
+ ReportSvcStatus(gSvcStatus.dwCurrentState, NO_ERROR, 0);
+ return;
+
+ case SERVICE_CONTROL_INTERROGATE:
+ break;
+
+ default:
+ break;
+ }
+}
+
+//
+// Purpose:
+// Logs messages to the event log
+//
+// Parameters:
+// szFunction - name of function that failed
+//
+// Return value:
+// None
+//
+// Remarks:
+// The service must have an entry in the Application event log.
+//
+VOID
+SvcReportEvent(LPTSTR szFunction)
+{
+ ZEN_UNUSED(szFunction);
+
+ // HANDLE hEventSource;
+ // LPCTSTR lpszStrings[2];
+ // TCHAR Buffer[80];
+
+ // hEventSource = RegisterEventSource(NULL, SVCNAME);
+
+ // if (NULL != hEventSource)
+ //{
+ // StringCchPrintf(Buffer, 80, TEXT("%s failed with %d"), szFunction, GetLastError());
+
+ // lpszStrings[0] = SVCNAME;
+ // lpszStrings[1] = Buffer;
+
+ // ReportEvent(hEventSource, // event log handle
+ // EVENTLOG_ERROR_TYPE, // event type
+ // 0, // event category
+ // SVC_ERROR, // event identifier
+ // NULL, // no security identifier
+ // 2, // size of lpszStrings array
+ // 0, // no binary data
+ // lpszStrings, // array of strings
+ // NULL); // no binary data
+
+ // DeregisterEventSource(hEventSource);
+ //}
+}
+
+#endif // ZEN_PLATFORM_WINDOWS
diff --git a/src/zenserver/windows/service.h b/src/zenserver/windows/service.h
new file mode 100644
index 000000000..7c9610983
--- /dev/null
+++ b/src/zenserver/windows/service.h
@@ -0,0 +1,20 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+class WindowsService
+{
+public:
+ WindowsService();
+ ~WindowsService();
+
+ virtual int Run() = 0;
+
+ int ServiceMain();
+
+ static void Install();
+ static void Delete();
+
+ int SvcMain();
+ static void __stdcall SvcCtrlHandler(unsigned long);
+};
diff --git a/src/zenserver/xmake.lua b/src/zenserver/xmake.lua
new file mode 100644
index 000000000..23bfb9535
--- /dev/null
+++ b/src/zenserver/xmake.lua
@@ -0,0 +1,60 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target("zenserver")
+ set_kind("binary")
+ add_deps("zencore", "zenhttp", "zenstore", "zenutil")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_files("zenserver.cpp", {unity_ignored = true })
+ add_includedirs(".")
+ set_symbols("debug")
+
+ if is_mode("release") then
+ set_optimize("fastest")
+ end
+
+ if is_plat("windows") then
+ add_ldflags("/subsystem:console,5.02")
+ add_ldflags("/MANIFEST:EMBED")
+ add_ldflags("/LTCG")
+ add_files("zenserver.rc")
+ add_cxxflags("/bigobj")
+ else
+ remove_files("windows/**")
+ end
+
+ if is_plat("macosx") then
+ add_ldflags("-framework CoreFoundation")
+ add_ldflags("-framework CoreGraphics")
+ add_ldflags("-framework CoreText")
+ add_ldflags("-framework Foundation")
+ add_ldflags("-framework Security")
+ add_ldflags("-framework SystemConfiguration")
+ add_syslinks("bsm")
+ end
+
+ add_options("compute")
+ add_options("exec")
+
+ add_packages(
+ "vcpkg::asio",
+ "vcpkg::cxxopts",
+ "vcpkg::http-parser",
+ "vcpkg::json11",
+ "vcpkg::lua",
+ "vcpkg::mimalloc",
+ "vcpkg::rocksdb",
+ "vcpkg::sentry-native",
+ "vcpkg::sol2"
+ )
+
+ -- Only applicable to later versions of sentry-native
+ --[[
+ if is_plat("linux") then
+ -- As sentry_native uses symbols from breakpad_client, the latter must
+ -- be specified after the former with GCC-like toolchains. xmake however
+ -- is unaware of this and simply globs files from vcpkg's output. The
+ -- line below forces breakpad_client to be to the right of sentry_native
+ add_syslinks("breakpad_client")
+ end
+ ]]--
diff --git a/src/zenserver/zenserver.cpp b/src/zenserver/zenserver.cpp
new file mode 100644
index 000000000..635fd04e0
--- /dev/null
+++ b/src/zenserver/zenserver.cpp
@@ -0,0 +1,1261 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/config.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/iobuffer.h>
+#include <zencore/logging.h>
+#include <zencore/refcount.h>
+#include <zencore/scopeguard.h>
+#include <zencore/session.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/timer.h>
+#include <zencore/trace.h>
+#include <zenhttp/httpserver.h>
+#include <zenhttp/websocket.h>
+#include <zenstore/cidstore.h>
+#include <zenstore/scrubcontext.h>
+#include <zenutil/basicfile.h>
+#include <zenutil/zenserverprocess.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#endif
+
+#if ZEN_USE_MIMALLOC
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <mimalloc-new-delete.h>
+# include <mimalloc.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+#endif
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <asio.hpp>
+#include <lua.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <exception>
+#include <list>
+#include <optional>
+#include <regex>
+#include <set>
+#include <unordered_map>
+
+//////////////////////////////////////////////////////////////////////////
+// We don't have any doctest code in this file but this is needed to bring
+// in some shared code into the executable
+
+#if ZEN_WITH_TESTS
+# define ZEN_TEST_WITH_RUNNER 1
+# include <zencore/testing.h>
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+#include "config.h"
+#include "diag/logging.h"
+
+#if ZEN_PLATFORM_WINDOWS
+# include "windows/service.h"
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// Sentry
+//
+
+#if !defined(ZEN_USE_SENTRY)
+# if ZEN_PLATFORM_MAC && ZEN_ARCH_ARM64
+// vcpkg's sentry-native port does not support Arm on Mac.
+# define ZEN_USE_SENTRY 0
+# else
+# define ZEN_USE_SENTRY 1
+# endif
+#endif
+
+#if ZEN_USE_SENTRY
+# define SENTRY_BUILD_STATIC 1
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <sentry.h>
+# include <spdlog/sinks/base_sink.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+// Sentry currently does not automatically add all required Windows
+// libraries to the linker when consumed via vcpkg
+
+# if ZEN_PLATFORM_WINDOWS
+# pragma comment(lib, "sentry.lib")
+# pragma comment(lib, "dbghelp.lib")
+# pragma comment(lib, "winhttp.lib")
+# pragma comment(lib, "version.lib")
+# endif
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+// Services
+//
+
+#include "admin/admin.h"
+#include "auth/authmgr.h"
+#include "auth/authservice.h"
+#include "cache/structuredcache.h"
+#include "cache/structuredcachestore.h"
+#include "cidstore.h"
+#include "compute/function.h"
+#include "diag/diagsvcs.h"
+#include "frontend/frontend.h"
+#include "monitoring/httpstats.h"
+#include "monitoring/httpstatus.h"
+#include "objectstore/objectstore.h"
+#include "projectstore/projectstore.h"
+#include "testing/httptest.h"
+#include "upstream/upstream.h"
+#include "zenstore/gc.h"
+
+#define ZEN_APP_NAME "Zen store"
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace utils {
+#if ZEN_USE_SENTRY
+ class sentry_sink final : public spdlog::sinks::base_sink<spdlog::details::null_mutex>
+ {
+ public:
+ sentry_sink() {}
+
+ protected:
+ static constexpr sentry_level_t MapToSentryLevel[spdlog::level::level_enum::n_levels] = {SENTRY_LEVEL_DEBUG,
+ SENTRY_LEVEL_DEBUG,
+ SENTRY_LEVEL_INFO,
+ SENTRY_LEVEL_WARNING,
+ SENTRY_LEVEL_ERROR,
+ SENTRY_LEVEL_FATAL,
+ SENTRY_LEVEL_DEBUG};
+
+ void sink_it_(const spdlog::details::log_msg& msg) override
+ {
+ std::string Message = fmt::format("{}\n{}({}) [{}]", msg.payload, msg.source.filename, msg.source.line, msg.source.funcname);
+ sentry_value_t event = sentry_value_new_message_event(
+ /* level */ MapToSentryLevel[msg.level],
+ /* logger */ nullptr,
+ /* message */ Message.c_str());
+ sentry_event_value_add_stacktrace(event, NULL, 0);
+ sentry_capture_event(event);
+ }
+ void flush_() override {}
+ };
+#endif
+
+ asio::error_code ResolveHostname(asio::io_context& Ctx,
+ std::string_view Host,
+ std::string_view DefaultPort,
+ std::vector<std::string>& OutEndpoints)
+ {
+ std::string_view Port = DefaultPort;
+
+ if (const size_t Idx = Host.find(":"); Idx != std::string_view::npos)
+ {
+ Port = Host.substr(Idx + 1);
+ Host = Host.substr(0, Idx);
+ }
+
+ asio::ip::tcp::resolver Resolver(Ctx);
+
+ asio::error_code ErrorCode;
+ asio::ip::tcp::resolver::results_type Endpoints = Resolver.resolve(Host, Port, ErrorCode);
+
+ if (!ErrorCode)
+ {
+ for (const asio::ip::tcp::endpoint Ep : Endpoints)
+ {
+ OutEndpoints.push_back(fmt::format("http://{}:{}", Ep.address().to_string(), Ep.port()));
+ }
+ }
+
+ return ErrorCode;
+ }
+} // namespace utils
+
+class ZenServer : public IHttpStatusProvider
+{
+public:
+ int Initialize(const ZenServerOptions& ServerOptions, ZenServerState::ZenServerEntry* ServerEntry)
+ {
+ m_UseSentry = ServerOptions.NoSentry == false;
+ m_ServerEntry = ServerEntry;
+ m_DebugOptionForcedCrash = ServerOptions.ShouldCrash;
+ const int ParentPid = ServerOptions.OwnerPid;
+
+ if (ParentPid)
+ {
+ zen::ProcessHandle OwnerProcess;
+ OwnerProcess.Initialize(ParentPid);
+
+ if (!OwnerProcess.IsValid())
+ {
+ ZEN_WARN("Unable to initialize process handle for specified parent pid #{}", ParentPid);
+
+ // If the pid is not reachable should we just shut down immediately? the intended owner process
+ // could have been killed or somehow crashed already
+ }
+ else
+ {
+ ZEN_INFO("Using parent pid #{} to control process lifetime", ParentPid);
+ }
+
+ m_ProcessMonitor.AddPid(ParentPid);
+ }
+
+ // Initialize/check mutex based on base port
+
+ std::string MutexName = fmt::format("zen_{}", ServerOptions.BasePort);
+
+ if (zen::NamedMutex::Exists(MutexName) || ((m_ServerMutex.Create(MutexName) == false)))
+ {
+ throw std::runtime_error(fmt::format("Failed to create mutex '{}' - is another instance already running?", MutexName).c_str());
+ }
+
+ InitializeState(ServerOptions);
+
+ m_HealthService.SetHealthInfo({.DataRoot = m_DataRoot,
+ .AbsLogPath = ServerOptions.AbsLogFile,
+ .HttpServerClass = std::string(ServerOptions.HttpServerClass),
+ .BuildVersion = std::string(ZEN_CFG_VERSION_BUILD_STRING_FULL)});
+
+ // Ok so now we're configured, let's kick things off
+
+ m_Http = zen::CreateHttpServer(ServerOptions.HttpServerClass);
+ int EffectiveBasePort = m_Http->Initialize(ServerOptions.BasePort);
+
+ if (ServerOptions.WebSocketPort != 0)
+ {
+ const uint32 ThreadCount =
+ ServerOptions.WebSocketThreads > 0 ? uint32_t(ServerOptions.WebSocketThreads) : std::thread::hardware_concurrency();
+
+ m_WebSocket = zen::WebSocketServer::Create(
+ {.Port = gsl::narrow<uint16_t>(ServerOptions.WebSocketPort), .ThreadCount = Max(ThreadCount, uint32_t(16))});
+ }
+
+ // Setup authentication manager
+ {
+ std::string EncryptionKey = ServerOptions.EncryptionKey;
+
+ if (EncryptionKey.empty())
+ {
+ EncryptionKey = "abcdefghijklmnopqrstuvxyz0123456";
+
+ ZEN_WARN("using default encryption key");
+ }
+
+ std::string EncryptionIV = ServerOptions.EncryptionIV;
+
+ if (EncryptionIV.empty())
+ {
+ EncryptionIV = "0123456789abcdef";
+
+ ZEN_WARN("using default encryption initialization vector");
+ }
+
+ m_AuthMgr = AuthMgr::Create({.RootDirectory = m_DataRoot / "auth",
+ .EncryptionKey = AesKey256Bit::FromString(EncryptionKey),
+ .EncryptionIV = AesIV128Bit::FromString(EncryptionIV)});
+
+ for (const ZenOpenIdProviderConfig& OpenIdProvider : ServerOptions.AuthConfig.OpenIdProviders)
+ {
+ m_AuthMgr->AddOpenIdProvider({.Name = OpenIdProvider.Name, .Url = OpenIdProvider.Url, .ClientId = OpenIdProvider.ClientId});
+ }
+ }
+
+ m_AuthService = std::make_unique<zen::HttpAuthService>(*m_AuthMgr);
+ m_Http->RegisterService(*m_AuthService);
+
+ m_Http->RegisterService(m_HealthService);
+ m_Http->RegisterService(m_StatsService);
+ m_Http->RegisterService(m_StatusService);
+ m_StatusService.RegisterHandler("status", *this);
+
+ // Initialize storage and services
+
+ ZEN_INFO("initializing storage");
+
+ zen::CidStoreConfiguration Config;
+ Config.RootDirectory = m_DataRoot / "cas";
+
+ m_CidStore = std::make_unique<zen::CidStore>(m_GcManager);
+ m_CidStore->Initialize(Config);
+ m_CidService.reset(new zen::HttpCidService{*m_CidStore});
+
+ ZEN_INFO("instantiating project service");
+
+ m_ProjectStore = new zen::ProjectStore(*m_CidStore, m_DataRoot / "projects", m_GcManager);
+ m_HttpProjectService.reset(new zen::HttpProjectService{*m_CidStore, m_ProjectStore, m_StatsService, *m_AuthMgr});
+
+#if ZEN_WITH_COMPUTE_SERVICES
+ if (ServerOptions.ComputeServiceEnabled)
+ {
+ InitializeCompute(ServerOptions);
+ }
+ else
+ {
+ ZEN_INFO("NOT instantiating compute services");
+ }
+#endif // ZEN_WITH_COMPUTE_SERVICES
+
+ if (ServerOptions.StructuredCacheEnabled)
+ {
+ InitializeStructuredCache(ServerOptions);
+ }
+ else
+ {
+ ZEN_INFO("NOT instantiating structured cache service");
+ }
+
+ m_Http->RegisterService(m_TestService); // NOTE: this is intentionally not limited to test mode as it's useful for diagnostics
+ m_Http->RegisterService(m_TestingService);
+ m_Http->RegisterService(m_AdminService);
+
+ if (m_WebSocket)
+ {
+ m_WebSocket->RegisterService(m_TestingService);
+ }
+
+ if (m_HttpProjectService)
+ {
+ m_Http->RegisterService(*m_HttpProjectService);
+ }
+
+ m_Http->RegisterService(*m_CidService);
+
+#if ZEN_WITH_COMPUTE_SERVICES
+ if (ServerOptions.ComputeServiceEnabled)
+ {
+ if (m_HttpFunctionService != nullptr)
+ {
+ m_Http->RegisterService(*m_HttpFunctionService);
+ }
+ }
+#endif // ZEN_WITH_COMPUTE_SERVICES
+
+ m_FrontendService = std::make_unique<HttpFrontendService>(m_ContentRoot);
+
+ if (m_FrontendService)
+ {
+ m_Http->RegisterService(*m_FrontendService);
+ }
+
+ if (ServerOptions.ObjectStoreEnabled)
+ {
+ ObjectStoreConfig ObjCfg;
+ ObjCfg.RootDirectory = m_DataRoot / "obj";
+ ObjCfg.ServerPort = static_cast<uint16_t>(EffectiveBasePort);
+
+ for (const auto& Bucket : ServerOptions.ObjectStoreConfig.Buckets)
+ {
+ ObjectStoreConfig::BucketConfig NewBucket{.Name = Bucket.Name};
+ NewBucket.Directory = Bucket.Directory.empty() ? (ObjCfg.RootDirectory / Bucket.Name) : Bucket.Directory;
+ ObjCfg.Buckets.push_back(std::move(NewBucket));
+ }
+
+ m_ObjStoreService = std::make_unique<HttpObjectStoreService>(std::move(ObjCfg));
+ m_Http->RegisterService(*m_ObjStoreService);
+ }
+
+ ZEN_INFO("initializing GC, enabled '{}', interval {}s", ServerOptions.GcConfig.Enabled, ServerOptions.GcConfig.IntervalSeconds);
+ zen::GcSchedulerConfig GcConfig{.RootDirectory = m_DataRoot / "gc",
+ .MonitorInterval = std::chrono::seconds(ServerOptions.GcConfig.MonitorIntervalSeconds),
+ .Interval = std::chrono::seconds(ServerOptions.GcConfig.IntervalSeconds),
+ .MaxCacheDuration = std::chrono::seconds(ServerOptions.GcConfig.Cache.MaxDurationSeconds),
+ .CollectSmallObjects = ServerOptions.GcConfig.CollectSmallObjects,
+ .Enabled = ServerOptions.GcConfig.Enabled,
+ .DiskReserveSize = ServerOptions.GcConfig.DiskReserveSize,
+ .DiskSizeSoftLimit = ServerOptions.GcConfig.Cache.DiskSizeSoftLimit};
+ m_GcScheduler.Initialize(GcConfig);
+
+ return EffectiveBasePort;
+ }
+
+ void InitializeState(const ZenServerOptions& ServerOptions);
+ void InitializeStructuredCache(const ZenServerOptions& ServerOptions);
+ void InitializeCompute(const ZenServerOptions& ServerOptions);
+
+ void Run()
+ {
+ // This is disabled for now, awaiting better scheduling
+ //
+ // Scrub();
+
+ if (m_ProcessMonitor.IsActive())
+ {
+ EnqueueTimer();
+ }
+
+ if (!m_TestMode)
+ {
+ ZEN_INFO("__________ _________ __ ");
+ ZEN_INFO("\\____ /____ ____ / _____// |_ ___________ ____ ");
+ ZEN_INFO(" / // __ \\ / \\ \\_____ \\\\ __\\/ _ \\_ __ \\_/ __ \\ ");
+ ZEN_INFO(" / /\\ ___/| | \\ / \\| | ( <_> ) | \\/\\ ___/ ");
+ ZEN_INFO("/_______ \\___ >___| / /_______ /|__| \\____/|__| \\___ >");
+ ZEN_INFO(" \\/ \\/ \\/ \\/ \\/ ");
+ }
+
+ ZEN_INFO(ZEN_APP_NAME " now running (pid: {})", zen::GetCurrentProcessId());
+
+#if ZEN_USE_SENTRY
+ ZEN_INFO("sentry crash handler {}", m_UseSentry ? "ENABLED" : "DISABLED");
+ if (m_UseSentry)
+ {
+ sentry_clear_modulecache();
+ }
+#endif
+
+ if (m_DebugOptionForcedCrash)
+ {
+ ZEN_DEBUG_BREAK();
+ }
+
+ const bool IsInteractiveMode = zen::IsInteractiveSession() && !m_TestMode;
+
+ SetNewState(kRunning);
+
+ OnReady();
+
+ if (m_WebSocket)
+ {
+ m_WebSocket->Run();
+ }
+
+ m_Http->Run(IsInteractiveMode);
+
+ SetNewState(kShuttingDown);
+
+ ZEN_INFO(ZEN_APP_NAME " exiting");
+
+ m_IoContext.stop();
+ if (m_IoRunner.joinable())
+ {
+ m_IoRunner.join();
+ }
+
+ Flush();
+ }
+
+ void RequestExit(int ExitCode)
+ {
+ RequestApplicationExit(ExitCode);
+ m_Http->RequestExit();
+ }
+
+ void Cleanup()
+ {
+ ZEN_INFO(ZEN_APP_NAME " cleaning up");
+ m_GcScheduler.Shutdown();
+ }
+
+ void SetDedicatedMode(bool State) { m_IsDedicatedMode = State; }
+ void SetTestMode(bool State) { m_TestMode = State; }
+ void SetDataRoot(std::filesystem::path Root) { m_DataRoot = Root; }
+ void SetContentRoot(std::filesystem::path Root) { m_ContentRoot = Root; }
+
+ std::function<void()> m_IsReadyFunc;
+ void SetIsReadyFunc(std::function<void()>&& IsReadyFunc) { m_IsReadyFunc = std::move(IsReadyFunc); }
+ void OnReady();
+
+ void EnsureIoRunner()
+ {
+ if (!m_IoRunner.joinable())
+ {
+ m_IoRunner = std::thread{[this] { m_IoContext.run(); }};
+ }
+ }
+
+ void EnqueueTimer()
+ {
+ m_PidCheckTimer.expires_after(std::chrono::seconds(1));
+ m_PidCheckTimer.async_wait([this](const asio::error_code&) { CheckOwnerPid(); });
+
+ EnsureIoRunner();
+ }
+
+ void CheckOwnerPid()
+ {
+ // Pick up any new "owner" processes
+
+ std::set<uint32_t> AddedPids;
+
+ for (auto& PidEntry : m_ServerEntry->SponsorPids)
+ {
+ if (uint32_t ThisPid = PidEntry.load(std::memory_order_relaxed))
+ {
+ if (PidEntry.compare_exchange_strong(ThisPid, 0))
+ {
+ if (AddedPids.insert(ThisPid).second)
+ {
+ m_ProcessMonitor.AddPid(ThisPid);
+
+ ZEN_INFO("added process with pid #{} as a sponsor process", ThisPid);
+ }
+ }
+ }
+ }
+
+ if (m_ProcessMonitor.IsRunning())
+ {
+ EnqueueTimer();
+ }
+ else
+ {
+ ZEN_INFO(ZEN_APP_NAME " exiting since sponsor processes are all gone");
+
+ RequestExit(0);
+ }
+ }
+
+ void Scrub()
+ {
+ Stopwatch Timer;
+ ZEN_INFO("Storage validation STARTING");
+
+ ScrubContext Ctx;
+ m_CidStore->Scrub(Ctx);
+ m_ProjectStore->Scrub(Ctx);
+ m_StructuredCacheService->Scrub(Ctx);
+
+ const uint64_t ElapsedTimeMs = Timer.GetElapsedTimeMs();
+
+ ZEN_INFO("Storage validation DONE in {}, ({} in {} chunks - {})",
+ NiceTimeSpanMs(ElapsedTimeMs),
+ NiceBytes(Ctx.ScrubbedBytes()),
+ Ctx.ScrubbedChunks(),
+ NiceByteRate(Ctx.ScrubbedBytes(), ElapsedTimeMs));
+ }
+
+ void Flush()
+ {
+ if (m_CidStore)
+ m_CidStore->Flush();
+
+ if (m_StructuredCacheService)
+ m_StructuredCacheService->Flush();
+
+ if (m_ProjectStore)
+ m_ProjectStore->Flush();
+ }
+
+ virtual void HandleStatusRequest(HttpServerRequest& Request) override
+ {
+ CbObjectWriter Cbo;
+ Cbo << "ok" << true;
+ Cbo << "state" << ToString(m_CurrentState);
+ Request.WriteResponse(HttpResponseCode::OK, Cbo.Save());
+ }
+
+private:
+ ZenServerState::ZenServerEntry* m_ServerEntry = nullptr;
+ bool m_IsDedicatedMode = false;
+ bool m_TestMode = false;
+ CbObject m_RootManifest;
+ std::filesystem::path m_DataRoot;
+ std::filesystem::path m_ContentRoot;
+ std::thread m_IoRunner;
+ asio::io_context m_IoContext;
+ asio::steady_timer m_PidCheckTimer{m_IoContext};
+ zen::ProcessMonitor m_ProcessMonitor;
+ zen::NamedMutex m_ServerMutex;
+
+ enum ServerState
+ {
+ kInitializing,
+ kRunning,
+ kShuttingDown
+ } m_CurrentState = kInitializing;
+
+ inline void SetNewState(ServerState NewState) { m_CurrentState = NewState; }
+
+ std::string_view ToString(ServerState Value)
+ {
+ switch (Value)
+ {
+ case kInitializing:
+ return "initializing"sv;
+ case kRunning:
+ return "running"sv;
+ case kShuttingDown:
+ return "shutdown"sv;
+ default:
+ return "unknown"sv;
+ }
+ }
+
+ zen::Ref<zen::HttpServer> m_Http;
+ std::unique_ptr<zen::WebSocketServer> m_WebSocket;
+ std::unique_ptr<zen::AuthMgr> m_AuthMgr;
+ std::unique_ptr<zen::HttpAuthService> m_AuthService;
+ zen::HttpStatusService m_StatusService;
+ zen::HttpStatsService m_StatsService;
+ zen::GcManager m_GcManager;
+ zen::GcScheduler m_GcScheduler{m_GcManager};
+ std::unique_ptr<zen::CidStore> m_CidStore;
+ std::unique_ptr<zen::ZenCacheStore> m_CacheStore;
+ zen::HttpTestService m_TestService;
+ zen::HttpTestingService m_TestingService;
+ std::unique_ptr<zen::HttpCidService> m_CidService;
+ zen::RefPtr<zen::ProjectStore> m_ProjectStore;
+ std::unique_ptr<zen::HttpProjectService> m_HttpProjectService;
+ std::unique_ptr<zen::UpstreamCache> m_UpstreamCache;
+ std::unique_ptr<zen::HttpUpstreamService> m_UpstreamService;
+ std::unique_ptr<zen::HttpStructuredCacheService> m_StructuredCacheService;
+ zen::HttpAdminService m_AdminService{m_GcScheduler};
+ zen::HttpHealthService m_HealthService;
+#if ZEN_WITH_COMPUTE_SERVICES
+ std::unique_ptr<zen::HttpFunctionService> m_HttpFunctionService;
+#endif // ZEN_WITH_COMPUTE_SERVICES
+ std::unique_ptr<zen::HttpFrontendService> m_FrontendService;
+ std::unique_ptr<zen::HttpObjectStoreService> m_ObjStoreService;
+
+ bool m_DebugOptionForcedCrash = false;
+ bool m_UseSentry = false;
+};
+
+void
+ZenServer::OnReady()
+{
+ m_ServerEntry->SignalReady();
+
+ if (m_IsReadyFunc)
+ {
+ m_IsReadyFunc();
+ }
+}
+
+void
+ZenServer::InitializeState(const ZenServerOptions& ServerOptions)
+{
+ // Check root manifest to deal with schema versioning
+
+ bool WipeState = false;
+ std::string WipeReason = "Unspecified";
+
+ bool UpdateManifest = false;
+ std::filesystem::path ManifestPath = m_DataRoot / "root_manifest";
+ FileContents ManifestData = zen::ReadFile(ManifestPath);
+
+ if (ManifestData.ErrorCode)
+ {
+ if (ServerOptions.IsFirstRun)
+ {
+ ZEN_INFO("Initializing state at '{}'", m_DataRoot);
+
+ UpdateManifest = true;
+ }
+ else
+ {
+ WipeState = true;
+ WipeReason = fmt::format("No manifest present at '{}'", ManifestPath);
+ }
+ }
+ else
+ {
+ IoBuffer Manifest = ManifestData.Flatten();
+
+ if (CbValidateError ValidationResult = ValidateCompactBinary(Manifest, CbValidateMode::All);
+ ValidationResult != CbValidateError::None)
+ {
+ ZEN_WARN("Manifest validation failed: {}, state will be wiped", uint32_t(ValidationResult));
+
+ WipeState = true;
+ WipeReason = fmt::format("Validation of manifest at '{}' failed: {}", ManifestPath, uint32_t(ValidationResult));
+ }
+ else
+ {
+ m_RootManifest = LoadCompactBinaryObject(Manifest);
+
+ const int32_t ManifestVersion = m_RootManifest["schema_version"].AsInt32(0);
+
+ if (ManifestVersion != ZEN_CFG_SCHEMA_VERSION)
+ {
+ WipeState = true;
+ WipeReason = fmt::format("Manifest schema version: {}, differs from required: {}", ManifestVersion, ZEN_CFG_SCHEMA_VERSION);
+ }
+ }
+ }
+
+ // Release any open handles so we can overwrite the manifest
+ ManifestData = {};
+
+ // Handle any state wipe
+
+ if (WipeState)
+ {
+ ZEN_WARN("Wiping state at '{}' - reason: '{}'", m_DataRoot, WipeReason);
+
+ std::error_code Ec;
+ for (const std::filesystem::directory_entry& DirEntry : std::filesystem::directory_iterator{m_DataRoot, Ec})
+ {
+ if (DirEntry.is_directory() && (DirEntry.path().filename() != "logs"))
+ {
+ ZEN_INFO("Deleting '{}'", DirEntry.path());
+
+ std::filesystem::remove_all(DirEntry.path(), Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("Delete of '{}' returned error: '{}'", DirEntry.path(), Ec.message());
+ }
+ }
+ }
+
+ ZEN_INFO("Wiped all directories in data root");
+
+ UpdateManifest = true;
+ }
+
+ if (UpdateManifest)
+ {
+ // Write new manifest
+
+ const DateTime Now = DateTime::Now();
+
+ CbObjectWriter Cbo;
+ Cbo << "schema_version" << ZEN_CFG_SCHEMA_VERSION << "created" << Now << "updated" << Now << "state_id" << Oid::NewOid();
+
+ m_RootManifest = Cbo.Save();
+
+ WriteFile(ManifestPath, m_RootManifest.GetBuffer().AsIoBuffer());
+ }
+}
+
+void
+ZenServer::InitializeStructuredCache(const ZenServerOptions& ServerOptions)
+{
+ using namespace std::literals;
+
+ ZEN_INFO("instantiating structured cache service");
+ m_CacheStore = std::make_unique<ZenCacheStore>(
+ m_GcManager,
+ ZenCacheStore::Configuration{.BasePath = m_DataRoot / "cache", .AllowAutomaticCreationOfNamespaces = true});
+
+ const ZenUpstreamCacheConfig& UpstreamConfig = ServerOptions.UpstreamCacheConfig;
+
+ zen::UpstreamCacheOptions UpstreamOptions;
+ UpstreamOptions.ReadUpstream = (uint8_t(ServerOptions.UpstreamCacheConfig.CachePolicy) & uint8_t(UpstreamCachePolicy::Read)) != 0;
+ UpstreamOptions.WriteUpstream = (uint8_t(ServerOptions.UpstreamCacheConfig.CachePolicy) & uint8_t(UpstreamCachePolicy::Write)) != 0;
+
+ if (UpstreamConfig.UpstreamThreadCount < 32)
+ {
+ UpstreamOptions.ThreadCount = static_cast<uint32_t>(UpstreamConfig.UpstreamThreadCount);
+ }
+
+ m_UpstreamCache = zen::UpstreamCache::Create(UpstreamOptions, *m_CacheStore, *m_CidStore);
+ m_UpstreamService = std::make_unique<HttpUpstreamService>(*m_UpstreamCache, *m_AuthMgr);
+ m_UpstreamCache->Initialize();
+
+ if (ServerOptions.UpstreamCacheConfig.CachePolicy != UpstreamCachePolicy::Disabled)
+ {
+ // Zen upstream
+ {
+ std::vector<std::string> ZenUrls = UpstreamConfig.ZenConfig.Urls;
+ if (!UpstreamConfig.ZenConfig.Dns.empty())
+ {
+ for (const std::string& Dns : UpstreamConfig.ZenConfig.Dns)
+ {
+ if (!Dns.empty())
+ {
+ const asio::error_code Err = zen::utils::ResolveHostname(m_IoContext, Dns, "1337"sv, ZenUrls);
+ if (Err)
+ {
+ ZEN_ERROR("resolve FAILED, reason '{}'", Err.message());
+ }
+ }
+ }
+ }
+
+ std::erase_if(ZenUrls, [](const auto& Url) { return Url.empty(); });
+
+ if (!ZenUrls.empty())
+ {
+ const auto ZenEndpointName = UpstreamConfig.ZenConfig.Name.empty() ? "Zen"sv : UpstreamConfig.ZenConfig.Name;
+
+ std::unique_ptr<zen::UpstreamEndpoint> ZenEndpoint = zen::UpstreamEndpoint::CreateZenEndpoint(
+ {.Name = ZenEndpointName,
+ .Urls = ZenUrls,
+ .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds),
+ .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)});
+
+ m_UpstreamCache->RegisterEndpoint(std::move(ZenEndpoint));
+ }
+ }
+
+ // Jupiter upstream
+ if (UpstreamConfig.JupiterConfig.Url.empty() == false)
+ {
+ std::string_view EndpointName = UpstreamConfig.JupiterConfig.Name.empty() ? "Jupiter"sv : UpstreamConfig.JupiterConfig.Name;
+
+ auto Options =
+ zen::CloudCacheClientOptions{.Name = EndpointName,
+ .ServiceUrl = UpstreamConfig.JupiterConfig.Url,
+ .DdcNamespace = UpstreamConfig.JupiterConfig.DdcNamespace,
+ .BlobStoreNamespace = UpstreamConfig.JupiterConfig.Namespace,
+ .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds),
+ .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)};
+
+ auto AuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.JupiterConfig.OAuthUrl,
+ .OAuthClientId = UpstreamConfig.JupiterConfig.OAuthClientId,
+ .OAuthClientSecret = UpstreamConfig.JupiterConfig.OAuthClientSecret,
+ .OpenIdProvider = UpstreamConfig.JupiterConfig.OpenIdProvider,
+ .AccessToken = UpstreamConfig.JupiterConfig.AccessToken};
+
+ std::unique_ptr<zen::UpstreamEndpoint> JupiterEndpoint =
+ zen::UpstreamEndpoint::CreateJupiterEndpoint(Options, AuthConfig, *m_AuthMgr);
+
+ m_UpstreamCache->RegisterEndpoint(std::move(JupiterEndpoint));
+ }
+ }
+
+ m_StructuredCacheService =
+ std::make_unique<HttpStructuredCacheService>(*m_CacheStore, *m_CidStore, m_StatsService, m_StatusService, *m_UpstreamCache);
+
+ m_Http->RegisterService(*m_StructuredCacheService);
+ m_Http->RegisterService(*m_UpstreamService);
+}
+
+#if ZEN_WITH_COMPUTE_SERVICES
+void
+ZenServer::InitializeCompute(const ZenServerOptions& ServerOptions)
+{
+ ServerOptions;
+ const ZenUpstreamCacheConfig& UpstreamConfig = ServerOptions.UpstreamCacheConfig;
+
+ // Horde compute upstream
+ if (UpstreamConfig.HordeConfig.Url.empty() == false && UpstreamConfig.HordeConfig.StorageUrl.empty() == false)
+ {
+ ZEN_INFO("instantiating compute service");
+
+ std::string_view EndpointName = UpstreamConfig.HordeConfig.Name.empty() ? "Horde"sv : UpstreamConfig.HordeConfig.Name;
+
+ auto ComputeOptions =
+ zen::CloudCacheClientOptions{.Name = EndpointName,
+ .ServiceUrl = UpstreamConfig.HordeConfig.Url,
+ .ComputeCluster = UpstreamConfig.HordeConfig.Cluster,
+ .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds),
+ .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)};
+
+ auto ComputeAuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.HordeConfig.OAuthUrl,
+ .OAuthClientId = UpstreamConfig.HordeConfig.OAuthClientId,
+ .OAuthClientSecret = UpstreamConfig.HordeConfig.OAuthClientSecret,
+ .OpenIdProvider = UpstreamConfig.HordeConfig.OpenIdProvider,
+ .AccessToken = UpstreamConfig.HordeConfig.AccessToken};
+
+ auto StorageOptions =
+ zen::CloudCacheClientOptions{.Name = EndpointName,
+ .ServiceUrl = UpstreamConfig.HordeConfig.StorageUrl,
+ .BlobStoreNamespace = UpstreamConfig.HordeConfig.Namespace,
+ .ConnectTimeout = std::chrono::milliseconds(UpstreamConfig.ConnectTimeoutMilliseconds),
+ .Timeout = std::chrono::milliseconds(UpstreamConfig.TimeoutMilliseconds)};
+
+ auto StorageAuthConfig = zen::UpstreamAuthConfig{.OAuthUrl = UpstreamConfig.HordeConfig.StorageOAuthUrl,
+ .OAuthClientId = UpstreamConfig.HordeConfig.StorageOAuthClientId,
+ .OAuthClientSecret = UpstreamConfig.HordeConfig.StorageOAuthClientSecret,
+ .OpenIdProvider = UpstreamConfig.HordeConfig.StorageOpenIdProvider,
+ .AccessToken = UpstreamConfig.HordeConfig.StorageAccessToken};
+
+ m_HttpFunctionService = std::make_unique<zen::HttpFunctionService>(*m_CidStore,
+ ComputeOptions,
+ StorageOptions,
+ ComputeAuthConfig,
+ StorageAuthConfig,
+ *m_AuthMgr);
+ }
+ else
+ {
+ ZEN_INFO("NOT instantiating compute service (missing Horde or Storage config)");
+ }
+}
+#endif // ZEN_WITH_COMPUTE_SERVICES
+
+////////////////////////////////////////////////////////////////////////////////
+
+class ZenEntryPoint
+{
+public:
+ ZenEntryPoint(ZenServerOptions& ServerOptions);
+ ZenEntryPoint(const ZenEntryPoint&) = delete;
+ ZenEntryPoint& operator=(const ZenEntryPoint&) = delete;
+ int Run();
+
+private:
+ ZenServerOptions& m_ServerOptions;
+ zen::LockFile m_LockFile;
+};
+
+ZenEntryPoint::ZenEntryPoint(ZenServerOptions& ServerOptions) : m_ServerOptions(ServerOptions)
+{
+}
+
+#if ZEN_USE_SENTRY
+static void
+SentryLogFunction(sentry_level_t Level, const char* Message, va_list Args, [[maybe_unused]] void* Userdata)
+{
+ char LogMessageBuffer[160];
+ std::string LogMessage;
+ const char* MessagePtr = LogMessageBuffer;
+
+ int n = vsnprintf(LogMessageBuffer, sizeof LogMessageBuffer, Message, Args);
+
+ if (n >= int(sizeof LogMessageBuffer))
+ {
+ LogMessage.resize(n + 1);
+
+ n = vsnprintf(LogMessage.data(), LogMessage.size(), Message, Args);
+
+ MessagePtr = LogMessage.c_str();
+ }
+
+ switch (Level)
+ {
+ case SENTRY_LEVEL_DEBUG:
+ ConsoleLog().debug("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_INFO:
+ ConsoleLog().info("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_WARNING:
+ ConsoleLog().warn("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_ERROR:
+ ConsoleLog().error("sentry: {}", MessagePtr);
+ break;
+
+ case SENTRY_LEVEL_FATAL:
+ ConsoleLog().critical("sentry: {}", MessagePtr);
+ break;
+ }
+}
+#endif
+
+int
+ZenEntryPoint::Run()
+{
+#if ZEN_USE_SENTRY
+ std::string SentryDatabasePath = PathToUtf8(m_ServerOptions.DataDir / ".sentry-native");
+ int SentryErrorCode = 0;
+ if (m_ServerOptions.NoSentry == false)
+ {
+ sentry_options_t* SentryOptions = sentry_options_new();
+ sentry_options_set_dsn(SentryOptions, "https://[email protected]/5919284");
+ if (SentryDatabasePath.starts_with("\\\\?\\"))
+ {
+ SentryDatabasePath = SentryDatabasePath.substr(4);
+ }
+ sentry_options_set_database_path(SentryOptions, SentryDatabasePath.c_str());
+ sentry_options_set_logger(SentryOptions, SentryLogFunction, this);
+ std::string SentryAttachmentPath = m_ServerOptions.AbsLogFile.string();
+ if (SentryAttachmentPath.starts_with("\\\\?\\"))
+ {
+ SentryAttachmentPath = SentryAttachmentPath.substr(4);
+ }
+ sentry_options_add_attachment(SentryOptions, SentryAttachmentPath.c_str());
+ sentry_options_set_release(SentryOptions, ZEN_CFG_VERSION);
+ // sentry_options_set_debug(SentryOptions, 1);
+
+ SentryErrorCode = sentry_init(SentryOptions);
+
+ auto SentrySink = spdlog::create<utils::sentry_sink>("sentry");
+ zen::logging::SetErrorLog(std::move(SentrySink));
+ }
+
+ auto _ = zen::MakeGuard([] {
+ zen::logging::SetErrorLog(std::shared_ptr<spdlog::logger>());
+ sentry_close();
+ });
+#endif
+
+ auto& ServerOptions = m_ServerOptions;
+
+ try
+ {
+ // Mutual exclusion and synchronization
+ ZenServerState ServerState;
+ ServerState.Initialize();
+ ServerState.Sweep();
+
+ ZenServerState::ZenServerEntry* Entry = ServerState.Lookup(ServerOptions.BasePort);
+
+ if (Entry)
+ {
+ if (ServerOptions.OwnerPid)
+ {
+ ConsoleLog().info(
+ "Looks like there is already a process listening to this port {} (pid: {}), attaching owner pid {} to running instance",
+ ServerOptions.BasePort,
+ Entry->Pid,
+ ServerOptions.OwnerPid);
+
+ Entry->AddSponsorProcess(ServerOptions.OwnerPid);
+
+ std::exit(0);
+ }
+ else
+ {
+ ConsoleLog().warn("Exiting since there is already a process listening to port {} (pid: {})",
+ ServerOptions.BasePort,
+ Entry->Pid);
+ std::exit(1);
+ }
+ }
+
+ std::error_code Ec;
+
+ std::filesystem::path LockFilePath = ServerOptions.DataDir / ".lock";
+
+ bool IsReady = false;
+
+ auto MakeLockData = [&] {
+ CbObjectWriter Cbo;
+ Cbo << "pid" << zen::GetCurrentProcessId() << "data" << PathToUtf8(ServerOptions.DataDir) << "port" << ServerOptions.BasePort
+ << "session_id" << GetSessionId() << "ready" << IsReady;
+ return Cbo.Save();
+ };
+
+ m_LockFile.Create(LockFilePath, MakeLockData(), Ec);
+
+ if (Ec)
+ {
+ ConsoleLog().warn("ERROR: Unable to grab lock at '{}' (error: '{}')", LockFilePath, Ec.message());
+
+ std::exit(99);
+ }
+
+ InitializeLogging(ServerOptions);
+
+#if ZEN_USE_SENTRY
+ if (m_ServerOptions.NoSentry == false)
+ {
+ if (SentryErrorCode == 0)
+ {
+ ZEN_INFO("sentry initialized");
+ }
+ else
+ {
+ ZEN_WARN("sentry_init returned failure! (error code: {})", SentryErrorCode);
+ }
+ }
+#endif
+
+ MaximizeOpenFileCount();
+
+ ZEN_INFO(ZEN_APP_NAME " - using lock file at '{}'", LockFilePath);
+
+ ZEN_INFO(ZEN_APP_NAME " - starting on port {}, version '{}'", ServerOptions.BasePort, ZEN_CFG_VERSION_BUILD_STRING_FULL);
+
+ Entry = ServerState.Register(ServerOptions.BasePort);
+
+ if (ServerOptions.OwnerPid)
+ {
+ Entry->AddSponsorProcess(ServerOptions.OwnerPid);
+ }
+
+ ZenServer Server;
+ Server.SetDataRoot(ServerOptions.DataDir);
+ Server.SetContentRoot(ServerOptions.ContentDir);
+ Server.SetTestMode(ServerOptions.IsTest);
+ Server.SetDedicatedMode(ServerOptions.IsDedicated);
+
+ int EffectiveBasePort = Server.Initialize(ServerOptions, Entry);
+
+ Entry->EffectiveListenPort = uint16_t(EffectiveBasePort);
+ if (EffectiveBasePort != ServerOptions.BasePort)
+ {
+ ZEN_INFO(ZEN_APP_NAME " - relocated to base port {}", EffectiveBasePort);
+ ServerOptions.BasePort = EffectiveBasePort;
+ }
+
+ std::unique_ptr<std::thread> ShutdownThread;
+ std::unique_ptr<zen::NamedEvent> ShutdownEvent;
+
+ zen::ExtendableStringBuilder<64> ShutdownEventName;
+ ShutdownEventName << "Zen_" << ServerOptions.BasePort << "_Shutdown";
+ ShutdownEvent.reset(new zen::NamedEvent{ShutdownEventName});
+
+ // Monitor shutdown signals
+
+ ShutdownThread.reset(new std::thread{[&] {
+ ZEN_INFO("shutdown monitor thread waiting for shutdown signal '{}'", ShutdownEventName);
+ if (ShutdownEvent->Wait())
+ {
+ ZEN_INFO("shutdown signal received");
+ Server.RequestExit(0);
+ }
+ else
+ {
+ ZEN_INFO("shutdown signal wait() failed");
+ }
+ }});
+
+ // If we have a parent process, establish the mechanisms we need
+ // to be able to communicate readiness with the parent
+
+ Server.SetIsReadyFunc([&] {
+ IsReady = true;
+
+ m_LockFile.Update(MakeLockData(), Ec);
+
+ if (!ServerOptions.ChildId.empty())
+ {
+ zen::NamedEvent ParentEvent{ServerOptions.ChildId};
+ ParentEvent.Set();
+ }
+ });
+
+ Server.Run();
+ Server.Cleanup();
+
+ ShutdownEvent->Set();
+ ShutdownThread->join();
+ }
+ catch (std::exception& e)
+ {
+ SPDLOG_CRITICAL("Caught exception in main: {}", e.what());
+ }
+
+ ShutdownLogging();
+
+ return 0;
+}
+
+} // namespace zen
+
+////////////////////////////////////////////////////////////////////////////////
+
+#if ZEN_PLATFORM_WINDOWS
+
+class ZenWindowsService : public WindowsService
+{
+public:
+ ZenWindowsService(ZenServerOptions& ServerOptions) : m_EntryPoint(ServerOptions) {}
+
+ ZenWindowsService(const ZenWindowsService&) = delete;
+ ZenWindowsService& operator=(const ZenWindowsService&) = delete;
+
+ virtual int Run() override;
+
+private:
+ zen::ZenEntryPoint m_EntryPoint;
+};
+
+int
+ZenWindowsService::Run()
+{
+ return m_EntryPoint.Run();
+}
+
+#endif // ZEN_PLATFORM_WINDOWS
+
+////////////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+int
+test_main(int argc, char** argv)
+{
+ zen::zencore_forcelinktests();
+ zen::zenhttp_forcelinktests();
+ zen::zenstore_forcelinktests();
+ zen::z$_forcelink();
+ zen::z$service_forcelink();
+
+ zen::logging::InitializeLogging();
+ spdlog::set_level(spdlog::level::debug);
+
+ zen::MaximizeOpenFileCount();
+
+ return ZEN_RUN_TESTS(argc, argv);
+}
+#endif
+
+int
+main(int argc, char* argv[])
+{
+ using namespace zen;
+
+#if ZEN_USE_MIMALLOC
+ mi_version();
+#endif
+
+#if ZEN_WITH_TESTS
+ if (argc >= 2)
+ {
+ if (argv[1] == "test"sv)
+ {
+ return test_main(argc, argv);
+ }
+ }
+#endif
+
+ try
+ {
+ ZenServerOptions ServerOptions;
+ ParseCliOptions(argc, argv, ServerOptions);
+
+ if (!std::filesystem::exists(ServerOptions.DataDir))
+ {
+ ServerOptions.IsFirstRun = true;
+ std::filesystem::create_directories(ServerOptions.DataDir);
+ }
+
+#if ZEN_WITH_TRACE
+ if (ServerOptions.TraceHost.size())
+ {
+ TraceInit(ServerOptions.TraceHost.c_str(), TraceType::Network);
+ }
+ else if (ServerOptions.TraceFile.size())
+ {
+ TraceInit(ServerOptions.TraceFile.c_str(), TraceType::File);
+ }
+ else
+ {
+ TraceInit(nullptr, TraceType::None);
+ }
+#endif // ZEN_WITH_TRACE
+
+#if ZEN_PLATFORM_WINDOWS
+ if (ServerOptions.InstallService)
+ {
+ WindowsService::Install();
+
+ std::exit(0);
+ }
+
+ if (ServerOptions.UninstallService)
+ {
+ WindowsService::Delete();
+
+ std::exit(0);
+ }
+
+ ZenWindowsService App(ServerOptions);
+ return App.ServiceMain();
+#else
+ if (ServerOptions.InstallService || ServerOptions.UninstallService)
+ {
+ throw std::runtime_error("Service mode is not supported on this platform");
+ }
+
+ ZenEntryPoint App(ServerOptions);
+ return App.Run();
+#endif // ZEN_PLATFORM_WINDOWS
+ }
+ catch (std::exception& Ex)
+ {
+ fprintf(stderr, "ERROR: Caught exception in main: '%s'", Ex.what());
+
+ return 1;
+ }
+}
diff --git a/src/zenserver/zenserver.rc b/src/zenserver/zenserver.rc
new file mode 100644
index 000000000..6d31e2c6e
--- /dev/null
+++ b/src/zenserver/zenserver.rc
@@ -0,0 +1,105 @@
+// Microsoft Visual C++ generated resource script.
+//
+#include "resource.h"
+
+#include "zencore/config.h"
+
+#define APSTUDIO_READONLY_SYMBOLS
+/////////////////////////////////////////////////////////////////////////////
+//
+// Generated from the TEXTINCLUDE 2 resource.
+//
+#include "winres.h"
+
+/////////////////////////////////////////////////////////////////////////////
+#undef APSTUDIO_READONLY_SYMBOLS
+
+/////////////////////////////////////////////////////////////////////////////
+// English (United States) resources
+
+#if !defined(AFX_RESOURCE_DLL) || defined(AFX_TARG_ENU)
+LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_US
+#pragma code_page(1252)
+
+/////////////////////////////////////////////////////////////////////////////
+//
+// Icon
+//
+
+// Icon with lowest ID value placed first to ensure application icon
+// remains consistent on all systems.
+IDI_ICON1 ICON "..\\UnrealEngine.ico"
+
+#endif // English (United States) resources
+/////////////////////////////////////////////////////////////////////////////
+
+
+/////////////////////////////////////////////////////////////////////////////
+// English (United Kingdom) resources
+
+#if !defined(AFX_RESOURCE_DLL) || defined(AFX_TARG_ENG)
+LANGUAGE LANG_ENGLISH, SUBLANG_ENGLISH_UK
+#pragma code_page(1252)
+
+#ifdef APSTUDIO_INVOKED
+/////////////////////////////////////////////////////////////////////////////
+//
+// TEXTINCLUDE
+//
+
+1 TEXTINCLUDE
+BEGIN
+ "resource.h\0"
+END
+
+2 TEXTINCLUDE
+BEGIN
+ "#include ""winres.h""\r\n"
+ "\0"
+END
+
+3 TEXTINCLUDE
+BEGIN
+ "\r\n"
+ "\0"
+END
+
+#endif // APSTUDIO_INVOKED
+
+#endif // English (United Kingdom) resources
+/////////////////////////////////////////////////////////////////////////////
+
+
+
+#ifndef APSTUDIO_INVOKED
+/////////////////////////////////////////////////////////////////////////////
+//
+// Generated from the TEXTINCLUDE 3 resource.
+//
+
+
+/////////////////////////////////////////////////////////////////////////////
+#endif // not APSTUDIO_INVOKED
+
+VS_VERSION_INFO VERSIONINFO
+FILEVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0
+PRODUCTVERSION ZEN_CFG_VERSION_MAJOR,ZEN_CFG_VERSION_MINOR,ZEN_CFG_VERSION_ALTER,0
+{
+ BLOCK "StringFileInfo"
+ {
+ BLOCK "040904b0"
+ {
+ VALUE "CompanyName", "Epic Games Inc\0"
+ VALUE "FileDescription", "Local Storage Service for Unreal Engine\0"
+ VALUE "FileVersion", ZEN_CFG_VERSION "\0"
+ VALUE "LegalCopyright", "Copyright Epic Games Inc. All Rights Reserved\0"
+ VALUE "OriginalFilename", "zenserver.exe\0"
+ VALUE "ProductName", "Zen Storage Server\0"
+ VALUE "ProductVersion", ZEN_CFG_VERSION_BUILD_STRING_FULL "\0"
+ }
+ }
+ BLOCK "VarFileInfo"
+ {
+ VALUE "Translation", 0x409, 1200
+ }
+}
diff --git a/src/zenstore-test/xmake.lua b/src/zenstore-test/xmake.lua
new file mode 100644
index 000000000..5dbcafa3c
--- /dev/null
+++ b/src/zenstore-test/xmake.lua
@@ -0,0 +1,8 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target("zenstore-test")
+ set_kind("binary")
+ add_headerfiles("**.h")
+ add_files("*.cpp")
+ add_deps("zenstore", "zencore")
+ add_packages("vcpkg::doctest")
diff --git a/src/zenstore-test/zenstore-test.cpp b/src/zenstore-test/zenstore-test.cpp
new file mode 100644
index 000000000..00c1136b6
--- /dev/null
+++ b/src/zenstore-test/zenstore-test.cpp
@@ -0,0 +1,32 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zencore/filesystem.h>
+#include <zencore/logging.h>
+#include <zencore/zencore.h>
+#include <zenstore/zenstore.h>
+
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+# include <sys/time.h>
+# include <sys/resource.h>
+# include <zencore/except.h>
+#endif
+
+#if ZEN_WITH_TESTS
+# define ZEN_TEST_WITH_RUNNER 1
+# include <zencore/testing.h>
+#endif
+
+int
+main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[])
+{
+#if ZEN_WITH_TESTS
+ zen::zenstore_forcelinktests();
+
+ zen::logging::InitializeLogging();
+ zen::MaximizeOpenFileCount();
+
+ return ZEN_RUN_TESTS(argc, argv);
+#else
+ return 0;
+#endif
+}
diff --git a/src/zenstore/blockstore.cpp b/src/zenstore/blockstore.cpp
new file mode 100644
index 000000000..5dfa10c91
--- /dev/null
+++ b/src/zenstore/blockstore.cpp
@@ -0,0 +1,1312 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenstore/blockstore.h>
+
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/timer.h>
+
+#include <algorithm>
+
+#if ZEN_WITH_TESTS
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+# include <zencore/workthreadpool.h>
+# include <random>
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+BlockStoreFile::BlockStoreFile(const std::filesystem::path& BlockPath) : m_Path(BlockPath)
+{
+}
+
+BlockStoreFile::~BlockStoreFile()
+{
+ m_IoBuffer = IoBuffer();
+ m_File.Detach();
+}
+
+const std::filesystem::path&
+BlockStoreFile::GetPath() const
+{
+ return m_Path;
+}
+
+void
+BlockStoreFile::Open()
+{
+ m_File.Open(m_Path, BasicFile::Mode::kDelete);
+ void* FileHandle = m_File.Handle();
+ m_IoBuffer = IoBuffer(IoBuffer::File, FileHandle, 0, m_File.FileSize());
+}
+
+void
+BlockStoreFile::Create(uint64_t InitialSize)
+{
+ auto ParentPath = m_Path.parent_path();
+ if (!std::filesystem::is_directory(ParentPath))
+ {
+ CreateDirectories(ParentPath);
+ }
+
+ m_File.Open(m_Path, BasicFile::Mode::kTruncateDelete);
+ void* FileHandle = m_File.Handle();
+
+ // We map our m_IoBuffer beyond the file size as we will grow it over time and want
+ // to be able to create sub-buffers of all the written range later
+ m_IoBuffer = IoBuffer(IoBuffer::File, FileHandle, 0, InitialSize);
+}
+
+uint64_t
+BlockStoreFile::FileSize()
+{
+ return m_File.FileSize();
+}
+
+void
+BlockStoreFile::MarkAsDeleteOnClose()
+{
+ m_IoBuffer.MarkAsDeleteOnClose();
+}
+
+IoBuffer
+BlockStoreFile::GetChunk(uint64_t Offset, uint64_t Size)
+{
+ return IoBuffer(m_IoBuffer, Offset, Size);
+}
+
+void
+BlockStoreFile::Read(void* Data, uint64_t Size, uint64_t FileOffset)
+{
+ m_File.Read(Data, Size, FileOffset);
+}
+
+void
+BlockStoreFile::Write(const void* Data, uint64_t Size, uint64_t FileOffset)
+{
+ m_File.Write(Data, Size, FileOffset);
+}
+
+void
+BlockStoreFile::Flush()
+{
+ m_File.Flush();
+}
+
+BasicFile&
+BlockStoreFile::GetBasicFile()
+{
+ return m_File;
+}
+
+void
+BlockStoreFile::StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun)
+{
+ m_File.StreamByteRange(FileOffset, Size, std::move(ChunkFun));
+}
+
+constexpr uint64_t ScrubSmallChunkWindowSize = 4 * 1024 * 1024;
+
+void
+BlockStore::Initialize(const std::filesystem::path& BlocksBasePath,
+ uint64_t MaxBlockSize,
+ uint64_t MaxBlockCount,
+ const std::vector<BlockStoreLocation>& KnownLocations)
+{
+ ZEN_ASSERT(MaxBlockSize > 0);
+ ZEN_ASSERT(MaxBlockCount > 0);
+ ZEN_ASSERT(IsPow2(MaxBlockCount));
+
+ m_TotalSize = 0;
+ m_BlocksBasePath = BlocksBasePath;
+ m_MaxBlockSize = MaxBlockSize;
+
+ m_ChunkBlocks.clear();
+
+ std::unordered_set<uint32_t> KnownBlocks;
+ for (const auto& Entry : KnownLocations)
+ {
+ KnownBlocks.insert(Entry.BlockIndex);
+ }
+
+ if (std::filesystem::is_directory(m_BlocksBasePath))
+ {
+ std::vector<std::filesystem::path> FoldersToScan;
+ FoldersToScan.push_back(m_BlocksBasePath);
+ size_t FolderOffset = 0;
+ while (FolderOffset < FoldersToScan.size())
+ {
+ for (const std::filesystem::directory_entry& Entry : std::filesystem::directory_iterator(FoldersToScan[FolderOffset]))
+ {
+ if (Entry.is_directory())
+ {
+ FoldersToScan.push_back(Entry.path());
+ continue;
+ }
+ if (Entry.is_regular_file())
+ {
+ const std::filesystem::path Path = Entry.path();
+ if (Path.extension() != GetBlockFileExtension())
+ {
+ continue;
+ }
+ std::string FileName = PathToUtf8(Path.stem());
+ uint32_t BlockIndex;
+ bool OK = ParseHexNumber(FileName, BlockIndex);
+ if (!OK)
+ {
+ continue;
+ }
+ if (!KnownBlocks.contains(BlockIndex))
+ {
+ // Log removing unreferenced block
+ // Clear out unused blocks
+ ZEN_DEBUG("removing unused block at '{}'", Path);
+ std::error_code Ec;
+ std::filesystem::remove(Path, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("Failed to delete file '{}' reason: '{}'", Path, Ec.message());
+ }
+ continue;
+ }
+ Ref<BlockStoreFile> BlockFile{new BlockStoreFile(Path)};
+ BlockFile->Open();
+ m_TotalSize.fetch_add(BlockFile->FileSize(), std::memory_order::relaxed);
+ m_ChunkBlocks[BlockIndex] = BlockFile;
+ }
+ }
+ ++FolderOffset;
+ }
+ }
+ else
+ {
+ CreateDirectories(m_BlocksBasePath);
+ }
+}
+
+void
+BlockStore::Close()
+{
+ RwLock::ExclusiveLockScope InsertLock(m_InsertLock);
+ m_WriteBlock = nullptr;
+ m_CurrentInsertOffset = 0;
+ m_WriteBlockIndex = 0;
+
+ m_ChunkBlocks.clear();
+ m_BlocksBasePath.clear();
+}
+
+void
+BlockStore::WriteChunk(const void* Data, uint64_t Size, uint64_t Alignment, const WriteChunkCallback& Callback)
+{
+ ZEN_ASSERT(Data != nullptr);
+ ZEN_ASSERT(Size > 0u);
+ ZEN_ASSERT(Size <= m_MaxBlockSize);
+ ZEN_ASSERT(Alignment > 0u);
+
+ RwLock::ExclusiveLockScope InsertLock(m_InsertLock);
+
+ uint32_t WriteBlockIndex = m_WriteBlockIndex.load(std::memory_order_acquire);
+ bool IsWriting = !!m_WriteBlock;
+ if (!IsWriting || (m_CurrentInsertOffset + Size) > m_MaxBlockSize)
+ {
+ if (m_WriteBlock)
+ {
+ m_WriteBlock = nullptr;
+ }
+
+ if (m_ChunkBlocks.size() == m_MaxBlockCount)
+ {
+ throw std::runtime_error(fmt::format("unable to allocate a new block in '{}'", m_BlocksBasePath));
+ }
+
+ WriteBlockIndex += IsWriting ? 1 : 0;
+ while (m_ChunkBlocks.contains(WriteBlockIndex))
+ {
+ WriteBlockIndex = (WriteBlockIndex + 1) & (m_MaxBlockCount - 1);
+ }
+
+ std::filesystem::path BlockPath = GetBlockPath(m_BlocksBasePath, WriteBlockIndex);
+
+ Ref<BlockStoreFile> NewBlockFile(new BlockStoreFile(BlockPath));
+ NewBlockFile->Create(m_MaxBlockSize);
+
+ m_ChunkBlocks[WriteBlockIndex] = NewBlockFile;
+ m_WriteBlock = NewBlockFile;
+ m_WriteBlockIndex.store(WriteBlockIndex, std::memory_order_release);
+ m_CurrentInsertOffset = 0;
+ }
+ uint64_t InsertOffset = m_CurrentInsertOffset;
+ m_CurrentInsertOffset = RoundUp(InsertOffset + Size, Alignment);
+ uint64_t AlignedWriteSize = m_CurrentInsertOffset - InsertOffset;
+ Ref<BlockStoreFile> WriteBlock = m_WriteBlock;
+ m_ActiveWriteBlocks.push_back(WriteBlockIndex);
+ InsertLock.ReleaseNow();
+
+ WriteBlock->Write(Data, Size, InsertOffset);
+ m_TotalSize.fetch_add(AlignedWriteSize, std::memory_order::relaxed);
+
+ Callback({.BlockIndex = WriteBlockIndex, .Offset = InsertOffset, .Size = Size});
+
+ {
+ RwLock::ExclusiveLockScope _(m_InsertLock);
+ m_ActiveWriteBlocks.erase(std::find(m_ActiveWriteBlocks.begin(), m_ActiveWriteBlocks.end(), WriteBlockIndex));
+ }
+}
+
+BlockStore::ReclaimSnapshotState
+BlockStore::GetReclaimSnapshotState()
+{
+ ReclaimSnapshotState State;
+ RwLock::SharedLockScope _(m_InsertLock);
+ for (uint32_t BlockIndex : m_ActiveWriteBlocks)
+ {
+ State.m_ActiveWriteBlocks.insert(BlockIndex);
+ }
+ if (m_WriteBlock)
+ {
+ State.m_ActiveWriteBlocks.insert(m_WriteBlockIndex);
+ }
+ State.BlockCount = m_ChunkBlocks.size();
+ return State;
+}
+
+IoBuffer
+BlockStore::TryGetChunk(const BlockStoreLocation& Location) const
+{
+ RwLock::SharedLockScope InsertLock(m_InsertLock);
+ if (auto BlockIt = m_ChunkBlocks.find(Location.BlockIndex); BlockIt != m_ChunkBlocks.end())
+ {
+ if (const Ref<BlockStoreFile>& Block = BlockIt->second; Block)
+ {
+ return Block->GetChunk(Location.Offset, Location.Size);
+ }
+ }
+ return IoBuffer();
+}
+
+void
+BlockStore::Flush()
+{
+ RwLock::ExclusiveLockScope _(m_InsertLock);
+ if (m_CurrentInsertOffset > 0)
+ {
+ uint32_t WriteBlockIndex = m_WriteBlockIndex.load(std::memory_order_acquire);
+ WriteBlockIndex = (WriteBlockIndex + 1) & (m_MaxBlockCount - 1);
+ m_WriteBlock = nullptr;
+ m_WriteBlockIndex.store(WriteBlockIndex, std::memory_order_release);
+ m_CurrentInsertOffset = 0;
+ }
+}
+
+void
+BlockStore::ReclaimSpace(const ReclaimSnapshotState& Snapshot,
+ const std::vector<BlockStoreLocation>& ChunkLocations,
+ const ChunkIndexArray& KeepChunkIndexes,
+ uint64_t PayloadAlignment,
+ bool DryRun,
+ const ReclaimCallback& ChangeCallback,
+ const ClaimDiskReserveCallback& DiskReserveCallback)
+{
+ if (ChunkLocations.empty())
+ {
+ return;
+ }
+ uint64_t WriteBlockTimeUs = 0;
+ uint64_t WriteBlockLongestTimeUs = 0;
+ uint64_t ReadBlockTimeUs = 0;
+ uint64_t ReadBlockLongestTimeUs = 0;
+ uint64_t TotalChunkCount = ChunkLocations.size();
+ uint64_t DeletedSize = 0;
+ uint64_t OldTotalSize = 0;
+ uint64_t NewTotalSize = 0;
+
+ uint64_t MovedCount = 0;
+ uint64_t DeletedCount = 0;
+
+ Stopwatch TotalTimer;
+ const auto _ = MakeGuard([&] {
+ ZEN_DEBUG(
+ "reclaim space for '{}' DONE after {}, write lock: {} ({}), read lock: {} ({}), collected {} bytes, deleted {} and moved "
+ "{} "
+ "of {} "
+ "chunks ({}).",
+ m_BlocksBasePath,
+ NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()),
+ NiceLatencyNs(WriteBlockTimeUs),
+ NiceLatencyNs(WriteBlockLongestTimeUs),
+ NiceLatencyNs(ReadBlockTimeUs),
+ NiceLatencyNs(ReadBlockLongestTimeUs),
+ NiceBytes(DeletedSize),
+ DeletedCount,
+ MovedCount,
+ TotalChunkCount,
+ NiceBytes(OldTotalSize));
+ });
+
+ size_t BlockCount = Snapshot.BlockCount;
+ if (BlockCount == 0)
+ {
+ ZEN_DEBUG("garbage collect for '{}' SKIPPED, no blocks to process", m_BlocksBasePath);
+ return;
+ }
+
+ std::unordered_set<size_t> KeepChunkMap;
+ KeepChunkMap.reserve(KeepChunkIndexes.size());
+ for (size_t KeepChunkIndex : KeepChunkIndexes)
+ {
+ KeepChunkMap.insert(KeepChunkIndex);
+ }
+
+ std::unordered_map<uint32_t, size_t> BlockIndexToChunkMapIndex;
+ std::vector<ChunkIndexArray> BlockKeepChunks;
+ std::vector<ChunkIndexArray> BlockDeleteChunks;
+
+ BlockIndexToChunkMapIndex.reserve(BlockCount);
+ BlockKeepChunks.reserve(BlockCount);
+ BlockDeleteChunks.reserve(BlockCount);
+ size_t GuesstimateCountPerBlock = TotalChunkCount / BlockCount / 2;
+
+ size_t DeleteCount = 0;
+ for (size_t Index = 0; Index < TotalChunkCount; ++Index)
+ {
+ const BlockStoreLocation& Location = ChunkLocations[Index];
+ OldTotalSize += Location.Size;
+ if (Snapshot.m_ActiveWriteBlocks.contains(Location.BlockIndex))
+ {
+ continue;
+ }
+
+ auto BlockIndexPtr = BlockIndexToChunkMapIndex.find(Location.BlockIndex);
+ size_t ChunkMapIndex = 0;
+ if (BlockIndexPtr == BlockIndexToChunkMapIndex.end())
+ {
+ ChunkMapIndex = BlockKeepChunks.size();
+ BlockIndexToChunkMapIndex[Location.BlockIndex] = ChunkMapIndex;
+ BlockKeepChunks.resize(ChunkMapIndex + 1);
+ BlockKeepChunks.back().reserve(GuesstimateCountPerBlock);
+ BlockDeleteChunks.resize(ChunkMapIndex + 1);
+ BlockDeleteChunks.back().reserve(GuesstimateCountPerBlock);
+ }
+ else
+ {
+ ChunkMapIndex = BlockIndexPtr->second;
+ }
+
+ if (KeepChunkMap.contains(Index))
+ {
+ ChunkIndexArray& IndexMap = BlockKeepChunks[ChunkMapIndex];
+ IndexMap.push_back(Index);
+ NewTotalSize += Location.Size;
+ continue;
+ }
+ ChunkIndexArray& IndexMap = BlockDeleteChunks[ChunkMapIndex];
+ IndexMap.push_back(Index);
+ DeleteCount++;
+ }
+
+ std::unordered_set<uint32_t> BlocksToReWrite;
+ BlocksToReWrite.reserve(BlockIndexToChunkMapIndex.size());
+ for (const auto& Entry : BlockIndexToChunkMapIndex)
+ {
+ uint32_t BlockIndex = Entry.first;
+ size_t ChunkMapIndex = Entry.second;
+ const ChunkIndexArray& ChunkMap = BlockDeleteChunks[ChunkMapIndex];
+ if (ChunkMap.empty())
+ {
+ continue;
+ }
+ BlocksToReWrite.insert(BlockIndex);
+ }
+
+ if (DryRun)
+ {
+ ZEN_DEBUG("garbage collect for '{}' DISABLED, found {} {} chunks of total {} {}",
+ m_BlocksBasePath,
+ DeleteCount,
+ NiceBytes(OldTotalSize - NewTotalSize),
+ TotalChunkCount,
+ OldTotalSize);
+ return;
+ }
+
+ Ref<BlockStoreFile> NewBlockFile;
+ try
+ {
+ uint64_t WriteOffset = 0;
+ uint32_t NewBlockIndex = 0;
+ for (uint32_t BlockIndex : BlocksToReWrite)
+ {
+ const size_t ChunkMapIndex = BlockIndexToChunkMapIndex[BlockIndex];
+
+ Ref<BlockStoreFile> OldBlockFile;
+ {
+ RwLock::SharedLockScope _i(m_InsertLock);
+ Stopwatch Timer;
+ const auto __ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ OldBlockFile = m_ChunkBlocks[BlockIndex];
+ }
+
+ if (!OldBlockFile)
+ {
+ // If the block file pointed to does not exist, move them all to deleted list
+ BlockDeleteChunks[ChunkMapIndex].insert(BlockDeleteChunks[ChunkMapIndex].end(),
+ BlockKeepChunks[ChunkMapIndex].begin(),
+ BlockKeepChunks[ChunkMapIndex].end());
+ BlockKeepChunks[ChunkMapIndex].clear();
+ }
+
+ const ChunkIndexArray& KeepMap = BlockKeepChunks[ChunkMapIndex];
+ if (KeepMap.empty())
+ {
+ const ChunkIndexArray& DeleteMap = BlockDeleteChunks[ChunkMapIndex];
+ for (size_t DeleteIndex : DeleteMap)
+ {
+ DeletedSize += ChunkLocations[DeleteIndex].Size;
+ }
+ ChangeCallback({}, DeleteMap);
+ DeletedCount += DeleteMap.size();
+ {
+ RwLock::ExclusiveLockScope _i(m_InsertLock);
+ Stopwatch Timer;
+ const auto __ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ ReadBlockTimeUs += ElapsedUs;
+ ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs);
+ });
+ if (OldBlockFile)
+ {
+ m_ChunkBlocks[BlockIndex] = nullptr;
+ ZEN_DEBUG("marking cas block store file '{}' for delete, block #{}", OldBlockFile->GetPath(), BlockIndex);
+ m_TotalSize.fetch_sub(OldBlockFile->FileSize(), std::memory_order::relaxed);
+ OldBlockFile->MarkAsDeleteOnClose();
+ }
+ }
+ continue;
+ }
+
+ ZEN_ASSERT(OldBlockFile);
+
+ MovedChunksArray MovedChunks;
+ std::vector<uint8_t> Chunk;
+ for (const size_t& ChunkIndex : KeepMap)
+ {
+ const BlockStoreLocation ChunkLocation = ChunkLocations[ChunkIndex];
+ Chunk.resize(ChunkLocation.Size);
+ OldBlockFile->Read(Chunk.data(), Chunk.size(), ChunkLocation.Offset);
+
+ if (!NewBlockFile || (WriteOffset + Chunk.size() > m_MaxBlockSize))
+ {
+ uint32_t NextBlockIndex = m_WriteBlockIndex.load(std::memory_order_relaxed);
+
+ if (NewBlockFile)
+ {
+ NewBlockFile->Flush();
+ NewBlockFile = nullptr;
+ }
+ {
+ ChangeCallback(MovedChunks, {});
+ MovedCount += KeepMap.size();
+ MovedChunks.clear();
+ RwLock::ExclusiveLockScope __(m_InsertLock);
+ Stopwatch Timer;
+ const auto ___ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ ReadBlockTimeUs += ElapsedUs;
+ ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs);
+ });
+ if (m_ChunkBlocks.size() == m_MaxBlockCount)
+ {
+ ZEN_ERROR("unable to allocate a new block in '{}', count limit {} exeeded",
+ m_BlocksBasePath,
+ static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1);
+ return;
+ }
+ while (m_ChunkBlocks.contains(NextBlockIndex))
+ {
+ NextBlockIndex = (NextBlockIndex + 1) & (m_MaxBlockCount - 1);
+ }
+ std::filesystem::path NewBlockPath = GetBlockPath(m_BlocksBasePath, NextBlockIndex);
+ NewBlockFile = new BlockStoreFile(NewBlockPath);
+ m_ChunkBlocks[NextBlockIndex] = NewBlockFile;
+ }
+
+ std::error_code Error;
+ DiskSpace Space = DiskSpaceInfo(m_BlocksBasePath, Error);
+ if (Error)
+ {
+ ZEN_ERROR("get disk space in '{}' FAILED, reason: '{}'", m_BlocksBasePath, Error.message());
+ return;
+ }
+ if (Space.Free < m_MaxBlockSize)
+ {
+ uint64_t ReclaimedSpace = DiskReserveCallback();
+ if (Space.Free + ReclaimedSpace < m_MaxBlockSize)
+ {
+ ZEN_WARN("garbage collect for '{}' FAILED, required disk space {}, free {}",
+ m_BlocksBasePath,
+ m_MaxBlockSize,
+ NiceBytes(Space.Free + ReclaimedSpace));
+ RwLock::ExclusiveLockScope _l(m_InsertLock);
+ Stopwatch Timer;
+ const auto __ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ ReadBlockTimeUs += ElapsedUs;
+ ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs);
+ });
+ m_ChunkBlocks.erase(NextBlockIndex);
+ return;
+ }
+
+ ZEN_INFO("using gc reserve for '{}', reclaimed {}, disk free {}",
+ m_BlocksBasePath,
+ ReclaimedSpace,
+ NiceBytes(Space.Free + ReclaimedSpace));
+ }
+ NewBlockFile->Create(m_MaxBlockSize);
+ NewBlockIndex = NextBlockIndex;
+ WriteOffset = 0;
+ }
+
+ NewBlockFile->Write(Chunk.data(), Chunk.size(), WriteOffset);
+ MovedChunks.push_back({ChunkIndex, {.BlockIndex = NewBlockIndex, .Offset = WriteOffset, .Size = Chunk.size()}});
+ uint64_t OldOffset = WriteOffset;
+ WriteOffset = RoundUp(WriteOffset + Chunk.size(), PayloadAlignment);
+ m_TotalSize.fetch_add(WriteOffset - OldOffset, std::memory_order::relaxed);
+ }
+ Chunk.clear();
+ if (NewBlockFile)
+ {
+ NewBlockFile->Flush();
+ NewBlockFile = nullptr;
+ }
+
+ const ChunkIndexArray& DeleteMap = BlockDeleteChunks[ChunkMapIndex];
+ for (size_t DeleteIndex : DeleteMap)
+ {
+ DeletedSize += ChunkLocations[DeleteIndex].Size;
+ }
+
+ ChangeCallback(MovedChunks, DeleteMap);
+ MovedCount += KeepMap.size();
+ DeletedCount += DeleteMap.size();
+ MovedChunks.clear();
+ {
+ RwLock::ExclusiveLockScope __(m_InsertLock);
+ Stopwatch Timer;
+ const auto ___ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ ReadBlockTimeUs += ElapsedUs;
+ ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs);
+ });
+ m_ChunkBlocks[BlockIndex] = nullptr;
+ ZEN_DEBUG("marking cas block store file '{}' for delete, block #{}", OldBlockFile->GetPath(), BlockIndex);
+ m_TotalSize.fetch_sub(OldBlockFile->FileSize(), std::memory_order::relaxed);
+ OldBlockFile->MarkAsDeleteOnClose();
+ }
+ }
+ }
+ catch (std::exception& ex)
+ {
+ ZEN_ERROR("reclaiming space for '{}' failed with: '{}'", m_BlocksBasePath, ex.what());
+ if (NewBlockFile)
+ {
+ ZEN_DEBUG("dropping incomplete cas block store file '{}'", NewBlockFile->GetPath());
+ m_TotalSize.fetch_sub(NewBlockFile->FileSize(), std::memory_order::relaxed);
+ NewBlockFile->MarkAsDeleteOnClose();
+ }
+ }
+}
+
+void
+BlockStore::IterateChunks(const std::vector<BlockStoreLocation>& ChunkLocations,
+ const IterateChunksSmallSizeCallback& SmallSizeCallback,
+ const IterateChunksLargeSizeCallback& LargeSizeCallback)
+{
+ std::vector<size_t> LocationIndexes;
+ LocationIndexes.reserve(ChunkLocations.size());
+ for (size_t ChunkIndex = 0; ChunkIndex < ChunkLocations.size(); ++ChunkIndex)
+ {
+ LocationIndexes.push_back(ChunkIndex);
+ }
+ std::sort(LocationIndexes.begin(), LocationIndexes.end(), [&](size_t IndexA, size_t IndexB) -> bool {
+ const BlockStoreLocation& LocationA = ChunkLocations[IndexA];
+ const BlockStoreLocation& LocationB = ChunkLocations[IndexB];
+ if (LocationA.BlockIndex < LocationB.BlockIndex)
+ {
+ return true;
+ }
+ else if (LocationA.BlockIndex > LocationB.BlockIndex)
+ {
+ return false;
+ }
+ return LocationA.Offset < LocationB.Offset;
+ });
+
+ IoBuffer ReadBuffer{ScrubSmallChunkWindowSize};
+ void* BufferBase = ReadBuffer.MutableData();
+
+ RwLock::SharedLockScope _(m_InsertLock);
+
+ auto GetNextRange = [&](size_t StartIndexOffset) {
+ size_t ChunkCount = 0;
+ size_t StartIndex = LocationIndexes[StartIndexOffset];
+ const BlockStoreLocation& StartLocation = ChunkLocations[StartIndex];
+ uint64_t StartOffset = StartLocation.Offset;
+ while (StartIndexOffset + ChunkCount < LocationIndexes.size())
+ {
+ size_t NextIndex = LocationIndexes[StartIndexOffset + ChunkCount];
+ const BlockStoreLocation& Location = ChunkLocations[NextIndex];
+ if (Location.BlockIndex != StartLocation.BlockIndex)
+ {
+ break;
+ }
+ if ((Location.Offset + Location.Size) - StartOffset > ScrubSmallChunkWindowSize)
+ {
+ break;
+ }
+ ++ChunkCount;
+ }
+ return ChunkCount;
+ };
+
+ size_t LocationIndexOffset = 0;
+ while (LocationIndexOffset < LocationIndexes.size())
+ {
+ size_t ChunkIndex = LocationIndexes[LocationIndexOffset];
+ const BlockStoreLocation& FirstLocation = ChunkLocations[ChunkIndex];
+
+ const Ref<BlockStoreFile>& BlockFile = m_ChunkBlocks[FirstLocation.BlockIndex];
+ if (!BlockFile)
+ {
+ while (ChunkLocations[ChunkIndex].BlockIndex == FirstLocation.BlockIndex)
+ {
+ SmallSizeCallback(ChunkIndex, nullptr, 0);
+ LocationIndexOffset++;
+ if (LocationIndexOffset == LocationIndexes.size())
+ {
+ break;
+ }
+ ChunkIndex = LocationIndexes[LocationIndexOffset];
+ }
+ continue;
+ }
+ size_t BlockSize = BlockFile->FileSize();
+ size_t RangeCount = GetNextRange(LocationIndexOffset);
+ if (RangeCount > 0)
+ {
+ size_t LastChunkIndex = LocationIndexes[LocationIndexOffset + RangeCount - 1];
+ const BlockStoreLocation& LastLocation = ChunkLocations[LastChunkIndex];
+ uint64_t Size = LastLocation.Offset + LastLocation.Size - FirstLocation.Offset;
+ BlockFile->Read(BufferBase, Size, FirstLocation.Offset);
+ for (size_t RangeIndex = 0; RangeIndex < RangeCount; ++RangeIndex)
+ {
+ size_t NextChunkIndex = LocationIndexes[LocationIndexOffset + RangeIndex];
+ const BlockStoreLocation& ChunkLocation = ChunkLocations[NextChunkIndex];
+ if (ChunkLocation.Size == 0 || (ChunkLocation.Offset + ChunkLocation.Size > BlockSize))
+ {
+ SmallSizeCallback(NextChunkIndex, nullptr, 0);
+ continue;
+ }
+ void* BufferPtr = &((char*)BufferBase)[ChunkLocation.Offset - FirstLocation.Offset];
+ SmallSizeCallback(NextChunkIndex, BufferPtr, ChunkLocation.Size);
+ }
+ LocationIndexOffset += RangeCount;
+ continue;
+ }
+ if (FirstLocation.Size == 0 || (FirstLocation.Offset + FirstLocation.Size > BlockSize))
+ {
+ SmallSizeCallback(ChunkIndex, nullptr, 0);
+ LocationIndexOffset++;
+ continue;
+ }
+ LargeSizeCallback(ChunkIndex, *BlockFile.Get(), FirstLocation.Offset, FirstLocation.Size);
+ LocationIndexOffset++;
+ }
+}
+
+const char*
+BlockStore::GetBlockFileExtension()
+{
+ return ".ucas";
+}
+
+std::filesystem::path
+BlockStore::GetBlockPath(const std::filesystem::path& BlocksBasePath, const uint32_t BlockIndex)
+{
+ ExtendablePathBuilder<256> Path;
+
+ char BlockHexString[9];
+ ToHexNumber(BlockIndex, BlockHexString);
+
+ Path.Append(BlocksBasePath);
+ Path.AppendSeparator();
+ Path.AppendAsciiRange(BlockHexString, BlockHexString + 4);
+ Path.AppendSeparator();
+ Path.Append(BlockHexString);
+ Path.Append(GetBlockFileExtension());
+ return Path.ToPath();
+}
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("blockstore.blockstoredisklocation")
+{
+ BlockStoreLocation Zero = BlockStoreLocation{.BlockIndex = 0, .Offset = 0, .Size = 0};
+ CHECK(Zero == BlockStoreDiskLocation(Zero, 4).Get(4));
+
+ BlockStoreLocation MaxBlockIndex = BlockStoreLocation{.BlockIndex = BlockStoreDiskLocation::MaxBlockIndex, .Offset = 0, .Size = 0};
+ CHECK(MaxBlockIndex == BlockStoreDiskLocation(MaxBlockIndex, 4).Get(4));
+
+ BlockStoreLocation MaxOffset = BlockStoreLocation{.BlockIndex = 0, .Offset = BlockStoreDiskLocation::MaxOffset * 4, .Size = 0};
+ CHECK(MaxOffset == BlockStoreDiskLocation(MaxOffset, 4).Get(4));
+
+ BlockStoreLocation MaxSize = BlockStoreLocation{.BlockIndex = 0, .Offset = 0, .Size = std::numeric_limits<uint32_t>::max()};
+ CHECK(MaxSize == BlockStoreDiskLocation(MaxSize, 4).Get(4));
+
+ BlockStoreLocation MaxBlockIndexAndOffset =
+ BlockStoreLocation{.BlockIndex = BlockStoreDiskLocation::MaxBlockIndex, .Offset = BlockStoreDiskLocation::MaxOffset * 4, .Size = 0};
+ CHECK(MaxBlockIndexAndOffset == BlockStoreDiskLocation(MaxBlockIndexAndOffset, 4).Get(4));
+
+ BlockStoreLocation MaxAll = BlockStoreLocation{.BlockIndex = BlockStoreDiskLocation::MaxBlockIndex,
+ .Offset = BlockStoreDiskLocation::MaxOffset * 4,
+ .Size = std::numeric_limits<uint32_t>::max()};
+ CHECK(MaxAll == BlockStoreDiskLocation(MaxAll, 4).Get(4));
+
+ BlockStoreLocation MaxAll4096 = BlockStoreLocation{.BlockIndex = BlockStoreDiskLocation::MaxBlockIndex,
+ .Offset = BlockStoreDiskLocation::MaxOffset * 4096,
+ .Size = std::numeric_limits<uint32_t>::max()};
+ CHECK(MaxAll4096 == BlockStoreDiskLocation(MaxAll4096, 4096).Get(4096));
+
+ BlockStoreLocation Middle = BlockStoreLocation{.BlockIndex = (BlockStoreDiskLocation::MaxBlockIndex) / 2,
+ .Offset = ((BlockStoreDiskLocation::MaxOffset) / 2) * 4,
+ .Size = std::numeric_limits<uint32_t>::max() / 2};
+ CHECK(Middle == BlockStoreDiskLocation(Middle, 4).Get(4));
+}
+
+TEST_CASE("blockstore.blockfile")
+{
+ ScopedTemporaryDirectory TempDir;
+ auto RootDirectory = TempDir.Path() / "blocks";
+ CreateDirectories(RootDirectory);
+
+ {
+ BlockStoreFile File1(RootDirectory / "1");
+ File1.Create(16384);
+ CHECK(File1.FileSize() == 0);
+ File1.Write("data", 5, 0);
+ IoBuffer DataChunk = File1.GetChunk(0, 5);
+ File1.Write("boop", 5, 5);
+ IoBuffer BoopChunk = File1.GetChunk(5, 5);
+ const char* Data = static_cast<const char*>(DataChunk.GetData());
+ CHECK(std::string(Data) == "data");
+ const char* Boop = static_cast<const char*>(BoopChunk.GetData());
+ CHECK(std::string(Boop) == "boop");
+ File1.Flush();
+ CHECK(File1.FileSize() == 10);
+ }
+ {
+ BlockStoreFile File1(RootDirectory / "1");
+ File1.Open();
+
+ char DataRaw[5];
+ File1.Read(DataRaw, 5, 0);
+ CHECK(std::string(DataRaw) == "data");
+ IoBuffer DataChunk = File1.GetChunk(0, 5);
+
+ char BoopRaw[5];
+ File1.Read(BoopRaw, 5, 5);
+ CHECK(std::string(BoopRaw) == "boop");
+
+ IoBuffer BoopChunk = File1.GetChunk(5, 5);
+ const char* Data = static_cast<const char*>(DataChunk.GetData());
+ CHECK(std::string(Data) == "data");
+ const char* Boop = static_cast<const char*>(BoopChunk.GetData());
+ CHECK(std::string(Boop) == "boop");
+ }
+
+ {
+ IoBuffer DataChunk;
+ IoBuffer BoopChunk;
+
+ {
+ BlockStoreFile File1(RootDirectory / "1");
+ File1.Open();
+ DataChunk = File1.GetChunk(0, 5);
+ BoopChunk = File1.GetChunk(5, 5);
+ }
+
+ CHECK(std::filesystem::exists(RootDirectory / "1"));
+
+ const char* Data = static_cast<const char*>(DataChunk.GetData());
+ CHECK(std::string(Data) == "data");
+ const char* Boop = static_cast<const char*>(BoopChunk.GetData());
+ CHECK(std::string(Boop) == "boop");
+ }
+ CHECK(std::filesystem::exists(RootDirectory / "1"));
+
+ {
+ IoBuffer DataChunk;
+ IoBuffer BoopChunk;
+
+ {
+ BlockStoreFile File1(RootDirectory / "1");
+ File1.Open();
+ File1.MarkAsDeleteOnClose();
+ DataChunk = File1.GetChunk(0, 5);
+ BoopChunk = File1.GetChunk(5, 5);
+ }
+
+ const char* Data = static_cast<const char*>(DataChunk.GetData());
+ CHECK(std::string(Data) == "data");
+ const char* Boop = static_cast<const char*>(BoopChunk.GetData());
+ CHECK(std::string(Boop) == "boop");
+ }
+ CHECK(!std::filesystem::exists(RootDirectory / "1"));
+}
+
+namespace blockstore::impl {
+ BlockStoreLocation WriteStringAsChunk(BlockStore& Store, std::string_view String, size_t PayloadAlignment)
+ {
+ BlockStoreLocation Location;
+ Store.WriteChunk(String.data(), String.length(), PayloadAlignment, [&](const BlockStoreLocation& L) { Location = L; });
+ CHECK(Location.Size == String.length());
+ return Location;
+ };
+
+ std::string ReadChunkAsString(BlockStore& Store, const BlockStoreLocation& Location)
+ {
+ IoBuffer ChunkData = Store.TryGetChunk(Location);
+ if (!ChunkData)
+ {
+ return "";
+ }
+ std::string AsString((const char*)ChunkData.Data(), ChunkData.Size());
+ return AsString;
+ };
+
+ std::vector<std::filesystem::path> GetDirectoryContent(std::filesystem::path RootDir, bool Files, bool Directories)
+ {
+ DirectoryContent DirectoryContent;
+ GetDirectoryContent(RootDir,
+ DirectoryContent::RecursiveFlag | (Files ? DirectoryContent::IncludeFilesFlag : 0) |
+ (Directories ? DirectoryContent::IncludeDirsFlag : 0),
+ DirectoryContent);
+ std::vector<std::filesystem::path> Result;
+ Result.insert(Result.end(), DirectoryContent.Directories.begin(), DirectoryContent.Directories.end());
+ Result.insert(Result.end(), DirectoryContent.Files.begin(), DirectoryContent.Files.end());
+ return Result;
+ };
+
+ static IoBuffer CreateChunk(uint64_t Size)
+ {
+ static std::random_device rd;
+ static std::mt19937 g(rd());
+
+ std::vector<uint8_t> Values;
+ Values.resize(Size);
+ for (size_t Idx = 0; Idx < Size; ++Idx)
+ {
+ Values[Idx] = static_cast<uint8_t>(Idx);
+ }
+ std::shuffle(Values.begin(), Values.end(), g);
+
+ return IoBufferBuilder::MakeCloneFromMemory(Values.data(), Values.size());
+ }
+} // namespace blockstore::impl
+
+TEST_CASE("blockstore.chunks")
+{
+ using namespace blockstore::impl;
+
+ ScopedTemporaryDirectory TempDir;
+ auto RootDirectory = TempDir.Path();
+
+ BlockStore Store;
+ Store.Initialize(RootDirectory, 128, 1024, {});
+ IoBuffer BadChunk = Store.TryGetChunk({.BlockIndex = 0, .Offset = 0, .Size = 512});
+ CHECK(!BadChunk);
+
+ std::string FirstChunkData = "This is the data of the first chunk that we will write";
+ BlockStoreLocation FirstChunkLocation = WriteStringAsChunk(Store, FirstChunkData, 4);
+ std::string SecondChunkData = "This is the data for the second chunk that we will write";
+ BlockStoreLocation SecondChunkLocation = WriteStringAsChunk(Store, SecondChunkData, 4);
+
+ CHECK(ReadChunkAsString(Store, FirstChunkLocation) == FirstChunkData);
+ CHECK(ReadChunkAsString(Store, SecondChunkLocation) == SecondChunkData);
+
+ std::string ThirdChunkData =
+ "This is a much longer string that will not fit in the first block so it should be placed in the second block";
+ BlockStoreLocation ThirdChunkLocation = WriteStringAsChunk(Store, ThirdChunkData, 4);
+ CHECK(ThirdChunkLocation.BlockIndex != FirstChunkLocation.BlockIndex);
+
+ CHECK(ReadChunkAsString(Store, FirstChunkLocation) == FirstChunkData);
+ CHECK(ReadChunkAsString(Store, SecondChunkLocation) == SecondChunkData);
+ CHECK(ReadChunkAsString(Store, ThirdChunkLocation) == ThirdChunkData);
+}
+
+TEST_CASE("blockstore.clean.stray.blocks")
+{
+ using namespace blockstore::impl;
+
+ ScopedTemporaryDirectory TempDir;
+ auto RootDirectory = TempDir.Path();
+
+ BlockStore Store;
+ Store.Initialize(RootDirectory / "store", 128, 1024, {});
+
+ std::string FirstChunkData = "This is the data of the first chunk that we will write";
+ BlockStoreLocation FirstChunkLocation = WriteStringAsChunk(Store, FirstChunkData, 4);
+ std::string SecondChunkData = "This is the data for the second chunk that we will write";
+ BlockStoreLocation SecondChunkLocation = WriteStringAsChunk(Store, SecondChunkData, 4);
+ std::string ThirdChunkData =
+ "This is a much longer string that will not fit in the first block so it should be placed in the second block";
+ WriteStringAsChunk(Store, ThirdChunkData, 4);
+
+ Store.Close();
+
+ // Not referencing the second block means that we should be deleted
+ Store.Initialize(RootDirectory / "store", 128, 1024, {FirstChunkLocation, SecondChunkLocation});
+
+ CHECK(GetDirectoryContent(RootDirectory / "store", true, false).size() == 1);
+}
+
+TEST_CASE("blockstore.flush.forces.new.block")
+{
+ using namespace blockstore::impl;
+
+ ScopedTemporaryDirectory TempDir;
+ auto RootDirectory = TempDir.Path();
+
+ BlockStore Store;
+ Store.Initialize(RootDirectory / "store", 128, 1024, {});
+
+ std::string FirstChunkData = "This is the data of the first chunk that we will write";
+ WriteStringAsChunk(Store, FirstChunkData, 4);
+ Store.Flush();
+ std::string SecondChunkData = "This is the data for the second chunk that we will write";
+ WriteStringAsChunk(Store, SecondChunkData, 4);
+ Store.Flush();
+ std::string ThirdChunkData =
+ "This is a much longer string that will not fit in the first block so it should be placed in the second block";
+ WriteStringAsChunk(Store, ThirdChunkData, 4);
+
+ CHECK(GetDirectoryContent(RootDirectory / "store", true, false).size() == 3);
+}
+
+TEST_CASE("blockstore.iterate.chunks")
+{
+ using namespace blockstore::impl;
+
+ ScopedTemporaryDirectory TempDir;
+ auto RootDirectory = TempDir.Path();
+
+ BlockStore Store;
+ Store.Initialize(RootDirectory / "store", ScrubSmallChunkWindowSize * 2, 1024, {});
+ IoBuffer BadChunk = Store.TryGetChunk({.BlockIndex = 0, .Offset = 0, .Size = 512});
+ CHECK(!BadChunk);
+
+ std::string FirstChunkData = "This is the data of the first chunk that we will write";
+ BlockStoreLocation FirstChunkLocation = WriteStringAsChunk(Store, FirstChunkData, 4);
+
+ std::string SecondChunkData = "This is the data for the second chunk that we will write";
+ BlockStoreLocation SecondChunkLocation = WriteStringAsChunk(Store, SecondChunkData, 4);
+ Store.Flush();
+
+ std::string VeryLargeChunk(ScrubSmallChunkWindowSize * 2, 'L');
+ BlockStoreLocation VeryLargeChunkLocation = WriteStringAsChunk(Store, VeryLargeChunk, 4);
+
+ BlockStoreLocation BadLocationZeroSize = {.BlockIndex = 0, .Offset = 0, .Size = 0};
+ BlockStoreLocation BadLocationOutOfRange = {.BlockIndex = 0,
+ .Offset = ScrubSmallChunkWindowSize,
+ .Size = ScrubSmallChunkWindowSize * 2};
+ BlockStoreLocation BadBlockIndex = {.BlockIndex = 0xfffff, .Offset = 1024, .Size = 1024};
+
+ Store.IterateChunks(
+ {FirstChunkLocation, SecondChunkLocation, VeryLargeChunkLocation, BadLocationZeroSize, BadLocationOutOfRange, BadBlockIndex},
+ [&](size_t ChunkIndex, const void* Data, uint64_t Size) {
+ switch (ChunkIndex)
+ {
+ case 0:
+ CHECK(Data);
+ CHECK(Size == FirstChunkData.size());
+ CHECK(std::string((const char*)Data, Size) == FirstChunkData);
+ break;
+ case 1:
+ CHECK(Data);
+ CHECK(Size == SecondChunkData.size());
+ CHECK(std::string((const char*)Data, Size) == SecondChunkData);
+ break;
+ case 2:
+ CHECK(false);
+ break;
+ case 3:
+ CHECK(!Data);
+ break;
+ case 4:
+ CHECK(!Data);
+ break;
+ case 5:
+ CHECK(!Data);
+ break;
+ default:
+ CHECK(false);
+ break;
+ }
+ },
+ [&](size_t ChunkIndex, BlockStoreFile& File, uint64_t Offset, uint64_t Size) {
+ switch (ChunkIndex)
+ {
+ case 0:
+ case 1:
+ CHECK(false);
+ break;
+ case 2:
+ {
+ CHECK(Size == VeryLargeChunk.size());
+ char* Buffer = new char[Size];
+ size_t HashOffset = 0;
+ File.StreamByteRange(Offset, Size, [&](const void* Data, uint64_t Size) {
+ memcpy(&Buffer[HashOffset], Data, Size);
+ HashOffset += Size;
+ });
+ CHECK(memcmp(Buffer, VeryLargeChunk.data(), Size) == 0);
+ delete[] Buffer;
+ }
+ break;
+ case 3:
+ CHECK(false);
+ break;
+ case 4:
+ CHECK(false);
+ break;
+ case 5:
+ CHECK(false);
+ break;
+ default:
+ CHECK(false);
+ break;
+ }
+ });
+}
+
+TEST_CASE("blockstore.reclaim.space")
+{
+ using namespace blockstore::impl;
+
+ ScopedTemporaryDirectory TempDir;
+ auto RootDirectory = TempDir.Path();
+
+ BlockStore Store;
+ Store.Initialize(RootDirectory / "store", 512, 1024, {});
+
+ constexpr size_t ChunkCount = 200;
+ constexpr size_t Alignment = 8;
+ std::vector<BlockStoreLocation> ChunkLocations;
+ std::vector<IoHash> ChunkHashes;
+ ChunkLocations.reserve(ChunkCount);
+ ChunkHashes.reserve(ChunkCount);
+ for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex)
+ {
+ IoBuffer Chunk = CreateChunk(57 + ChunkIndex);
+
+ Store.WriteChunk(Chunk.Data(), Chunk.Size(), Alignment, [&](const BlockStoreLocation& L) { ChunkLocations.push_back(L); });
+ ChunkHashes.push_back(IoHash::HashBuffer(Chunk.Data(), Chunk.Size()));
+ }
+
+ std::vector<size_t> ChunksToKeep;
+ ChunksToKeep.reserve(ChunkLocations.size());
+ for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex)
+ {
+ ChunksToKeep.push_back(ChunkIndex);
+ }
+
+ Store.Flush();
+ BlockStore::ReclaimSnapshotState State1 = Store.GetReclaimSnapshotState();
+ Store.ReclaimSpace(State1, ChunkLocations, ChunksToKeep, Alignment, true);
+
+ // If we keep all the chunks we should not get any callbacks on moved/deleted stuff
+ Store.ReclaimSpace(
+ State1,
+ ChunkLocations,
+ ChunksToKeep,
+ Alignment,
+ false,
+ [](const BlockStore::MovedChunksArray&, const BlockStore::ChunkIndexArray&) { CHECK(false); },
+ []() {
+ CHECK(false);
+ return 0;
+ });
+
+ size_t DeleteChunkCount = 38;
+ ChunksToKeep.clear();
+ for (size_t ChunkIndex = DeleteChunkCount; ChunkIndex < ChunkCount; ++ChunkIndex)
+ {
+ ChunksToKeep.push_back(ChunkIndex);
+ }
+
+ std::vector<BlockStoreLocation> NewChunkLocations = ChunkLocations;
+ size_t MovedChunkCount = 0;
+ size_t DeletedChunkCount = 0;
+ Store.ReclaimSpace(
+ State1,
+ ChunkLocations,
+ ChunksToKeep,
+ Alignment,
+ false,
+ [&](const BlockStore::MovedChunksArray& MovedChunks, const BlockStore::ChunkIndexArray& DeletedChunks) {
+ for (const auto& MovedChunk : MovedChunks)
+ {
+ CHECK(MovedChunk.first >= DeleteChunkCount);
+ NewChunkLocations[MovedChunk.first] = MovedChunk.second;
+ }
+ MovedChunkCount += MovedChunks.size();
+ for (size_t DeletedIndex : DeletedChunks)
+ {
+ CHECK(DeletedIndex < DeleteChunkCount);
+ }
+ DeletedChunkCount += DeletedChunks.size();
+ },
+ []() {
+ CHECK(false);
+ return 0;
+ });
+ CHECK(MovedChunkCount <= DeleteChunkCount);
+ CHECK(DeletedChunkCount == DeleteChunkCount);
+ ChunkLocations = std::vector<BlockStoreLocation>(NewChunkLocations.begin() + DeleteChunkCount, NewChunkLocations.end());
+
+ for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex)
+ {
+ IoBuffer ChunkBlock = Store.TryGetChunk(NewChunkLocations[ChunkIndex]);
+ if (ChunkIndex >= DeleteChunkCount)
+ {
+ IoBuffer VerifyChunk = Store.TryGetChunk(NewChunkLocations[ChunkIndex]);
+ CHECK(VerifyChunk);
+ IoHash VerifyHash = IoHash::HashBuffer(VerifyChunk.Data(), VerifyChunk.Size());
+ CHECK(VerifyHash == ChunkHashes[ChunkIndex]);
+ }
+ }
+
+ NewChunkLocations = ChunkLocations;
+ MovedChunkCount = 0;
+ DeletedChunkCount = 0;
+ Store.ReclaimSpace(
+ State1,
+ ChunkLocations,
+ {},
+ Alignment,
+ false,
+ [&](const BlockStore::MovedChunksArray& MovedChunks, const BlockStore::ChunkIndexArray& DeletedChunks) {
+ CHECK(MovedChunks.empty());
+ DeletedChunkCount += DeletedChunks.size();
+ },
+ []() {
+ CHECK(false);
+ return 0;
+ });
+ CHECK(DeletedChunkCount == ChunkCount - DeleteChunkCount);
+}
+
+TEST_CASE("blockstore.thread.read.write")
+{
+ using namespace blockstore::impl;
+
+ ScopedTemporaryDirectory TempDir;
+ auto RootDirectory = TempDir.Path();
+
+ BlockStore Store;
+ Store.Initialize(RootDirectory / "store", 1088, 1024, {});
+
+ constexpr size_t ChunkCount = 1000;
+ constexpr size_t Alignment = 8;
+ std::vector<IoBuffer> Chunks;
+ std::vector<IoHash> ChunkHashes;
+ Chunks.reserve(ChunkCount);
+ ChunkHashes.reserve(ChunkCount);
+ for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex)
+ {
+ IoBuffer Chunk = CreateChunk(57 + ChunkIndex / 2);
+ Chunks.push_back(Chunk);
+ ChunkHashes.push_back(IoHash::HashBuffer(Chunk.Data(), Chunk.Size()));
+ }
+
+ std::vector<BlockStoreLocation> ChunkLocations;
+ ChunkLocations.resize(ChunkCount);
+
+ WorkerThreadPool WorkerPool(8);
+ std::atomic<size_t> WorkCompleted = 0;
+ for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex)
+ {
+ WorkerPool.ScheduleWork([&Store, ChunkIndex, &Chunks, &ChunkLocations, &WorkCompleted]() {
+ IoBuffer& Chunk = Chunks[ChunkIndex];
+ Store.WriteChunk(Chunk.Data(), Chunk.Size(), Alignment, [&](const BlockStoreLocation& L) { ChunkLocations[ChunkIndex] = L; });
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < Chunks.size())
+ {
+ Sleep(1);
+ }
+
+ WorkCompleted = 0;
+ for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex)
+ {
+ WorkerPool.ScheduleWork([&Store, ChunkIndex, &ChunkLocations, &ChunkHashes, &WorkCompleted]() {
+ IoBuffer VerifyChunk = Store.TryGetChunk(ChunkLocations[ChunkIndex]);
+ CHECK(VerifyChunk);
+ IoHash VerifyHash = IoHash::HashBuffer(VerifyChunk.Data(), VerifyChunk.Size());
+ CHECK(VerifyHash == ChunkHashes[ChunkIndex]);
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < Chunks.size())
+ {
+ Sleep(1);
+ }
+
+ std::vector<BlockStoreLocation> SecondChunkLocations;
+ SecondChunkLocations.resize(ChunkCount);
+ WorkCompleted = 0;
+ for (size_t ChunkIndex = 0; ChunkIndex < ChunkCount; ++ChunkIndex)
+ {
+ WorkerPool.ScheduleWork([&Store, ChunkIndex, &Chunks, &SecondChunkLocations, &WorkCompleted]() {
+ IoBuffer& Chunk = Chunks[ChunkIndex];
+ Store.WriteChunk(Chunk.Data(), Chunk.Size(), Alignment, [&](const BlockStoreLocation& L) {
+ SecondChunkLocations[ChunkIndex] = L;
+ });
+ WorkCompleted.fetch_add(1);
+ });
+ WorkerPool.ScheduleWork([&Store, ChunkIndex, &ChunkLocations, &ChunkHashes, &WorkCompleted]() {
+ IoBuffer VerifyChunk = Store.TryGetChunk(ChunkLocations[ChunkIndex]);
+ CHECK(VerifyChunk);
+ IoHash VerifyHash = IoHash::HashBuffer(VerifyChunk.Data(), VerifyChunk.Size());
+ CHECK(VerifyHash == ChunkHashes[ChunkIndex]);
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < Chunks.size() * 2)
+ {
+ Sleep(1);
+ }
+}
+
+#endif
+
+void
+blockstore_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenstore/cas.cpp b/src/zenstore/cas.cpp
new file mode 100644
index 000000000..fdec78c60
--- /dev/null
+++ b/src/zenstore/cas.cpp
@@ -0,0 +1,355 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "cas.h"
+
+#include "compactcas.h"
+#include "filecas.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/except.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
+#include <zencore/thread.h>
+#include <zencore/trace.h>
+#include <zencore/uid.h>
+#include <zenstore/cidstore.h>
+#include <zenstore/gc.h>
+#include <zenstore/scrubcontext.h>
+
+#include <gsl/gsl-lite.hpp>
+
+#include <filesystem>
+#include <functional>
+#include <unordered_map>
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+/**
+ * CAS store implementation
+ *
+ * Uses a basic strategy of splitting payloads by size, to improve ability to reclaim space
+ * quickly for unused large chunks and to maintain locality for small chunks which are
+ * frequently accessed together.
+ *
+ */
+class CasImpl : public CasStore
+{
+public:
+ CasImpl(GcManager& Gc);
+ virtual ~CasImpl();
+
+ virtual void Initialize(const CidStoreConfiguration& InConfig) override;
+ virtual CasStore::InsertResult InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, InsertMode Mode) override;
+ virtual IoBuffer FindChunk(const IoHash& ChunkHash) override;
+ virtual bool ContainsChunk(const IoHash& ChunkHash) override;
+ virtual void FilterChunks(HashKeySet& InOutChunks) override;
+ virtual void Flush() override;
+ virtual void Scrub(ScrubContext& Ctx) override;
+ virtual void GarbageCollect(GcContext& GcCtx) override;
+ virtual CidStoreSize TotalSize() const override;
+
+private:
+ CasContainerStrategy m_TinyStrategy;
+ CasContainerStrategy m_SmallStrategy;
+ FileCasStrategy m_LargeStrategy;
+ CbObject m_ManifestObject;
+
+ enum class StorageScheme
+ {
+ Legacy = 0,
+ WithCbManifest = 1
+ };
+
+ StorageScheme m_StorageScheme = StorageScheme::Legacy;
+
+ bool OpenOrCreateManifest();
+ void UpdateManifest();
+};
+
+CasImpl::CasImpl(GcManager& Gc) : m_TinyStrategy(Gc), m_SmallStrategy(Gc), m_LargeStrategy(Gc)
+{
+}
+
+CasImpl::~CasImpl()
+{
+}
+
+void
+CasImpl::Initialize(const CidStoreConfiguration& InConfig)
+{
+ m_Config = InConfig;
+
+ ZEN_INFO("initializing CAS pool at '{}'", m_Config.RootDirectory);
+
+ // Ensure root directory exists - create if it doesn't exist already
+
+ std::filesystem::create_directories(m_Config.RootDirectory);
+
+ // Open or create manifest
+
+ const bool IsNewStore = OpenOrCreateManifest();
+
+ // Initialize payload storage
+
+ m_LargeStrategy.Initialize(m_Config.RootDirectory, IsNewStore);
+ m_TinyStrategy.Initialize(m_Config.RootDirectory, "tobs", 1u << 28, 16, IsNewStore); // 256 Mb per block
+ m_SmallStrategy.Initialize(m_Config.RootDirectory, "sobs", 1u << 30, 4096, IsNewStore); // 1 Gb per block
+}
+
+bool
+CasImpl::OpenOrCreateManifest()
+{
+ bool IsNewStore = false;
+
+ std::filesystem::path ManifestPath = m_Config.RootDirectory;
+ ManifestPath /= ".ucas_root";
+
+ std::error_code Ec;
+ BasicFile ManifestFile;
+ ManifestFile.Open(ManifestPath.c_str(), BasicFile::Mode::kRead, Ec);
+
+ bool ManifestIsOk = false;
+
+ if (Ec)
+ {
+ if (Ec == std::errc::no_such_file_or_directory)
+ {
+ IsNewStore = true;
+ }
+ }
+ else
+ {
+ IoBuffer ManifestBuffer = ManifestFile.ReadAll();
+ ManifestFile.Close();
+
+ if (ManifestBuffer.Size() > 0 && ManifestBuffer.Data<uint8_t>()[0] == '#')
+ {
+ // Old-style manifest, does not contain any useful information, so we may as well update it
+ }
+ else
+ {
+ CbObject Manifest{SharedBuffer(ManifestBuffer)};
+ CbValidateError ValidationResult = ValidateCompactBinary(ManifestBuffer, CbValidateMode::All);
+
+ if (ValidationResult == CbValidateError::None)
+ {
+ if (Manifest["id"])
+ {
+ ManifestIsOk = true;
+ }
+ }
+ else
+ {
+ ZEN_WARN("Store manifest validation failed: {:#x}, will generate new manifest to recover", uint32_t(ValidationResult));
+ }
+
+ if (ManifestIsOk)
+ {
+ m_ManifestObject = std::move(Manifest);
+ }
+ }
+ }
+
+ if (!ManifestIsOk)
+ {
+ UpdateManifest();
+ }
+
+ return IsNewStore;
+}
+
+void
+CasImpl::UpdateManifest()
+{
+ if (!m_ManifestObject)
+ {
+ CbObjectWriter Cbo;
+ Cbo << "id" << zen::Oid::NewOid() << "created" << DateTime::Now();
+ m_ManifestObject = Cbo.Save();
+ }
+
+ // Write manifest to file
+
+ std::filesystem::path ManifestPath = m_Config.RootDirectory;
+ ManifestPath /= ".ucas_root";
+
+ // This will throw on failure
+
+ ZEN_TRACE("Writing new manifest to '{}'", ManifestPath);
+
+ BasicFile Marker;
+ Marker.Open(ManifestPath.c_str(), BasicFile::Mode::kTruncate);
+ Marker.Write(m_ManifestObject.GetBuffer(), 0);
+}
+
+CasStore::InsertResult
+CasImpl::InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, InsertMode Mode)
+{
+ ZEN_TRACE_CPU("CAS::InsertChunk");
+
+ const uint64_t ChunkSize = Chunk.Size();
+
+ if (ChunkSize < m_Config.TinyValueThreshold)
+ {
+ ZEN_ASSERT(ChunkSize);
+
+ return m_TinyStrategy.InsertChunk(Chunk, ChunkHash);
+ }
+ else if (ChunkSize < m_Config.HugeValueThreshold)
+ {
+ return m_SmallStrategy.InsertChunk(Chunk, ChunkHash);
+ }
+
+ return m_LargeStrategy.InsertChunk(Chunk, ChunkHash, Mode);
+}
+
+IoBuffer
+CasImpl::FindChunk(const IoHash& ChunkHash)
+{
+ ZEN_TRACE_CPU("CAS::FindChunk");
+
+ if (IoBuffer Found = m_SmallStrategy.FindChunk(ChunkHash))
+ {
+ return Found;
+ }
+
+ if (IoBuffer Found = m_TinyStrategy.FindChunk(ChunkHash))
+ {
+ return Found;
+ }
+
+ if (IoBuffer Found = m_LargeStrategy.FindChunk(ChunkHash))
+ {
+ return Found;
+ }
+
+ // Not found
+ return IoBuffer{};
+}
+
+bool
+CasImpl::ContainsChunk(const IoHash& ChunkHash)
+{
+ return m_SmallStrategy.HaveChunk(ChunkHash) || m_TinyStrategy.HaveChunk(ChunkHash) || m_LargeStrategy.HaveChunk(ChunkHash);
+}
+
+void
+CasImpl::FilterChunks(HashKeySet& InOutChunks)
+{
+ m_SmallStrategy.FilterChunks(InOutChunks);
+ m_TinyStrategy.FilterChunks(InOutChunks);
+ m_LargeStrategy.FilterChunks(InOutChunks);
+}
+
+void
+CasImpl::Flush()
+{
+ m_SmallStrategy.Flush();
+ m_TinyStrategy.Flush();
+ m_LargeStrategy.Flush();
+}
+
+void
+CasImpl::Scrub(ScrubContext& Ctx)
+{
+ if (m_LastScrubTime == Ctx.ScrubTimestamp())
+ {
+ return;
+ }
+
+ m_LastScrubTime = Ctx.ScrubTimestamp();
+
+ m_SmallStrategy.Scrub(Ctx);
+ m_TinyStrategy.Scrub(Ctx);
+ m_LargeStrategy.Scrub(Ctx);
+}
+
+void
+CasImpl::GarbageCollect(GcContext& GcCtx)
+{
+ m_SmallStrategy.CollectGarbage(GcCtx);
+ m_TinyStrategy.CollectGarbage(GcCtx);
+ m_LargeStrategy.CollectGarbage(GcCtx);
+}
+
+CidStoreSize
+CasImpl::TotalSize() const
+{
+ const uint64_t Tiny = m_TinyStrategy.StorageSize().DiskSize;
+ const uint64_t Small = m_SmallStrategy.StorageSize().DiskSize;
+ const uint64_t Large = m_LargeStrategy.StorageSize().DiskSize;
+
+ return {.TinySize = Tiny, .SmallSize = Small, .LargeSize = Large, .TotalSize = Tiny + Small + Large};
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+std::unique_ptr<CasStore>
+CreateCasStore(GcManager& Gc)
+{
+ return std::make_unique<CasImpl>(Gc);
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("CasStore")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ CidStoreConfiguration config;
+ config.RootDirectory = TempDir.Path();
+
+ GcManager Gc;
+
+ std::unique_ptr<CasStore> Store = CreateCasStore(Gc);
+ Store->Initialize(config);
+
+ ScrubContext Ctx;
+ Store->Scrub(Ctx);
+
+ IoBuffer Value1{16};
+ memcpy(Value1.MutableData(), "1234567890123456", 16);
+ IoHash Hash1 = IoHash::HashBuffer(Value1.Data(), Value1.Size());
+ CasStore::InsertResult Result1 = Store->InsertChunk(Value1, Hash1);
+ CHECK(Result1.New);
+
+ IoBuffer Value2{16};
+ memcpy(Value2.MutableData(), "ABCDEFGHIJKLMNOP", 16);
+ IoHash Hash2 = IoHash::HashBuffer(Value2.Data(), Value2.Size());
+ CasStore::InsertResult Result2 = Store->InsertChunk(Value2, Hash2);
+ CHECK(Result2.New);
+
+ HashKeySet ChunkSet;
+ ChunkSet.AddHashToSet(Hash1);
+ ChunkSet.AddHashToSet(Hash2);
+
+ Store->FilterChunks(ChunkSet);
+ CHECK(ChunkSet.IsEmpty());
+
+ IoBuffer Lookup1 = Store->FindChunk(Hash1);
+ CHECK(Lookup1);
+ IoBuffer Lookup2 = Store->FindChunk(Hash2);
+ CHECK(Lookup2);
+}
+
+void
+CAS_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenstore/cas.h b/src/zenstore/cas.h
new file mode 100644
index 000000000..9c48d4707
--- /dev/null
+++ b/src/zenstore/cas.h
@@ -0,0 +1,67 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/blake3.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/refcount.h>
+#include <zencore/timer.h>
+#include <zenstore/cidstore.h>
+#include <zenstore/hashkeyset.h>
+
+#include <atomic>
+#include <filesystem>
+#include <functional>
+#include <memory>
+#include <string>
+#include <unordered_set>
+
+namespace zen {
+
+class GcContext;
+class GcManager;
+class ScrubContext;
+
+/** Content Addressable Storage interface
+
+ */
+
+class CasStore
+{
+public:
+ virtual ~CasStore() = default;
+
+ const CidStoreConfiguration& Config() { return m_Config; }
+
+ struct InsertResult
+ {
+ bool New = false;
+ };
+
+ enum class InsertMode
+ {
+ kCopyOnly,
+ kMayBeMovedInPlace
+ };
+
+ virtual void Initialize(const CidStoreConfiguration& Config) = 0;
+ virtual InsertResult InsertChunk(IoBuffer Data, const IoHash& ChunkHash, InsertMode Mode = InsertMode::kMayBeMovedInPlace) = 0;
+ virtual IoBuffer FindChunk(const IoHash& ChunkHash) = 0;
+ virtual bool ContainsChunk(const IoHash& ChunkHash) = 0;
+ virtual void FilterChunks(HashKeySet& InOutChunks) = 0;
+ virtual void Flush() = 0;
+ virtual void Scrub(ScrubContext& Ctx) = 0;
+ virtual void GarbageCollect(GcContext& GcCtx) = 0;
+ virtual CidStoreSize TotalSize() const = 0;
+
+protected:
+ CidStoreConfiguration m_Config;
+ uint64_t m_LastScrubTime = 0;
+};
+
+ZENCORE_API std::unique_ptr<CasStore> CreateCasStore(GcManager& Gc);
+
+void CAS_forcelink();
+
+} // namespace zen
diff --git a/src/zenstore/caslog.cpp b/src/zenstore/caslog.cpp
new file mode 100644
index 000000000..2a978ae12
--- /dev/null
+++ b/src/zenstore/caslog.cpp
@@ -0,0 +1,236 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenstore/caslog.h>
+
+#include "compactcas.h"
+
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+#include <zencore/uid.h>
+
+#include <xxhash.h>
+
+#include <gsl/gsl-lite.hpp>
+
+#include <filesystem>
+#include <functional>
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+uint32_t
+CasLogFile::FileHeader::ComputeChecksum()
+{
+ return XXH32(&this->Magic, sizeof(FileHeader) - 4, 0xC0C0'BABA);
+}
+
+CasLogFile::CasLogFile()
+{
+}
+
+CasLogFile::~CasLogFile()
+{
+}
+
+bool
+CasLogFile::IsValid(std::filesystem::path FileName, size_t RecordSize)
+{
+ if (!std::filesystem::is_regular_file(FileName))
+ {
+ return false;
+ }
+ BasicFile File;
+
+ std::error_code Ec;
+ File.Open(FileName, BasicFile::Mode::kRead, Ec);
+ if (Ec)
+ {
+ return false;
+ }
+
+ FileHeader Header;
+ if (File.FileSize() < sizeof(Header))
+ {
+ return false;
+ }
+
+ // Validate header and log contents and prepare for appending/replay
+ File.Read(&Header, sizeof Header, 0);
+
+ if ((0 != memcmp(Header.Magic, FileHeader::MagicSequence, sizeof Header.Magic)) || (Header.Checksum != Header.ComputeChecksum()))
+ {
+ return false;
+ }
+ if (Header.RecordSize != RecordSize)
+ {
+ return false;
+ }
+ return true;
+}
+
+void
+CasLogFile::Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode)
+{
+ m_RecordSize = RecordSize;
+
+ std::error_code Ec;
+ BasicFile::Mode FileMode = BasicFile::Mode::kRead;
+ switch (Mode)
+ {
+ case Mode::kWrite:
+ FileMode = BasicFile::Mode::kWrite;
+ break;
+ case Mode::kTruncate:
+ FileMode = BasicFile::Mode::kTruncate;
+ break;
+ }
+
+ m_File.Open(FileName, FileMode, Ec);
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("Failed to open log file '{}'", FileName));
+ }
+
+ uint64_t AppendOffset = 0;
+
+ if ((Mode == Mode::kTruncate) || (m_File.FileSize() < sizeof(FileHeader)))
+ {
+ if (Mode == Mode::kRead)
+ {
+ throw std::runtime_error(fmt::format("Mangled log header (file to small) in '{}'", FileName));
+ }
+ // Initialize log by writing header
+ FileHeader Header = {.RecordSize = gsl::narrow<uint32_t>(RecordSize), .LogId = Oid::NewOid(), .ValidatedTail = 0};
+ memcpy(Header.Magic, FileHeader::MagicSequence, sizeof Header.Magic);
+ Header.Finalize();
+
+ m_File.Write(&Header, sizeof Header, 0);
+
+ AppendOffset = sizeof(FileHeader);
+
+ m_Header = Header;
+ }
+ else
+ {
+ FileHeader Header;
+ m_File.Read(&Header, sizeof Header, 0);
+
+ if ((0 != memcmp(Header.Magic, FileHeader::MagicSequence, sizeof Header.Magic)) || (Header.Checksum != Header.ComputeChecksum()))
+ {
+ throw std::runtime_error(fmt::format("Mangled log header (invalid header magic) in '{}'", FileName));
+ }
+ if (Header.RecordSize != RecordSize)
+ {
+ throw std::runtime_error(fmt::format("Mangled log header (mismatch in record size, expected {}, found {}) in '{}'",
+ RecordSize,
+ Header.RecordSize,
+ FileName));
+ }
+
+ AppendOffset = m_File.FileSize();
+
+ // Adjust the offset to ensure we end up on a good boundary, in case there is some garbage appended
+
+ AppendOffset -= sizeof Header;
+ AppendOffset -= AppendOffset % RecordSize;
+ AppendOffset += sizeof Header;
+
+ m_Header = Header;
+ }
+
+ m_AppendOffset = AppendOffset;
+}
+
+void
+CasLogFile::Close()
+{
+ // TODO: update header and maybe add trailer
+ Flush();
+
+ m_File.Close();
+}
+
+uint64_t
+CasLogFile::GetLogSize()
+{
+ return m_File.FileSize();
+}
+
+uint64_t
+CasLogFile::GetLogCount()
+{
+ uint64_t LogFileSize = m_AppendOffset.load(std::memory_order_acquire);
+ if (LogFileSize < sizeof(FileHeader))
+ {
+ return 0;
+ }
+ const uint64_t LogBaseOffset = sizeof(FileHeader);
+ const size_t LogEntryCount = (LogFileSize - LogBaseOffset) / m_RecordSize;
+ return LogEntryCount;
+}
+
+void
+CasLogFile::Replay(std::function<void(const void*)>&& Handler, uint64_t SkipEntryCount)
+{
+ uint64_t LogFileSize = m_File.FileSize();
+
+ // Ensure we end up on a clean boundary
+ uint64_t LogBaseOffset = sizeof(FileHeader);
+ size_t LogEntryCount = (LogFileSize - LogBaseOffset) / m_RecordSize;
+
+ if (LogEntryCount <= SkipEntryCount)
+ {
+ return;
+ }
+
+ LogBaseOffset += SkipEntryCount * m_RecordSize;
+ LogEntryCount -= SkipEntryCount;
+
+ // This should really be streaming the data rather than just
+ // reading it into memory, though we don't tend to get very
+ // large logs so it may not matter
+
+ const uint64_t LogDataSize = LogEntryCount * m_RecordSize;
+
+ std::vector<uint8_t> ReadBuffer;
+ ReadBuffer.resize(LogDataSize);
+
+ m_File.Read(ReadBuffer.data(), LogDataSize, LogBaseOffset);
+
+ for (int i = 0; i < int(LogEntryCount); ++i)
+ {
+ Handler(ReadBuffer.data() + (i * m_RecordSize));
+ }
+
+ m_AppendOffset = LogBaseOffset + (m_RecordSize * LogEntryCount);
+}
+
+void
+CasLogFile::Append(const void* DataPointer, uint64_t DataSize)
+{
+ ZEN_ASSERT((DataSize % m_RecordSize) == 0);
+
+ uint64_t AppendOffset = m_AppendOffset.fetch_add(DataSize);
+
+ std::error_code Ec;
+ m_File.Write(DataPointer, gsl::narrow<uint32_t>(DataSize), AppendOffset, Ec);
+
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("Failed to write to log file '{}'", PathFromHandle(m_File.Handle())));
+ }
+}
+
+void
+CasLogFile::Flush()
+{
+ m_File.Flush();
+}
+
+} // namespace zen
diff --git a/src/zenstore/cidstore.cpp b/src/zenstore/cidstore.cpp
new file mode 100644
index 000000000..5a5116faf
--- /dev/null
+++ b/src/zenstore/cidstore.cpp
@@ -0,0 +1,125 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenstore/cidstore.h"
+
+#include <zencore/compress.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+#include <zenstore/scrubcontext.h>
+
+#include "cas.h"
+
+#include <filesystem>
+
+namespace zen {
+
+struct CidStore::Impl
+{
+ Impl(CasStore& InCasStore) : m_CasStore(InCasStore) {}
+
+ CasStore& m_CasStore;
+
+ void Initialize(const CidStoreConfiguration& Config) { m_CasStore.Initialize(Config); }
+
+ CidStore::InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, CidStore::InsertMode Mode)
+ {
+#ifndef NDEBUG
+ IoHash VerifyRawHash;
+ uint64_t _;
+ ZEN_ASSERT(CompressedBuffer::ValidateCompressedHeader(ChunkData, VerifyRawHash, _) && RawHash == VerifyRawHash);
+#endif // NDEBUG
+ IoBuffer Payload(ChunkData);
+ Payload.SetContentType(ZenContentType::kCompressedBinary);
+
+ CasStore::InsertResult Result = m_CasStore.InsertChunk(Payload, RawHash, static_cast<CasStore::InsertMode>(Mode));
+
+ return {.New = Result.New};
+ }
+
+ IoBuffer FindChunkByCid(const IoHash& DecompressedId) { return m_CasStore.FindChunk(DecompressedId); }
+
+ bool ContainsChunk(const IoHash& DecompressedId) { return m_CasStore.ContainsChunk(DecompressedId); }
+
+ void FilterChunks(HashKeySet& InOutChunks)
+ {
+ InOutChunks.RemoveHashesIf([&](const IoHash& Hash) { return ContainsChunk(Hash); });
+ }
+
+ void Flush() { m_CasStore.Flush(); }
+
+ void Scrub(ScrubContext& Ctx)
+ {
+ if (Ctx.ScrubTimestamp() == m_LastScrubTime)
+ {
+ return;
+ }
+
+ m_LastScrubTime = Ctx.ScrubTimestamp();
+
+ m_CasStore.Scrub(Ctx);
+ }
+
+ uint64_t m_LastScrubTime = 0;
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+CidStore::CidStore(GcManager& Gc) : m_CasStore(CreateCasStore(Gc)), m_Impl(std::make_unique<Impl>(*m_CasStore))
+{
+}
+
+CidStore::~CidStore()
+{
+}
+
+void
+CidStore::Initialize(const CidStoreConfiguration& Config)
+{
+ m_Impl->Initialize(Config);
+}
+
+CidStore::InsertResult
+CidStore::AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, InsertMode Mode)
+{
+ return m_Impl->AddChunk(ChunkData, RawHash, Mode);
+}
+
+IoBuffer
+CidStore::FindChunkByCid(const IoHash& DecompressedId)
+{
+ return m_Impl->FindChunkByCid(DecompressedId);
+}
+
+bool
+CidStore::ContainsChunk(const IoHash& DecompressedId)
+{
+ return m_Impl->ContainsChunk(DecompressedId);
+}
+
+void
+CidStore::FilterChunks(HashKeySet& InOutChunks)
+{
+ return m_Impl->FilterChunks(InOutChunks);
+}
+
+void
+CidStore::Flush()
+{
+ m_Impl->Flush();
+}
+
+void
+CidStore::Scrub(ScrubContext& Ctx)
+{
+ m_Impl->Scrub(Ctx);
+}
+
+CidStoreSize
+CidStore::TotalSize() const
+{
+ return m_Impl->m_CasStore.TotalSize();
+}
+
+} // namespace zen
diff --git a/src/zenstore/compactcas.cpp b/src/zenstore/compactcas.cpp
new file mode 100644
index 000000000..7b2c21b0f
--- /dev/null
+++ b/src/zenstore/compactcas.cpp
@@ -0,0 +1,1511 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "compactcas.h"
+
+#include "cas.h"
+
+#include <zencore/compress.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zenstore/scrubcontext.h>
+
+#include <gsl/gsl-lite.hpp>
+
+#include <xxhash.h>
+
+#if ZEN_WITH_TESTS
+# include <zencore/compactbinarybuilder.h>
+# include <zencore/testing.h>
+# include <zencore/testutils.h>
+# include <zencore/workthreadpool.h>
+# include <zenstore/cidstore.h>
+# include <algorithm>
+# include <random>
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+struct CasDiskIndexHeader
+{
+ static constexpr uint32_t ExpectedMagic = 0x75696478; // 'uidx';
+ static constexpr uint32_t CurrentVersion = 1;
+
+ uint32_t Magic = ExpectedMagic;
+ uint32_t Version = CurrentVersion;
+ uint64_t EntryCount = 0;
+ uint64_t LogPosition = 0;
+ uint32_t PayloadAlignment = 0;
+ uint32_t Checksum = 0;
+
+ static uint32_t ComputeChecksum(const CasDiskIndexHeader& Header)
+ {
+ return XXH32(&Header.Magic, sizeof(CasDiskIndexHeader) - sizeof(uint32_t), 0xC0C0'BABA);
+ }
+};
+
+static_assert(sizeof(CasDiskIndexHeader) == 32);
+
+namespace {
+ const char* IndexExtension = ".uidx";
+ const char* LogExtension = ".ulog";
+
+ std::filesystem::path GetBasePath(const std::filesystem::path& RootPath, const std::string& ContainerBaseName)
+ {
+ return RootPath / ContainerBaseName;
+ }
+
+ std::filesystem::path GetIndexPath(const std::filesystem::path& RootPath, const std::string& ContainerBaseName)
+ {
+ return GetBasePath(RootPath, ContainerBaseName) / (ContainerBaseName + IndexExtension);
+ }
+
+ std::filesystem::path GetTempIndexPath(const std::filesystem::path& RootPath, const std::string& ContainerBaseName)
+ {
+ return GetBasePath(RootPath, ContainerBaseName) / (ContainerBaseName + ".tmp" + LogExtension);
+ }
+
+ std::filesystem::path GetLogPath(const std::filesystem::path& RootPath, const std::string& ContainerBaseName)
+ {
+ return GetBasePath(RootPath, ContainerBaseName) / (ContainerBaseName + LogExtension);
+ }
+
+ std::filesystem::path GetBlocksBasePath(const std::filesystem::path& RootPath, const std::string& ContainerBaseName)
+ {
+ return GetBasePath(RootPath, ContainerBaseName) / "blocks";
+ }
+
+ bool ValidateEntry(const CasDiskIndexEntry& Entry, std::string& OutReason)
+ {
+ if (Entry.Key == IoHash::Zero)
+ {
+ OutReason = fmt::format("Invalid hash key {}", Entry.Key.ToHexString());
+ return false;
+ }
+ if ((Entry.Flags & ~CasDiskIndexEntry::kTombstone) != 0)
+ {
+ OutReason = fmt::format("Invalid flags {} for entry {}", Entry.Flags, Entry.Key.ToHexString());
+ return false;
+ }
+ if (Entry.Flags & CasDiskIndexEntry::kTombstone)
+ {
+ return true;
+ }
+ if (Entry.ContentType != ZenContentType::kUnknownContentType)
+ {
+ OutReason =
+ fmt::format("Invalid content type {} for entry {}", static_cast<uint8_t>(Entry.ContentType), Entry.Key.ToHexString());
+ return false;
+ }
+ uint64_t Size = Entry.Location.GetSize();
+ if (Size == 0)
+ {
+ OutReason = fmt::format("Invalid size {} for entry {}", Size, Entry.Key.ToHexString());
+ return false;
+ }
+ return true;
+ }
+
+} // namespace
+
+//////////////////////////////////////////////////////////////////////////
+
+CasContainerStrategy::CasContainerStrategy(GcManager& Gc) : GcStorage(Gc), m_Log(logging::Get("containercas"))
+{
+}
+
+CasContainerStrategy::~CasContainerStrategy()
+{
+}
+
+void
+CasContainerStrategy::Initialize(const std::filesystem::path& RootDirectory,
+ const std::string_view ContainerBaseName,
+ uint32_t MaxBlockSize,
+ uint64_t Alignment,
+ bool IsNewStore)
+{
+ ZEN_ASSERT(IsPow2(Alignment));
+ ZEN_ASSERT(!m_IsInitialized);
+ ZEN_ASSERT(MaxBlockSize > 0);
+
+ m_RootDirectory = RootDirectory;
+ m_ContainerBaseName = ContainerBaseName;
+ m_PayloadAlignment = Alignment;
+ m_MaxBlockSize = MaxBlockSize;
+ m_BlocksBasePath = GetBlocksBasePath(m_RootDirectory, m_ContainerBaseName);
+
+ OpenContainer(IsNewStore);
+
+ m_IsInitialized = true;
+}
+
+CasStore::InsertResult
+CasContainerStrategy::InsertChunk(const void* ChunkData, size_t ChunkSize, const IoHash& ChunkHash)
+{
+ {
+ RwLock::SharedLockScope _(m_LocationMapLock);
+ if (m_LocationMap.contains(ChunkHash))
+ {
+ return CasStore::InsertResult{.New = false};
+ }
+ }
+
+ // We can end up in a situation that InsertChunk writes the same chunk data in
+ // different locations.
+ // We release the insert lock once we have the correct WriteBlock ready and we know
+ // where to write the data. If a new InsertChunk request for the same chunk hash/data
+ // comes in before we update m_LocationMap below we will have a race.
+ // The outcome of that is that we will write the chunk data in more than one location
+ // but the chunk hash will only point to one of the chunks.
+ // We will in that case waste space until the next GC operation.
+ //
+ // This should be a rare occasion and the current flow reduces the time we block for
+ // reads, insert and GC.
+
+ m_BlockStore.WriteChunk(ChunkData, ChunkSize, m_PayloadAlignment, [&](const BlockStoreLocation& Location) {
+ BlockStoreDiskLocation DiskLocation(Location, m_PayloadAlignment);
+ const CasDiskIndexEntry IndexEntry{.Key = ChunkHash, .Location = DiskLocation};
+ m_CasLog.Append(IndexEntry);
+ {
+ RwLock::ExclusiveLockScope _(m_LocationMapLock);
+ m_LocationMap.emplace(ChunkHash, DiskLocation);
+ }
+ });
+
+ return CasStore::InsertResult{.New = true};
+}
+
+CasStore::InsertResult
+CasContainerStrategy::InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash)
+{
+#if !ZEN_WITH_TESTS
+ ZEN_ASSERT(Chunk.GetContentType() == ZenContentType::kCompressedBinary);
+#endif
+ return InsertChunk(Chunk.Data(), Chunk.Size(), ChunkHash);
+}
+
+IoBuffer
+CasContainerStrategy::FindChunk(const IoHash& ChunkHash)
+{
+ RwLock::SharedLockScope _(m_LocationMapLock);
+ auto KeyIt = m_LocationMap.find(ChunkHash);
+ if (KeyIt == m_LocationMap.end())
+ {
+ return IoBuffer();
+ }
+ const BlockStoreLocation& Location = KeyIt->second.Get(m_PayloadAlignment);
+
+ IoBuffer Chunk = m_BlockStore.TryGetChunk(Location);
+ return Chunk;
+}
+
+bool
+CasContainerStrategy::HaveChunk(const IoHash& ChunkHash)
+{
+ RwLock::SharedLockScope _(m_LocationMapLock);
+ return m_LocationMap.contains(ChunkHash);
+}
+
+void
+CasContainerStrategy::FilterChunks(HashKeySet& InOutChunks)
+{
+ // This implementation is good enough for relatively small
+ // chunk sets (in terms of chunk identifiers), but would
+ // benefit from a better implementation which removes
+ // items incrementally for large sets, especially when
+ // we're likely to already have a large proportion of the
+ // chunks in the set
+
+ InOutChunks.RemoveHashesIf([&](const IoHash& Hash) { return HaveChunk(Hash); });
+}
+
+void
+CasContainerStrategy::Flush()
+{
+ m_BlockStore.Flush();
+ m_CasLog.Flush();
+ MakeIndexSnapshot();
+}
+
+void
+CasContainerStrategy::Scrub(ScrubContext& Ctx)
+{
+ std::vector<IoHash> BadKeys;
+ uint64_t ChunkCount{0}, ChunkBytes{0};
+ std::vector<BlockStoreLocation> ChunkLocations;
+ std::vector<IoHash> ChunkIndexToChunkHash;
+
+ RwLock::SharedLockScope _(m_LocationMapLock);
+
+ uint64_t TotalChunkCount = m_LocationMap.size();
+ ChunkLocations.reserve(TotalChunkCount);
+ ChunkIndexToChunkHash.reserve(TotalChunkCount);
+ {
+ for (const auto& Entry : m_LocationMap)
+ {
+ const IoHash& ChunkHash = Entry.first;
+ const BlockStoreDiskLocation& DiskLocation = Entry.second;
+ BlockStoreLocation Location = DiskLocation.Get(m_PayloadAlignment);
+
+ ChunkLocations.push_back(Location);
+ ChunkIndexToChunkHash.push_back(ChunkHash);
+ }
+ }
+
+ const auto ValidateSmallChunk = [&](size_t ChunkIndex, const void* Data, uint64_t Size) {
+ ++ChunkCount;
+ ChunkBytes += Size;
+
+ const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex];
+ if (!Data)
+ {
+ // ChunkLocation out of range of stored blocks
+ BadKeys.push_back(Hash);
+ return;
+ }
+
+ IoBuffer Buffer(IoBuffer::Wrap, Data, Size);
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize))
+ {
+ if (RawHash != Hash)
+ {
+ // Hash mismatch
+ BadKeys.push_back(Hash);
+ return;
+ }
+ return;
+ }
+#if ZEN_WITH_TESTS
+ IoHash ComputedHash = IoHash::HashBuffer(Data, Size);
+ if (ComputedHash == Hash)
+ {
+ return;
+ }
+#endif
+ BadKeys.push_back(Hash);
+ };
+
+ const auto ValidateLargeChunk = [&](size_t ChunkIndex, BlockStoreFile& File, uint64_t Offset, uint64_t Size) {
+ ++ChunkCount;
+ ChunkBytes += Size;
+
+ const IoHash& Hash = ChunkIndexToChunkHash[ChunkIndex];
+ IoBuffer Buffer(IoBuffer::BorrowedFile, File.GetBasicFile().Handle(), Offset, Size);
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ // TODO: Add API to verify compressed buffer without having to memorymap the whole file
+ if (CompressedBuffer::ValidateCompressedHeader(Buffer, RawHash, RawSize))
+ {
+ if (RawHash != Hash)
+ {
+ // Hash mismatch
+ BadKeys.push_back(Hash);
+ return;
+ }
+ return;
+ }
+#if ZEN_WITH_TESTS
+ IoHashStream Hasher;
+ File.StreamByteRange(Offset, Size, [&](const void* Data, size_t Size) { Hasher.Append(Data, Size); });
+ IoHash ComputedHash = Hasher.GetHash();
+ if (ComputedHash == Hash)
+ {
+ return;
+ }
+#endif
+ BadKeys.push_back(Hash);
+ };
+
+ m_BlockStore.IterateChunks(ChunkLocations, ValidateSmallChunk, ValidateLargeChunk);
+
+ _.ReleaseNow();
+
+ Ctx.ReportScrubbed(ChunkCount, ChunkBytes);
+
+ if (!BadKeys.empty())
+ {
+ ZEN_WARN("Scrubbing found {} bad chunks in '{}'", BadKeys.size(), m_RootDirectory / m_ContainerBaseName);
+
+ if (Ctx.RunRecovery())
+ {
+ // Deal with bad chunks by removing them from our lookup map
+
+ std::vector<CasDiskIndexEntry> LogEntries;
+ LogEntries.reserve(BadKeys.size());
+ {
+ RwLock::ExclusiveLockScope __(m_LocationMapLock);
+ for (const IoHash& ChunkHash : BadKeys)
+ {
+ const auto KeyIt = m_LocationMap.find(ChunkHash);
+ if (KeyIt == m_LocationMap.end())
+ {
+ // Might have been GC'd
+ continue;
+ }
+ LogEntries.push_back({.Key = KeyIt->first, .Location = KeyIt->second, .Flags = CasDiskIndexEntry::kTombstone});
+ m_LocationMap.erase(KeyIt);
+ }
+ }
+ m_CasLog.Append(LogEntries);
+ }
+ }
+
+ // Let whomever it concerns know about the bad chunks. This could
+ // be used to invalidate higher level data structures more efficiently
+ // than a full validation pass might be able to do
+ Ctx.ReportBadCidChunks(BadKeys);
+
+ ZEN_INFO("compact cas scrubbed: {} chunks ({})", ChunkCount, NiceBytes(ChunkBytes));
+}
+
+void
+CasContainerStrategy::CollectGarbage(GcContext& GcCtx)
+{
+ // It collects all the blocks that we want to delete chunks from. For each such
+ // block we keep a list of chunks to retain and a list of chunks to delete.
+ //
+ // If there is a block that we are currently writing to, that block is omitted
+ // from the garbage collection.
+ //
+ // Next it will iterate over all blocks that we want to remove chunks from.
+ // If the block is empty after removal of chunks we mark the block as pending
+ // delete - we want to delete it as soon as there are no IoBuffers using the
+ // block file.
+ // Once complete we update the m_LocationMap by removing the chunks.
+ //
+ // If the block is non-empty we write out the chunks we want to keep to a new
+ // block file (creating new block files as needed).
+ //
+ // We update the index as we complete each new block file. This makes it possible
+ // to break the GC if we want to limit time for execution.
+ //
+ // GC can very parallell to regular operation - it will block while taking
+ // a snapshot of the current m_LocationMap state and while moving blocks it will
+ // do a blocking operation and update the m_LocationMap after each new block is
+ // written and figuring out the path to the next new block.
+
+ ZEN_DEBUG("collecting garbage from '{}'", m_RootDirectory / m_ContainerBaseName);
+
+ uint64_t WriteBlockTimeUs = 0;
+ uint64_t WriteBlockLongestTimeUs = 0;
+ uint64_t ReadBlockTimeUs = 0;
+ uint64_t ReadBlockLongestTimeUs = 0;
+
+ LocationMap_t LocationMap;
+ BlockStore::ReclaimSnapshotState BlockStoreState;
+ {
+ RwLock::SharedLockScope ___(m_LocationMapLock);
+ Stopwatch Timer;
+ const auto ____ = MakeGuard([&Timer, &WriteBlockTimeUs, &WriteBlockLongestTimeUs] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ WriteBlockTimeUs += ElapsedUs;
+ WriteBlockLongestTimeUs = std::max(ElapsedUs, WriteBlockLongestTimeUs);
+ });
+ LocationMap = m_LocationMap;
+ BlockStoreState = m_BlockStore.GetReclaimSnapshotState();
+ }
+
+ uint64_t TotalChunkCount = LocationMap.size();
+
+ std::vector<IoHash> TotalChunkHashes;
+ TotalChunkHashes.reserve(TotalChunkCount);
+ for (const auto& Entry : LocationMap)
+ {
+ TotalChunkHashes.push_back(Entry.first);
+ }
+
+ std::vector<BlockStoreLocation> ChunkLocations;
+ BlockStore::ChunkIndexArray KeepChunkIndexes;
+ std::vector<IoHash> ChunkIndexToChunkHash;
+ ChunkLocations.reserve(TotalChunkCount);
+ ChunkIndexToChunkHash.reserve(TotalChunkCount);
+
+ GcCtx.FilterCids(TotalChunkHashes, [&](const IoHash& ChunkHash, bool Keep) {
+ auto KeyIt = LocationMap.find(ChunkHash);
+ const BlockStoreDiskLocation& DiskLocation = KeyIt->second;
+ BlockStoreLocation Location = DiskLocation.Get(m_PayloadAlignment);
+ size_t ChunkIndex = ChunkLocations.size();
+
+ ChunkLocations.push_back(Location);
+ ChunkIndexToChunkHash[ChunkIndex] = ChunkHash;
+ if (Keep)
+ {
+ KeepChunkIndexes.push_back(ChunkIndex);
+ }
+ });
+
+ const bool PerformDelete = GcCtx.IsDeletionMode() && GcCtx.CollectSmallObjects();
+ if (!PerformDelete)
+ {
+ m_BlockStore.ReclaimSpace(BlockStoreState, ChunkLocations, KeepChunkIndexes, m_PayloadAlignment, true);
+ return;
+ }
+
+ std::vector<IoHash> DeletedChunks;
+ m_BlockStore.ReclaimSpace(
+ BlockStoreState,
+ ChunkLocations,
+ KeepChunkIndexes,
+ m_PayloadAlignment,
+ false,
+ [&](const BlockStore::MovedChunksArray& MovedChunks, const BlockStore::ChunkIndexArray& RemovedChunks) {
+ std::vector<CasDiskIndexEntry> LogEntries;
+ LogEntries.reserve(MovedChunks.size() + RemovedChunks.size());
+ for (const auto& Entry : MovedChunks)
+ {
+ size_t ChunkIndex = Entry.first;
+ const BlockStoreLocation& NewLocation = Entry.second;
+ const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex];
+ LogEntries.push_back({.Key = ChunkHash, .Location = {NewLocation, m_PayloadAlignment}});
+ }
+ for (const size_t ChunkIndex : RemovedChunks)
+ {
+ const IoHash& ChunkHash = ChunkIndexToChunkHash[ChunkIndex];
+ const BlockStoreDiskLocation& OldDiskLocation = LocationMap[ChunkHash];
+ LogEntries.push_back({.Key = ChunkHash, .Location = OldDiskLocation, .Flags = CasDiskIndexEntry::kTombstone});
+ DeletedChunks.push_back(ChunkHash);
+ }
+
+ m_CasLog.Append(LogEntries);
+ m_CasLog.Flush();
+ {
+ RwLock::ExclusiveLockScope __(m_LocationMapLock);
+ Stopwatch Timer;
+ const auto ____ = MakeGuard([&] {
+ uint64_t ElapsedUs = Timer.GetElapsedTimeUs();
+ ReadBlockTimeUs += ElapsedUs;
+ ReadBlockLongestTimeUs = std::max(ElapsedUs, ReadBlockLongestTimeUs);
+ });
+ for (const CasDiskIndexEntry& Entry : LogEntries)
+ {
+ if (Entry.Flags & CasDiskIndexEntry::kTombstone)
+ {
+ m_LocationMap.erase(Entry.Key);
+ continue;
+ }
+ m_LocationMap[Entry.Key] = Entry.Location;
+ }
+ }
+ },
+ [&GcCtx]() { return GcCtx.CollectSmallObjects(); });
+
+ GcCtx.AddDeletedCids(DeletedChunks);
+}
+
+void
+CasContainerStrategy::MakeIndexSnapshot()
+{
+ uint64_t LogCount = m_CasLog.GetLogCount();
+ if (m_LogFlushPosition == LogCount)
+ {
+ return;
+ }
+
+ ZEN_DEBUG("write store snapshot for '{}'", m_RootDirectory / m_ContainerBaseName);
+ uint64_t EntryCount = 0;
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("wrote store snapshot for '{}' containing {} entries in {}",
+ m_RootDirectory / m_ContainerBaseName,
+ EntryCount,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ namespace fs = std::filesystem;
+
+ fs::path IndexPath = GetIndexPath(m_RootDirectory, m_ContainerBaseName);
+ fs::path TempIndexPath = GetTempIndexPath(m_RootDirectory, m_ContainerBaseName);
+
+ // Move index away, we keep it if something goes wrong
+ if (fs::is_regular_file(TempIndexPath))
+ {
+ fs::remove(TempIndexPath);
+ }
+ if (fs::is_regular_file(IndexPath))
+ {
+ fs::rename(IndexPath, TempIndexPath);
+ }
+
+ try
+ {
+ // Write the current state of the location map to a new index state
+ std::vector<CasDiskIndexEntry> Entries;
+
+ {
+ RwLock::SharedLockScope ___(m_LocationMapLock);
+ Entries.resize(m_LocationMap.size());
+
+ uint64_t EntryIndex = 0;
+ for (auto& Entry : m_LocationMap)
+ {
+ CasDiskIndexEntry& IndexEntry = Entries[EntryIndex++];
+ IndexEntry.Key = Entry.first;
+ IndexEntry.Location = Entry.second;
+ }
+ }
+
+ BasicFile ObjectIndexFile;
+ ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kTruncate);
+ CasDiskIndexHeader Header = {.EntryCount = Entries.size(),
+ .LogPosition = LogCount,
+ .PayloadAlignment = gsl::narrow<uint32_t>(m_PayloadAlignment)};
+
+ Header.Checksum = CasDiskIndexHeader::ComputeChecksum(Header);
+
+ ObjectIndexFile.Write(&Header, sizeof(CasDiskIndexEntry), 0);
+ ObjectIndexFile.Write(Entries.data(), Entries.size() * sizeof(CasDiskIndexEntry), sizeof(CasDiskIndexEntry));
+ ObjectIndexFile.Flush();
+ ObjectIndexFile.Close();
+ EntryCount = Entries.size();
+ m_LogFlushPosition = LogCount;
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("snapshot FAILED, reason: '{}'", Err.what());
+
+ // Restore any previous snapshot
+
+ if (fs::is_regular_file(TempIndexPath))
+ {
+ fs::remove(IndexPath);
+ fs::rename(TempIndexPath, IndexPath);
+ }
+ }
+ if (fs::is_regular_file(TempIndexPath))
+ {
+ fs::remove(TempIndexPath);
+ }
+}
+
+uint64_t
+CasContainerStrategy::ReadIndexFile()
+{
+ std::vector<CasDiskIndexEntry> Entries;
+ std::filesystem::path IndexPath = GetIndexPath(m_RootDirectory, m_ContainerBaseName);
+ if (std::filesystem::is_regular_file(IndexPath))
+ {
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("read store '{}' index containing {} entries in {}",
+ IndexPath,
+ Entries.size(),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ BasicFile ObjectIndexFile;
+ ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kRead);
+ uint64_t Size = ObjectIndexFile.FileSize();
+ if (Size >= sizeof(CasDiskIndexHeader))
+ {
+ uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(CasDiskIndexHeader))) / sizeof(CasDiskIndexEntry);
+ CasDiskIndexHeader Header;
+ ObjectIndexFile.Read(&Header, sizeof(Header), 0);
+ if ((Header.Magic == CasDiskIndexHeader::ExpectedMagic) && (Header.Version == CasDiskIndexHeader::CurrentVersion) &&
+ (Header.Checksum == CasDiskIndexHeader::ComputeChecksum(Header)) && (Header.PayloadAlignment > 0) &&
+ (Header.EntryCount <= ExpectedEntryCount))
+ {
+ Entries.resize(Header.EntryCount);
+ ObjectIndexFile.Read(Entries.data(), Header.EntryCount * sizeof(CasDiskIndexEntry), sizeof(CasDiskIndexHeader));
+ m_PayloadAlignment = Header.PayloadAlignment;
+
+ std::string InvalidEntryReason;
+ for (const CasDiskIndexEntry& Entry : Entries)
+ {
+ if (!ValidateEntry(Entry, InvalidEntryReason))
+ {
+ ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", IndexPath, InvalidEntryReason);
+ continue;
+ }
+ m_LocationMap[Entry.Key] = Entry.Location;
+ }
+
+ return Header.LogPosition;
+ }
+ else
+ {
+ ZEN_WARN("skipping invalid index file '{}'", IndexPath);
+ }
+ }
+ }
+ return 0;
+}
+
+uint64_t
+CasContainerStrategy::ReadLog(uint64_t SkipEntryCount)
+{
+ std::filesystem::path LogPath = GetLogPath(m_RootDirectory, m_ContainerBaseName);
+ if (std::filesystem::is_regular_file(LogPath))
+ {
+ size_t LogEntryCount = 0;
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("read store '{}' log containing {} entries in {}",
+ m_RootDirectory / m_ContainerBaseName,
+ LogEntryCount,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ TCasLogFile<CasDiskIndexEntry> CasLog;
+ CasLog.Open(LogPath, CasLogFile::Mode::kRead);
+ if (CasLog.Initialize())
+ {
+ uint64_t EntryCount = CasLog.GetLogCount();
+ if (EntryCount < SkipEntryCount)
+ {
+ ZEN_WARN("reading full log at '{}', reason: Log position from index snapshot is out of range", LogPath);
+ SkipEntryCount = 0;
+ }
+ LogEntryCount = EntryCount - SkipEntryCount;
+ CasLog.Replay(
+ [&](const CasDiskIndexEntry& Record) {
+ LogEntryCount++;
+ std::string InvalidEntryReason;
+ if (Record.Flags & CasDiskIndexEntry::kTombstone)
+ {
+ m_LocationMap.erase(Record.Key);
+ return;
+ }
+ if (!ValidateEntry(Record, InvalidEntryReason))
+ {
+ ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", LogPath, InvalidEntryReason);
+ return;
+ }
+ m_LocationMap[Record.Key] = Record.Location;
+ },
+ SkipEntryCount);
+ return LogEntryCount;
+ }
+ }
+ return 0;
+}
+
+void
+CasContainerStrategy::OpenContainer(bool IsNewStore)
+{
+ // Add .running file and delete on clean on close to detect bad termination
+
+ m_LocationMap.clear();
+
+ std::filesystem::path BasePath = GetBasePath(m_RootDirectory, m_ContainerBaseName);
+
+ if (IsNewStore)
+ {
+ std::filesystem::remove_all(BasePath);
+ }
+
+ m_LogFlushPosition = ReadIndexFile();
+ uint64_t LogEntryCount = ReadLog(m_LogFlushPosition);
+
+ CreateDirectories(BasePath);
+
+ std::filesystem::path LogPath = GetLogPath(m_RootDirectory, m_ContainerBaseName);
+ m_CasLog.Open(LogPath, CasLogFile::Mode::kWrite);
+
+ std::vector<BlockStoreLocation> KnownLocations;
+ KnownLocations.reserve(m_LocationMap.size());
+ for (const auto& Entry : m_LocationMap)
+ {
+ const BlockStoreDiskLocation& Location = Entry.second;
+ KnownLocations.push_back(Location.Get(m_PayloadAlignment));
+ }
+
+ m_BlockStore.Initialize(m_BlocksBasePath, m_MaxBlockSize, BlockStoreDiskLocation::MaxBlockIndex + 1, KnownLocations);
+
+ if (IsNewStore || (LogEntryCount > 0))
+ {
+ MakeIndexSnapshot();
+ }
+
+ // TODO: should validate integrity of container files here
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+namespace {
+ static IoBuffer CreateRandomChunk(uint64_t Size)
+ {
+ static std::random_device rd;
+ static std::mt19937 g(rd());
+
+ std::vector<uint8_t> Values;
+ Values.resize(Size);
+ for (size_t Idx = 0; Idx < Size; ++Idx)
+ {
+ Values[Idx] = static_cast<uint8_t>(Idx);
+ }
+ std::shuffle(Values.begin(), Values.end(), g);
+
+ return IoBufferBuilder::MakeCloneFromMemory(Values.data(), Values.size());
+ }
+} // namespace
+
+TEST_CASE("compactcas.hex")
+{
+ uint32_t Value;
+ std::string HexString;
+ CHECK(!ParseHexNumber("", Value));
+ char Hex[9];
+
+ ToHexNumber(0u, Hex);
+ HexString = std::string(Hex);
+ CHECK(ParseHexNumber(HexString, Value));
+ CHECK(Value == 0u);
+
+ ToHexNumber(std::numeric_limits<std::uint32_t>::max(), Hex);
+ HexString = std::string(Hex);
+ CHECK(HexString == "ffffffff");
+ CHECK(ParseHexNumber(HexString, Value));
+ CHECK(Value == std::numeric_limits<std::uint32_t>::max());
+
+ ToHexNumber(0xadf14711u, Hex);
+ HexString = std::string(Hex);
+ CHECK(HexString == "adf14711");
+ CHECK(ParseHexNumber(HexString, Value));
+ CHECK(Value == 0xadf14711u);
+
+ ToHexNumber(0x80000000u, Hex);
+ HexString = std::string(Hex);
+ CHECK(HexString == "80000000");
+ CHECK(ParseHexNumber(HexString, Value));
+ CHECK(Value == 0x80000000u);
+
+ ToHexNumber(0x718293a4u, Hex);
+ HexString = std::string(Hex);
+ CHECK(HexString == "718293a4");
+ CHECK(ParseHexNumber(HexString, Value));
+ CHECK(Value == 0x718293a4u);
+}
+
+TEST_CASE("compactcas.compact.gc")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ const int kIterationCount = 1000;
+
+ std::vector<IoHash> Keys(kIterationCount);
+
+ {
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "test", 65536, 16, true);
+
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ CbObjectWriter Cbo;
+ Cbo << "id" << i;
+ CbObject Obj = Cbo.Save();
+
+ IoBuffer ObjBuffer = Obj.GetBuffer().AsIoBuffer();
+ const IoHash Hash = HashBuffer(ObjBuffer);
+
+ Cas.InsertChunk(ObjBuffer, Hash);
+
+ Keys[i] = Hash;
+ }
+
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ IoBuffer Chunk = Cas.FindChunk(Keys[i]);
+
+ CHECK(!!Chunk);
+
+ CbObject Value = LoadCompactBinaryObject(Chunk);
+
+ CHECK_EQ(Value["id"].AsInt32(), i);
+ }
+ }
+
+ // Validate that we can still read the inserted data after closing
+ // the original cas store
+
+ {
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "test", 65536, 16, false);
+
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ IoBuffer Chunk = Cas.FindChunk(Keys[i]);
+
+ CHECK(!!Chunk);
+
+ CbObject Value = LoadCompactBinaryObject(Chunk);
+
+ CHECK_EQ(Value["id"].AsInt32(), i);
+ }
+ }
+}
+
+TEST_CASE("compactcas.compact.totalsize")
+{
+ std::random_device rd;
+ std::mt19937 g(rd());
+
+ // for (uint32_t i = 0; i < 100; ++i)
+ {
+ ScopedTemporaryDirectory TempDir;
+
+ const uint64_t kChunkSize = 1024;
+ const int32_t kChunkCount = 16;
+
+ {
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "test", 65536, 16, true);
+
+ for (int32_t Idx = 0; Idx < kChunkCount; ++Idx)
+ {
+ IoBuffer Chunk = CreateRandomChunk(kChunkSize);
+ const IoHash Hash = HashBuffer(Chunk);
+ CasStore::InsertResult InsertResult = Cas.InsertChunk(Chunk, Hash);
+ ZEN_ASSERT(InsertResult.New);
+ }
+
+ const uint64_t TotalSize = Cas.StorageSize().DiskSize;
+ CHECK_EQ(kChunkSize * kChunkCount, TotalSize);
+ }
+
+ {
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "test", 65536, 16, false);
+
+ const uint64_t TotalSize = Cas.StorageSize().DiskSize;
+ CHECK_EQ(kChunkSize * kChunkCount, TotalSize);
+ }
+
+ // Re-open again, this time we should have a snapshot
+ {
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "test", 65536, 16, false);
+
+ const uint64_t TotalSize = Cas.StorageSize().DiskSize;
+ CHECK_EQ(kChunkSize * kChunkCount, TotalSize);
+ }
+ }
+}
+
+TEST_CASE("compactcas.gc.basic")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "cb", 65536, 1 << 4, true);
+
+ IoBuffer Chunk = CreateRandomChunk(128);
+ IoHash ChunkHash = IoHash::HashBuffer(Chunk);
+
+ const CasStore::InsertResult InsertResult = Cas.InsertChunk(Chunk, ChunkHash);
+ CHECK(InsertResult.New);
+ Cas.Flush();
+
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+
+ Cas.CollectGarbage(GcCtx);
+
+ CHECK(!Cas.HaveChunk(ChunkHash));
+}
+
+TEST_CASE("compactcas.gc.removefile")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ IoBuffer Chunk = CreateRandomChunk(128);
+ IoHash ChunkHash = IoHash::HashBuffer(Chunk);
+ {
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "cb", 65536, 1 << 4, true);
+
+ const CasStore::InsertResult InsertResult = Cas.InsertChunk(Chunk, ChunkHash);
+ CHECK(InsertResult.New);
+ const CasStore::InsertResult InsertResultDup = Cas.InsertChunk(Chunk, ChunkHash);
+ CHECK(!InsertResultDup.New);
+ Cas.Flush();
+ }
+
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "cb", 65536, 1 << 4, false);
+
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+
+ Cas.CollectGarbage(GcCtx);
+
+ CHECK(!Cas.HaveChunk(ChunkHash));
+}
+
+TEST_CASE("compactcas.gc.compact")
+{
+ // for (uint32_t i = 0; i < 100; ++i)
+ {
+ ScopedTemporaryDirectory TempDir;
+
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "cb", 2048, 1 << 4, true);
+
+ uint64_t ChunkSizes[9] = {128, 541, 1023, 781, 218, 37, 4, 997, 5};
+ std::vector<IoBuffer> Chunks;
+ Chunks.reserve(9);
+ for (uint64_t Size : ChunkSizes)
+ {
+ Chunks.push_back(CreateRandomChunk(Size));
+ }
+
+ std::vector<IoHash> ChunkHashes;
+ ChunkHashes.reserve(9);
+ for (const IoBuffer& Chunk : Chunks)
+ {
+ ChunkHashes.push_back(IoHash::HashBuffer(Chunk.Data(), Chunk.Size()));
+ }
+
+ CHECK(Cas.InsertChunk(Chunks[0], ChunkHashes[0]).New);
+ CHECK(Cas.InsertChunk(Chunks[1], ChunkHashes[1]).New);
+ CHECK(Cas.InsertChunk(Chunks[2], ChunkHashes[2]).New);
+ CHECK(Cas.InsertChunk(Chunks[3], ChunkHashes[3]).New);
+ CHECK(Cas.InsertChunk(Chunks[4], ChunkHashes[4]).New);
+ CHECK(Cas.InsertChunk(Chunks[5], ChunkHashes[5]).New);
+ CHECK(Cas.InsertChunk(Chunks[6], ChunkHashes[6]).New);
+ CHECK(Cas.InsertChunk(Chunks[7], ChunkHashes[7]).New);
+ CHECK(Cas.InsertChunk(Chunks[8], ChunkHashes[8]).New);
+
+ CHECK(Cas.HaveChunk(ChunkHashes[0]));
+ CHECK(Cas.HaveChunk(ChunkHashes[1]));
+ CHECK(Cas.HaveChunk(ChunkHashes[2]));
+ CHECK(Cas.HaveChunk(ChunkHashes[3]));
+ CHECK(Cas.HaveChunk(ChunkHashes[4]));
+ CHECK(Cas.HaveChunk(ChunkHashes[5]));
+ CHECK(Cas.HaveChunk(ChunkHashes[6]));
+ CHECK(Cas.HaveChunk(ChunkHashes[7]));
+ CHECK(Cas.HaveChunk(ChunkHashes[8]));
+
+ // Keep first and last
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+
+ std::vector<IoHash> KeepChunks;
+ KeepChunks.push_back(ChunkHashes[0]);
+ KeepChunks.push_back(ChunkHashes[8]);
+ GcCtx.AddRetainedCids(KeepChunks);
+
+ Cas.Flush();
+ Cas.CollectGarbage(GcCtx);
+
+ CHECK(Cas.HaveChunk(ChunkHashes[0]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[1]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[2]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[3]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[4]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[5]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[6]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[7]));
+ CHECK(Cas.HaveChunk(ChunkHashes[8]));
+
+ CHECK(ChunkHashes[0] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[0])));
+ CHECK(ChunkHashes[8] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[8])));
+
+ Cas.InsertChunk(Chunks[1], ChunkHashes[1]);
+ Cas.InsertChunk(Chunks[2], ChunkHashes[2]);
+ Cas.InsertChunk(Chunks[3], ChunkHashes[3]);
+ Cas.InsertChunk(Chunks[4], ChunkHashes[4]);
+ Cas.InsertChunk(Chunks[5], ChunkHashes[5]);
+ Cas.InsertChunk(Chunks[6], ChunkHashes[6]);
+ Cas.InsertChunk(Chunks[7], ChunkHashes[7]);
+ }
+
+ // Keep last
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ std::vector<IoHash> KeepChunks;
+ KeepChunks.push_back(ChunkHashes[8]);
+ GcCtx.AddRetainedCids(KeepChunks);
+
+ Cas.Flush();
+ Cas.CollectGarbage(GcCtx);
+
+ CHECK(!Cas.HaveChunk(ChunkHashes[0]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[1]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[2]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[3]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[4]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[5]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[6]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[7]));
+ CHECK(Cas.HaveChunk(ChunkHashes[8]));
+
+ CHECK(ChunkHashes[8] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[8])));
+
+ Cas.InsertChunk(Chunks[1], ChunkHashes[1]);
+ Cas.InsertChunk(Chunks[2], ChunkHashes[2]);
+ Cas.InsertChunk(Chunks[3], ChunkHashes[3]);
+ Cas.InsertChunk(Chunks[4], ChunkHashes[4]);
+ Cas.InsertChunk(Chunks[5], ChunkHashes[5]);
+ Cas.InsertChunk(Chunks[6], ChunkHashes[6]);
+ Cas.InsertChunk(Chunks[7], ChunkHashes[7]);
+ }
+
+ // Keep mixed
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ std::vector<IoHash> KeepChunks;
+ KeepChunks.push_back(ChunkHashes[1]);
+ KeepChunks.push_back(ChunkHashes[4]);
+ KeepChunks.push_back(ChunkHashes[7]);
+ GcCtx.AddRetainedCids(KeepChunks);
+
+ Cas.Flush();
+ Cas.CollectGarbage(GcCtx);
+
+ CHECK(!Cas.HaveChunk(ChunkHashes[0]));
+ CHECK(Cas.HaveChunk(ChunkHashes[1]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[2]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[3]));
+ CHECK(Cas.HaveChunk(ChunkHashes[4]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[5]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[6]));
+ CHECK(Cas.HaveChunk(ChunkHashes[7]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[8]));
+
+ CHECK(ChunkHashes[1] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[1])));
+ CHECK(ChunkHashes[4] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[4])));
+ CHECK(ChunkHashes[7] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[7])));
+
+ Cas.InsertChunk(Chunks[0], ChunkHashes[0]);
+ Cas.InsertChunk(Chunks[2], ChunkHashes[2]);
+ Cas.InsertChunk(Chunks[3], ChunkHashes[3]);
+ Cas.InsertChunk(Chunks[5], ChunkHashes[5]);
+ Cas.InsertChunk(Chunks[6], ChunkHashes[6]);
+ Cas.InsertChunk(Chunks[8], ChunkHashes[8]);
+ }
+
+ // Keep multiple at end
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ std::vector<IoHash> KeepChunks;
+ KeepChunks.push_back(ChunkHashes[6]);
+ KeepChunks.push_back(ChunkHashes[7]);
+ KeepChunks.push_back(ChunkHashes[8]);
+ GcCtx.AddRetainedCids(KeepChunks);
+
+ Cas.Flush();
+ Cas.CollectGarbage(GcCtx);
+
+ CHECK(!Cas.HaveChunk(ChunkHashes[0]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[1]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[2]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[3]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[4]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[5]));
+ CHECK(Cas.HaveChunk(ChunkHashes[6]));
+ CHECK(Cas.HaveChunk(ChunkHashes[7]));
+ CHECK(Cas.HaveChunk(ChunkHashes[8]));
+
+ CHECK(ChunkHashes[6] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[6])));
+ CHECK(ChunkHashes[7] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[7])));
+ CHECK(ChunkHashes[8] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[8])));
+
+ Cas.InsertChunk(Chunks[0], ChunkHashes[0]);
+ Cas.InsertChunk(Chunks[1], ChunkHashes[1]);
+ Cas.InsertChunk(Chunks[2], ChunkHashes[2]);
+ Cas.InsertChunk(Chunks[3], ChunkHashes[3]);
+ Cas.InsertChunk(Chunks[4], ChunkHashes[4]);
+ Cas.InsertChunk(Chunks[5], ChunkHashes[5]);
+ }
+
+ // Keep every other
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ std::vector<IoHash> KeepChunks;
+ KeepChunks.push_back(ChunkHashes[0]);
+ KeepChunks.push_back(ChunkHashes[2]);
+ KeepChunks.push_back(ChunkHashes[4]);
+ KeepChunks.push_back(ChunkHashes[6]);
+ KeepChunks.push_back(ChunkHashes[8]);
+ GcCtx.AddRetainedCids(KeepChunks);
+
+ Cas.Flush();
+ Cas.CollectGarbage(GcCtx);
+
+ CHECK(Cas.HaveChunk(ChunkHashes[0]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[1]));
+ CHECK(Cas.HaveChunk(ChunkHashes[2]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[3]));
+ CHECK(Cas.HaveChunk(ChunkHashes[4]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[5]));
+ CHECK(Cas.HaveChunk(ChunkHashes[6]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[7]));
+ CHECK(Cas.HaveChunk(ChunkHashes[8]));
+
+ CHECK(ChunkHashes[0] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[0])));
+ CHECK(ChunkHashes[2] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[2])));
+ CHECK(ChunkHashes[4] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[4])));
+ CHECK(ChunkHashes[6] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[6])));
+ CHECK(ChunkHashes[8] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[8])));
+
+ Cas.InsertChunk(Chunks[1], ChunkHashes[1]);
+ Cas.InsertChunk(Chunks[3], ChunkHashes[3]);
+ Cas.InsertChunk(Chunks[5], ChunkHashes[5]);
+ Cas.InsertChunk(Chunks[7], ChunkHashes[7]);
+ }
+
+ // Verify that we nicely appended blocks even after all GC operations
+ CHECK(ChunkHashes[0] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[0])));
+ CHECK(ChunkHashes[1] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[1])));
+ CHECK(ChunkHashes[2] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[2])));
+ CHECK(ChunkHashes[3] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[3])));
+ CHECK(ChunkHashes[4] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[4])));
+ CHECK(ChunkHashes[5] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[5])));
+ CHECK(ChunkHashes[6] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[6])));
+ CHECK(ChunkHashes[7] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[7])));
+ CHECK(ChunkHashes[8] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[8])));
+ }
+}
+
+TEST_CASE("compactcas.gc.deleteblockonopen")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ uint64_t ChunkSizes[20] = {128, 541, 311, 181, 218, 37, 4, 397, 5, 92, 551, 721, 31, 92, 16, 99, 131, 41, 541, 84};
+ std::vector<IoBuffer> Chunks;
+ Chunks.reserve(20);
+ for (uint64_t Size : ChunkSizes)
+ {
+ Chunks.push_back(CreateRandomChunk(Size));
+ }
+
+ std::vector<IoHash> ChunkHashes;
+ ChunkHashes.reserve(20);
+ for (const IoBuffer& Chunk : Chunks)
+ {
+ ChunkHashes.push_back(IoHash::HashBuffer(Chunk.Data(), Chunk.Size()));
+ }
+
+ {
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "test", 1024, 16, true);
+
+ for (size_t i = 0; i < 20; i++)
+ {
+ CHECK(Cas.InsertChunk(Chunks[i], ChunkHashes[i]).New);
+ }
+
+ // GC every other block
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ std::vector<IoHash> KeepChunks;
+ for (size_t i = 0; i < 20; i += 2)
+ {
+ KeepChunks.push_back(ChunkHashes[i]);
+ }
+ GcCtx.AddRetainedCids(KeepChunks);
+
+ Cas.Flush();
+ Cas.CollectGarbage(GcCtx);
+
+ for (size_t i = 0; i < 20; i += 2)
+ {
+ CHECK(Cas.HaveChunk(ChunkHashes[i]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[i + 1]));
+ CHECK(ChunkHashes[i] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[i])));
+ }
+ }
+ }
+ {
+ // Re-open
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "test", 1024, 16, false);
+
+ for (size_t i = 0; i < 20; i += 2)
+ {
+ CHECK(Cas.HaveChunk(ChunkHashes[i]));
+ CHECK(!Cas.HaveChunk(ChunkHashes[i + 1]));
+ CHECK(ChunkHashes[i] == IoHash::HashBuffer(Cas.FindChunk(ChunkHashes[i])));
+ }
+ }
+}
+
+TEST_CASE("compactcas.gc.handleopeniobuffer")
+{
+ ScopedTemporaryDirectory TempDir;
+
+ uint64_t ChunkSizes[20] = {128, 541, 311, 181, 218, 37, 4, 397, 5, 92, 551, 721, 31, 92, 16, 99, 131, 41, 541, 84};
+ std::vector<IoBuffer> Chunks;
+ Chunks.reserve(20);
+ for (const uint64_t& Size : ChunkSizes)
+ {
+ Chunks.push_back(CreateRandomChunk(Size));
+ }
+
+ std::vector<IoHash> ChunkHashes;
+ ChunkHashes.reserve(20);
+ for (const IoBuffer& Chunk : Chunks)
+ {
+ ChunkHashes.push_back(IoHash::HashBuffer(Chunk.Data(), Chunk.Size()));
+ }
+
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "test", 1024, 16, true);
+
+ for (size_t i = 0; i < 20; i++)
+ {
+ CHECK(Cas.InsertChunk(Chunks[i], ChunkHashes[i]).New);
+ }
+
+ IoBuffer RetainChunk = Cas.FindChunk(ChunkHashes[5]);
+ Cas.Flush();
+
+ // GC everything
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ Cas.CollectGarbage(GcCtx);
+
+ for (size_t i = 0; i < 20; i++)
+ {
+ CHECK(!Cas.HaveChunk(ChunkHashes[i]));
+ }
+
+ CHECK(ChunkHashes[5] == IoHash::HashBuffer(RetainChunk));
+}
+
+TEST_CASE("compactcas.threadedinsert")
+{
+ // for (uint32_t i = 0; i < 100; ++i)
+ {
+ ScopedTemporaryDirectory TempDir;
+
+ const uint64_t kChunkSize = 1048;
+ const int32_t kChunkCount = 4096;
+ uint64_t ExpectedSize = 0;
+
+ std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> Chunks;
+ Chunks.reserve(kChunkCount);
+
+ for (int32_t Idx = 0; Idx < kChunkCount; ++Idx)
+ {
+ while (true)
+ {
+ IoBuffer Chunk = CreateRandomChunk(kChunkSize);
+ IoHash Hash = HashBuffer(Chunk);
+ if (Chunks.contains(Hash))
+ {
+ continue;
+ }
+ Chunks[Hash] = Chunk;
+ ExpectedSize += Chunk.Size();
+ break;
+ }
+ }
+
+ std::atomic<size_t> WorkCompleted = 0;
+ WorkerThreadPool ThreadPool(4);
+ GcManager Gc;
+ CasContainerStrategy Cas(Gc);
+ Cas.Initialize(TempDir.Path(), "test", 32768, 16, true);
+ {
+ for (const auto& Chunk : Chunks)
+ {
+ const IoHash& Hash = Chunk.first;
+ const IoBuffer& Buffer = Chunk.second;
+ ThreadPool.ScheduleWork([&Cas, &WorkCompleted, Buffer, Hash]() {
+ CasStore::InsertResult InsertResult = Cas.InsertChunk(Buffer, Hash);
+ ZEN_ASSERT(InsertResult.New);
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < Chunks.size())
+ {
+ Sleep(1);
+ }
+ }
+
+ WorkCompleted = 0;
+ const uint64_t TotalSize = Cas.StorageSize().DiskSize;
+ CHECK_LE(ExpectedSize, TotalSize);
+ CHECK_GE(ExpectedSize + 32768, TotalSize);
+
+ {
+ for (const auto& Chunk : Chunks)
+ {
+ ThreadPool.ScheduleWork([&Cas, &WorkCompleted, &Chunk]() {
+ IoHash ChunkHash = Chunk.first;
+ IoBuffer Buffer = Cas.FindChunk(ChunkHash);
+ IoHash Hash = IoHash::HashBuffer(Buffer);
+ CHECK(ChunkHash == Hash);
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < Chunks.size())
+ {
+ Sleep(1);
+ }
+ }
+
+ std::unordered_set<IoHash, IoHash::Hasher> GcChunkHashes;
+ GcChunkHashes.reserve(Chunks.size());
+ for (const auto& Chunk : Chunks)
+ {
+ GcChunkHashes.insert(Chunk.first);
+ }
+ {
+ WorkCompleted = 0;
+ std::unordered_map<IoHash, IoBuffer, IoHash::Hasher> NewChunks;
+ NewChunks.reserve(kChunkCount);
+
+ for (int32_t Idx = 0; Idx < kChunkCount; ++Idx)
+ {
+ IoBuffer Chunk = CreateRandomChunk(kChunkSize);
+ IoHash Hash = HashBuffer(Chunk);
+ NewChunks[Hash] = Chunk;
+ }
+
+ std::atomic_uint32_t AddedChunkCount;
+
+ for (const auto& Chunk : NewChunks)
+ {
+ ThreadPool.ScheduleWork([&Cas, &WorkCompleted, Chunk, &AddedChunkCount]() {
+ Cas.InsertChunk(Chunk.second, Chunk.first);
+ AddedChunkCount.fetch_add(1);
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ for (const auto& Chunk : Chunks)
+ {
+ ThreadPool.ScheduleWork([&Cas, &WorkCompleted, Chunk]() {
+ IoHash ChunkHash = Chunk.first;
+ IoBuffer Buffer = Cas.FindChunk(ChunkHash);
+ if (Buffer)
+ {
+ CHECK(ChunkHash == IoHash::HashBuffer(Buffer));
+ }
+ WorkCompleted.fetch_add(1);
+ });
+ }
+
+ while (AddedChunkCount.load() < NewChunks.size())
+ {
+ // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope
+ for (const auto& Chunk : NewChunks)
+ {
+ if (Cas.HaveChunk(Chunk.first))
+ {
+ GcChunkHashes.emplace(Chunk.first);
+ }
+ }
+ std::vector<IoHash> KeepHashes(GcChunkHashes.begin(), GcChunkHashes.end());
+ size_t C = 0;
+ while (C < KeepHashes.size())
+ {
+ if (C % 155 == 0)
+ {
+ if (C < KeepHashes.size() - 1)
+ {
+ KeepHashes[C] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ if (C + 3 < KeepHashes.size() - 1)
+ {
+ KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ }
+ C++;
+ }
+
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ GcCtx.AddRetainedCids(KeepHashes);
+ Cas.CollectGarbage(GcCtx);
+ const HashKeySet& Deleted = GcCtx.DeletedCids();
+ Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); });
+ }
+
+ while (WorkCompleted < NewChunks.size() + Chunks.size())
+ {
+ Sleep(1);
+ }
+
+ // Need to be careful since we might GC blocks we don't know outside of RwLock::ExclusiveLockScope
+ for (const auto& Chunk : NewChunks)
+ {
+ if (Cas.HaveChunk(Chunk.first))
+ {
+ GcChunkHashes.emplace(Chunk.first);
+ }
+ }
+ std::vector<IoHash> KeepHashes(GcChunkHashes.begin(), GcChunkHashes.end());
+ size_t C = 0;
+ while (C < KeepHashes.size())
+ {
+ if (C % 155 == 0)
+ {
+ if (C < KeepHashes.size() - 1)
+ {
+ KeepHashes[C] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ if (C + 3 < KeepHashes.size() - 1)
+ {
+ KeepHashes[C + 3] = KeepHashes[KeepHashes.size() - 1];
+ KeepHashes.pop_back();
+ }
+ }
+ C++;
+ }
+
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ GcCtx.AddRetainedCids(KeepHashes);
+ Cas.CollectGarbage(GcCtx);
+ const HashKeySet& Deleted = GcCtx.DeletedCids();
+ Deleted.IterateHashes([&GcChunkHashes](const IoHash& ChunkHash) { GcChunkHashes.erase(ChunkHash); });
+ }
+ {
+ WorkCompleted = 0;
+ for (const IoHash& ChunkHash : GcChunkHashes)
+ {
+ ThreadPool.ScheduleWork([&Cas, &WorkCompleted, ChunkHash]() {
+ CHECK(Cas.HaveChunk(ChunkHash));
+ CHECK(ChunkHash == IoHash::HashBuffer(Cas.FindChunk(ChunkHash)));
+ WorkCompleted.fetch_add(1);
+ });
+ }
+ while (WorkCompleted < GcChunkHashes.size())
+ {
+ Sleep(1);
+ }
+ }
+ }
+}
+
+#endif
+
+void
+compactcas_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenstore/compactcas.h b/src/zenstore/compactcas.h
new file mode 100644
index 000000000..b0c6699eb
--- /dev/null
+++ b/src/zenstore/compactcas.h
@@ -0,0 +1,95 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+#include <zenstore/blockstore.h>
+#include <zenstore/caslog.h>
+#include <zenstore/gc.h>
+
+#include "cas.h"
+
+#include <atomic>
+#include <limits>
+#include <unordered_map>
+
+namespace spdlog {
+class logger;
+}
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+#pragma pack(push)
+#pragma pack(1)
+
+struct CasDiskIndexEntry
+{
+ static const uint8_t kTombstone = 0x01;
+
+ IoHash Key;
+ BlockStoreDiskLocation Location;
+ ZenContentType ContentType = ZenContentType::kUnknownContentType;
+ uint8_t Flags = 0;
+};
+
+#pragma pack(pop)
+
+static_assert(sizeof(CasDiskIndexEntry) == 32);
+
+/** This implements a storage strategy for small CAS values
+ *
+ * New chunks are simply appended to a small object file, and an index is
+ * maintained to allow chunks to be looked up within the active small object
+ * files
+ *
+ */
+
+struct CasContainerStrategy final : public GcStorage
+{
+ CasContainerStrategy(GcManager& Gc);
+ ~CasContainerStrategy();
+
+ CasStore::InsertResult InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash);
+ IoBuffer FindChunk(const IoHash& ChunkHash);
+ bool HaveChunk(const IoHash& ChunkHash);
+ void FilterChunks(HashKeySet& InOutChunks);
+ void Initialize(const std::filesystem::path& RootDirectory,
+ const std::string_view ContainerBaseName,
+ uint32_t MaxBlockSize,
+ uint64_t Alignment,
+ bool IsNewStore);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ virtual void CollectGarbage(GcContext& GcCtx) override;
+ virtual GcStorageSize StorageSize() const override { return {.DiskSize = m_BlockStore.TotalSize()}; }
+
+private:
+ CasStore::InsertResult InsertChunk(const void* ChunkData, size_t ChunkSize, const IoHash& ChunkHash);
+ void MakeIndexSnapshot();
+ uint64_t ReadIndexFile();
+ uint64_t ReadLog(uint64_t SkipEntryCount);
+ void OpenContainer(bool IsNewStore);
+
+ spdlog::logger& Log() { return m_Log; }
+
+ std::filesystem::path m_RootDirectory;
+ spdlog::logger& m_Log;
+ uint64_t m_PayloadAlignment = 1u << 4;
+ uint64_t m_MaxBlockSize = 1u << 28;
+ bool m_IsInitialized = false;
+ TCasLogFile<CasDiskIndexEntry> m_CasLog;
+ uint64_t m_LogFlushPosition = 0;
+ std::string m_ContainerBaseName;
+ std::filesystem::path m_BlocksBasePath;
+ BlockStore m_BlockStore;
+
+ RwLock m_LocationMapLock;
+ typedef std::unordered_map<IoHash, BlockStoreDiskLocation, IoHash::Hasher> LocationMap_t;
+ LocationMap_t m_LocationMap;
+};
+
+void compactcas_forcelink();
+
+} // namespace zen
diff --git a/src/zenstore/filecas.cpp b/src/zenstore/filecas.cpp
new file mode 100644
index 000000000..1d25920c4
--- /dev/null
+++ b/src/zenstore/filecas.cpp
@@ -0,0 +1,1452 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "filecas.h"
+
+#include <zencore/compress.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/memory.h>
+#include <zencore/scopeguard.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
+#include <zencore/thread.h>
+#include <zencore/timer.h>
+#include <zencore/uid.h>
+#include <zenstore/gc.h>
+#include <zenstore/scrubcontext.h>
+#include <zenutil/basicfile.h>
+
+#if ZEN_WITH_TESTS
+# include <zencore/compactbinarybuilder.h>
+#endif
+
+#include <gsl/gsl-lite.hpp>
+
+#include <barrier>
+#include <filesystem>
+#include <functional>
+#include <unordered_map>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <xxhash.h>
+#if ZEN_PLATFORM_WINDOWS
+# include <atlfile.h>
+#endif
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen {
+
+namespace filecas::impl {
+ const char* IndexExtension = ".uidx";
+ const char* LogExtension = ".ulog";
+
+ std::filesystem::path GetIndexPath(const std::filesystem::path& RootDir) { return RootDir / fmt::format("cas{}", IndexExtension); }
+
+ std::filesystem::path GetTempIndexPath(const std::filesystem::path& RootDir)
+ {
+ return RootDir / fmt::format("cas.tmp{}", IndexExtension);
+ }
+
+ std::filesystem::path GetLogPath(const std::filesystem::path& RootDir) { return RootDir / fmt::format("cas{}", LogExtension); }
+
+#pragma pack(push)
+#pragma pack(1)
+
+ struct FileCasIndexHeader
+ {
+ static constexpr uint32_t ExpectedMagic = 0x75696478; // 'uidx';
+ static constexpr uint32_t CurrentVersion = 1;
+
+ uint32_t Magic = ExpectedMagic;
+ uint32_t Version = CurrentVersion;
+ uint64_t EntryCount = 0;
+ uint64_t LogPosition = 0;
+ uint32_t Reserved = 0;
+ uint32_t Checksum = 0;
+
+ static uint32_t ComputeChecksum(const FileCasIndexHeader& Header)
+ {
+ return XXH32(&Header.Magic, sizeof(FileCasIndexHeader) - sizeof(uint32_t), 0xC0C0'BABA);
+ }
+ };
+
+ static_assert(sizeof(FileCasIndexHeader) == 32);
+
+#pragma pack(pop)
+
+} // namespace filecas::impl
+
+FileCasStrategy::ShardingHelper::ShardingHelper(const std::filesystem::path& RootPath, const IoHash& ChunkHash)
+{
+ ShardedPath.Append(RootPath.c_str());
+
+ ExtendableStringBuilder<64> HashString;
+ ChunkHash.ToHexString(HashString);
+
+ const char* str = HashString.c_str();
+
+ // Shard into a path with two directory levels containing 12 bits and 8 bits
+ // respectively.
+ //
+ // This results in a maximum of 4096 * 256 directories
+ //
+ // The numbers have been chosen somewhat arbitrarily but are large to scale
+ // to very large chunk repositories without creating too many directories
+ // on a single level since NTFS does not deal very well with this.
+ //
+ // It may or may not make sense to make this a configurable policy, and it
+ // would probably be a good idea to measure performance for different
+ // policies and chunk counts
+
+ ShardedPath.AppendSeparator();
+ ShardedPath.AppendAsciiRange(str, str + 3);
+
+ ShardedPath.AppendSeparator();
+ ShardedPath.AppendAsciiRange(str + 3, str + 5);
+ Shard2len = ShardedPath.Size();
+
+ ShardedPath.AppendSeparator();
+ ShardedPath.AppendAsciiRange(str + 5, str + 40);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+FileCasStrategy::FileCasStrategy(GcManager& Gc) : GcStorage(Gc), m_Log(logging::Get("filecas"))
+{
+}
+
+FileCasStrategy::~FileCasStrategy()
+{
+}
+
+void
+FileCasStrategy::Initialize(const std::filesystem::path& RootDirectory, bool IsNewStore)
+{
+ using namespace filecas::impl;
+
+ m_IsInitialized = true;
+
+ m_RootDirectory = RootDirectory;
+
+ m_Index.clear();
+
+ std::filesystem::path LogPath = GetLogPath(m_RootDirectory);
+ std::filesystem::path IndexPath = GetIndexPath(m_RootDirectory);
+
+ if (IsNewStore)
+ {
+ std::filesystem::remove(LogPath);
+ std::filesystem::remove(IndexPath);
+
+ if (std::filesystem::is_directory(m_RootDirectory))
+ {
+ // We need to explicitly only delete sharded root folders as the cas manifest, tinyobject and smallobject cas folders may reside
+ // in this folder as well
+ struct Visitor : public FileSystemTraversal::TreeVisitor
+ {
+ virtual void VisitFile(const std::filesystem::path&, const path_view&, uint64_t) override
+ {
+ // We don't care about files
+ }
+ static bool IsHexChar(std::filesystem::path::value_type C)
+ {
+ return std::find(&HexChars[0], &HexChars[16], C) != &HexChars[16];
+ }
+ virtual bool VisitDirectory([[maybe_unused]] const std::filesystem::path& Parent,
+ [[maybe_unused]] const path_view& DirectoryName) override
+ {
+ if (DirectoryName.length() == 3)
+ {
+ if (IsHexChar(DirectoryName[0]) && IsHexChar(DirectoryName[1]) && IsHexChar(DirectoryName[2]))
+ {
+ ShardedRoots.push_back(Parent / DirectoryName);
+ }
+ }
+ return false;
+ }
+ std::vector<std::filesystem::path> ShardedRoots;
+ } CasVisitor;
+
+ FileSystemTraversal Traversal;
+ Traversal.TraverseFileSystem(m_RootDirectory, CasVisitor);
+ for (const std::filesystem::path& SharededRoot : CasVisitor.ShardedRoots)
+ {
+ std::filesystem::remove_all(SharededRoot);
+ }
+ }
+ }
+
+ m_LogFlushPosition = ReadIndexFile();
+ uint64_t LogEntryCount = ReadLog(m_LogFlushPosition);
+ for (const auto& Entry : m_Index)
+ {
+ m_TotalSize.fetch_add(Entry.second.Size, std::memory_order::relaxed);
+ }
+
+ CreateDirectories(m_RootDirectory);
+ m_CasLog.Open(LogPath, CasLogFile::Mode::kWrite);
+
+ if (IsNewStore || LogEntryCount > 0)
+ {
+ MakeIndexSnapshot();
+ }
+}
+
+#if ZEN_PLATFORM_WINDOWS
+static void
+DeletePayloadFileOnClose(const void* FileHandle)
+{
+ const HANDLE WinFileHandle = (const HANDLE)FileHandle;
+ // This will cause the file to be deleted when the last handle to it is closed
+ FILE_DISPOSITION_INFO Fdi{};
+ Fdi.DeleteFile = TRUE;
+ BOOL Success = SetFileInformationByHandle(WinFileHandle, FileDispositionInfo, &Fdi, sizeof Fdi);
+
+ if (!Success)
+ {
+ // TODO: We should provide information to this function to tell it if the payload is temporary or not and if we are allowed
+ // to delete it.
+ ZEN_WARN("Failed to flag CAS temporary payload file '{}' for deletion: '{}'",
+ PathFromHandle(WinFileHandle),
+ GetLastErrorAsString());
+ }
+}
+#endif
+
+CasStore::InsertResult
+FileCasStrategy::InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, CasStore::InsertMode Mode)
+{
+ ZEN_ASSERT(m_IsInitialized);
+
+#if !ZEN_WITH_TESTS
+ ZEN_ASSERT(Chunk.GetContentType() == ZenContentType::kCompressedBinary);
+#endif
+
+ if (Mode == CasStore::InsertMode::kCopyOnly)
+ {
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ if (m_Index.contains(ChunkHash))
+ {
+ return CasStore::InsertResult{.New = false};
+ }
+ }
+ return InsertChunk(Chunk.Data(), Chunk.Size(), ChunkHash);
+ }
+
+ // File-based chunks have special case handling whereby we move the file into
+ // place in the file store directory, thus avoiding unnecessary copying
+
+ IoBufferFileReference FileRef;
+ if (Chunk.IsWholeFile() && Chunk.GetFileReference(/* out */ FileRef))
+ {
+ {
+ bool Exists = true;
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ Exists = m_Index.contains(ChunkHash);
+ }
+ if (Exists)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ DeletePayloadFileOnClose(FileRef.FileHandle);
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ std::filesystem::path FilePath = PathFromHandle(FileRef.FileHandle);
+ if (unlink(FilePath.c_str()) < 0)
+ {
+ int UnlinkError = zen::GetLastError();
+ if (UnlinkError != ENOENT)
+ {
+ ZEN_WARN("Failed to unlink CAS temporary payload file '{}': '{}'",
+ FilePath.string(),
+ GetSystemErrorAsString(UnlinkError));
+ }
+ }
+#endif
+ return CasStore::InsertResult{.New = false};
+ }
+ }
+
+ ShardingHelper Name(m_RootDirectory.c_str(), ChunkHash);
+
+ RwLock::ExclusiveLockScope HashLock(LockForHash(ChunkHash));
+
+#if ZEN_PLATFORM_WINDOWS
+ const HANDLE ChunkFileHandle = FileRef.FileHandle;
+ // See if file already exists
+ {
+ CAtlFile PayloadFile;
+
+ if (HRESULT hRes = PayloadFile.Create(Name.ShardedPath.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING); SUCCEEDED(hRes))
+ {
+ // If we succeeded in opening the target file then we don't need to do anything else because it already exists
+ // and should contain the content we were about to insert
+
+ // We do need to ensure the source file goes away on close, however
+ size_t ChunkSize = Chunk.GetSize();
+ uint64_t FileSize = 0;
+ if (HRESULT hSizeRes = PayloadFile.GetSize(FileSize); SUCCEEDED(hSizeRes) && FileSize == ChunkSize)
+ {
+ HashLock.ReleaseNow();
+
+ bool IsNew = false;
+ {
+ RwLock::ExclusiveLockScope __(m_Lock);
+ IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = ChunkSize}}).second;
+ }
+ if (IsNew)
+
+ {
+ m_TotalSize.fetch_add(static_cast<uint64_t>(ChunkSize), std::memory_order::relaxed);
+ }
+
+ DeletePayloadFileOnClose(ChunkFileHandle);
+
+ return CasStore::InsertResult{.New = IsNew};
+ }
+ else
+ {
+ ZEN_WARN("get file size FAILED or file size mismatch of file cas '{}'. Expected {}, found {}. Trying to overwrite",
+ Name.ShardedPath.ToUtf8(),
+ ChunkSize,
+ FileSize);
+ }
+ }
+ else
+ {
+ if (hRes == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND))
+ {
+ // Shard directory does not exist
+ }
+ else if (hRes == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND))
+ {
+ // Shard directory exists, but not the file
+ }
+ else if (hRes == HRESULT_FROM_WIN32(ERROR_SHARING_VIOLATION))
+ {
+ // Sharing violation, likely because we are trying to open a file
+ // which has been renamed on another thread, and the file handle
+ // used to rename it is still open. We handle this case below
+ // instead of here
+ }
+ else
+ {
+ ZEN_INFO("Unexpected error opening file '{}': {}", Name.ShardedPath.ToUtf8(), hRes);
+ }
+ }
+ }
+
+ std::filesystem::path FullPath(Name.ShardedPath.c_str());
+
+ std::filesystem::path FilePath = FullPath.parent_path();
+ std::wstring FileName = FullPath.native();
+
+ const DWORD BufferSize = sizeof(FILE_RENAME_INFO) + gsl::narrow<DWORD>(FileName.size() * sizeof(WCHAR));
+ FILE_RENAME_INFO* RenameInfo = reinterpret_cast<FILE_RENAME_INFO*>(Memory::Alloc(BufferSize));
+ memset(RenameInfo, 0, BufferSize);
+
+ RenameInfo->ReplaceIfExists = FALSE;
+ RenameInfo->FileNameLength = gsl::narrow<DWORD>(FileName.size());
+ memcpy(RenameInfo->FileName, FileName.c_str(), FileName.size() * sizeof(WCHAR));
+ RenameInfo->FileName[FileName.size()] = 0;
+
+ auto $ = MakeGuard([&] { Memory::Free(RenameInfo); });
+
+ // Try to move file into place
+ BOOL Success = SetFileInformationByHandle(ChunkFileHandle, FileRenameInfo, RenameInfo, BufferSize);
+
+ if (!Success)
+ {
+ // The rename/move could fail because the target directory does not yet exist. This code attempts
+ // to create it
+
+ CAtlFile DirHandle;
+
+ auto InternalCreateDirectoryHandle = [&] {
+ return DirHandle.Create(FilePath.c_str(),
+ GENERIC_READ | GENERIC_WRITE,
+ FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
+ OPEN_EXISTING,
+ FILE_FLAG_BACKUP_SEMANTICS);
+ };
+
+ // It's possible for several threads to enter this logic trying to create the same
+ // directory. Only one will create the directory of course, but all threads will
+ // make it through okay
+
+ HRESULT hRes = InternalCreateDirectoryHandle();
+
+ if (FAILED(hRes))
+ {
+ // TODO: we can handle directory creation more intelligently and efficiently than
+ // this currently does
+
+ CreateDirectories(FilePath.c_str());
+
+ hRes = InternalCreateDirectoryHandle();
+ }
+
+ if (FAILED(hRes))
+ {
+ ThrowSystemException(hRes, fmt::format("Failed to open shard directory '{}'", FilePath));
+ }
+
+ // Retry rename/move
+
+ Success = SetFileInformationByHandle(ChunkFileHandle, FileRenameInfo, RenameInfo, BufferSize);
+ }
+
+ if (Success)
+ {
+ m_CasLog.Append({.Key = ChunkHash, .Size = Chunk.Size()});
+
+ HashLock.ReleaseNow();
+
+ bool IsNew = false;
+ {
+ RwLock::ExclusiveLockScope __(m_Lock);
+ IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = Chunk.Size()}}).second;
+ }
+ if (IsNew)
+ {
+ m_TotalSize.fetch_add(Chunk.Size(), std::memory_order::relaxed);
+ }
+
+ return CasStore::InsertResult{.New = IsNew};
+ }
+
+ const DWORD LastError = GetLastError();
+
+ if ((LastError == ERROR_FILE_EXISTS) || (LastError == ERROR_ALREADY_EXISTS))
+ {
+ HashLock.ReleaseNow();
+ DeletePayloadFileOnClose(ChunkFileHandle);
+
+ bool IsNew = false;
+ {
+ RwLock::ExclusiveLockScope __(m_Lock);
+ IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = Chunk.Size()}}).second;
+ }
+ if (IsNew)
+ {
+ m_TotalSize.fetch_add(Chunk.Size(), std::memory_order::relaxed);
+ }
+
+ return CasStore::InsertResult{.New = IsNew};
+ }
+
+ ZEN_WARN("rename of CAS payload file failed ('{}'), falling back to regular write for insert of {}",
+ GetSystemErrorAsString(LastError),
+ ChunkHash);
+
+ DeletePayloadFileOnClose(ChunkFileHandle);
+
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ std::filesystem::path SourcePath = PathFromHandle(FileRef.FileHandle);
+ std::filesystem::path DestPath = Name.ShardedPath.c_str();
+ int Ret = link(SourcePath.c_str(), DestPath.c_str());
+ if (Ret < 0 && zen::GetLastError() == ENOENT)
+ {
+ // Destination directory doesn't exist. Create it any try again.
+ CreateDirectories(DestPath.parent_path().c_str());
+ Ret = link(SourcePath.c_str(), DestPath.c_str());
+ }
+ int LinkError = zen::GetLastError();
+
+ if (unlink(SourcePath.c_str()) < 0)
+ {
+ int UnlinkError = zen::GetLastError();
+ if (UnlinkError != ENOENT)
+ {
+ ZEN_WARN("Failed to unlink CAS temporary payload file '{}': '{}'",
+ SourcePath.string(),
+ GetSystemErrorAsString(UnlinkError));
+ }
+ }
+
+ // It is possible that someone beat us to it in linking the file. In that
+ // case a "file exists" error is okay. All others are not.
+ if (Ret < 0)
+ {
+ if (LinkError == EEXIST)
+ {
+ HashLock.ReleaseNow();
+ bool IsNew = false;
+ {
+ RwLock::ExclusiveLockScope __(m_Lock);
+ IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = Chunk.Size()}}).second;
+ }
+ if (IsNew)
+ {
+ m_TotalSize.fetch_add(Chunk.Size(), std::memory_order::relaxed);
+ }
+ return CasStore::InsertResult{.New = IsNew};
+ }
+
+ ZEN_WARN("link of CAS payload file failed ('{}'), falling back to regular write for insert of {}",
+ GetSystemErrorAsString(LinkError),
+ ChunkHash);
+ }
+ else
+ {
+ HashLock.ReleaseNow();
+ bool IsNew = false;
+ {
+ RwLock::ExclusiveLockScope __(m_Lock);
+ IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = Chunk.Size()}}).second;
+ }
+ if (IsNew)
+ {
+ m_TotalSize.fetch_add(Chunk.Size(), std::memory_order::relaxed);
+ }
+ return CasStore::InsertResult{.New = IsNew};
+ }
+#endif // ZEN_PLATFORM_*
+ }
+
+ return InsertChunk(Chunk.Data(), Chunk.Size(), ChunkHash);
+}
+
+CasStore::InsertResult
+FileCasStrategy::InsertChunk(const void* const ChunkData, const size_t ChunkSize, const IoHash& ChunkHash)
+{
+ ZEN_ASSERT(m_IsInitialized);
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ if (m_Index.contains(ChunkHash))
+ {
+ return {.New = false};
+ }
+ }
+
+ ShardingHelper Name(m_RootDirectory.c_str(), ChunkHash);
+
+ // See if file already exists
+
+#if ZEN_PLATFORM_WINDOWS
+ CAtlFile PayloadFile;
+
+ HRESULT hRes = PayloadFile.Create(Name.ShardedPath.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING);
+
+ if (SUCCEEDED(hRes))
+ {
+ // If we succeeded in opening the file then we don't need to do anything else because it already exists and should contain the
+ // content we were about to insert
+
+ bool IsNew = false;
+ {
+ RwLock::ExclusiveLockScope _(m_Lock);
+ IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = ChunkSize}}).second;
+ }
+ if (IsNew)
+ {
+ m_TotalSize.fetch_add(static_cast<uint64_t>(ChunkSize), std::memory_order::relaxed);
+ }
+ return CasStore::InsertResult{.New = IsNew};
+ }
+
+ PayloadFile.Close();
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ if (access(Name.ShardedPath.c_str(), F_OK) == 0)
+ {
+ return CasStore::InsertResult{.New = false};
+ }
+#endif
+
+ RwLock::ExclusiveLockScope HashLock(LockForHash(ChunkHash));
+
+#if ZEN_PLATFORM_WINDOWS
+ // For now, use double-checked locking to see if someone else was first
+
+ hRes = PayloadFile.Create(Name.ShardedPath.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING);
+
+ if (SUCCEEDED(hRes))
+ {
+ uint64_t FileSize = 0;
+ if (HRESULT hSizeRes = PayloadFile.GetSize(FileSize); SUCCEEDED(hSizeRes) && FileSize == ChunkSize)
+ {
+ // If we succeeded in opening the file then and the size is correct we don't need to do anything
+ // else because someone else managed to create the file before we did. Just return.
+
+ HashLock.ReleaseNow();
+ bool IsNew = false;
+ {
+ RwLock::ExclusiveLockScope __(m_Lock);
+ IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = ChunkSize}}).second;
+ }
+ if (IsNew)
+ {
+ m_TotalSize.fetch_add(static_cast<uint64_t>(ChunkSize), std::memory_order::relaxed);
+ }
+ return CasStore::InsertResult{.New = IsNew};
+ }
+ else
+ {
+ ZEN_WARN("get file size FAILED or file size mismatch of file cas '{}'. Expected {}, found {}. Trying to overwrite",
+ Name.ShardedPath.ToUtf8(),
+ ChunkSize,
+ FileSize);
+ }
+ }
+
+ if ((hRes != HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND)) && (hRes != HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)))
+ {
+ ZEN_WARN("Unexpected error code when opening shard file for read: {:#x}", uint32_t(hRes));
+ }
+
+ auto InternalCreateFile = [&] { return PayloadFile.Create(Name.ShardedPath.c_str(), GENERIC_WRITE, FILE_SHARE_DELETE, CREATE_ALWAYS); };
+
+ hRes = InternalCreateFile();
+
+ if (hRes == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND))
+ {
+ // Ensure parent directories exist and retry file creation
+ CreateDirectories(std::wstring_view(Name.ShardedPath.c_str(), Name.Shard2len));
+ hRes = InternalCreateFile();
+ }
+
+ if (FAILED(hRes))
+ {
+ ThrowSystemException(hRes, fmt::format("Failed to open shard file '{}'", Name.ShardedPath.ToUtf8()));
+ }
+#else
+ // Attempt to exclusively create the file.
+ auto InternalCreateFile = [&] {
+ int Fd = open(Name.ShardedPath.c_str(), O_WRONLY | O_CREAT | O_EXCL | O_CLOEXEC, 0666);
+ if (Fd >= 0)
+ {
+ fchmod(Fd, 0666);
+ }
+ return Fd;
+ };
+ int Fd = InternalCreateFile();
+ if (Fd < 0)
+ {
+ switch (zen::GetLastError())
+ {
+ case EEXIST:
+ // Another thread has beat us to it so we're golden.
+ {
+ HashLock.ReleaseNow();
+
+ bool IsNew = false;
+ {
+ RwLock::ExclusiveLockScope __(m_Lock);
+ IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = ChunkSize}}).second;
+ }
+ if (IsNew)
+ {
+ m_TotalSize.fetch_add(static_cast<uint64_t>(ChunkSize), std::memory_order::relaxed);
+ }
+ return {.New = IsNew};
+ }
+ break;
+
+ case ENOENT:
+ if (zen::CreateDirectories(std::string_view(Name.ShardedPath.c_str(), Name.Shard2len)))
+ {
+ Fd = InternalCreateFile();
+ if (Fd >= 0)
+ {
+ break;
+ }
+ }
+ ThrowLastError(fmt::format("Failed creating shard directory '{}'", Name.ShardedPath));
+
+ default:
+ ThrowLastError(fmt::format("Unexpected error occurred opening shard file '{}'", Name.ShardedPath.ToUtf8()));
+ }
+ }
+
+ struct FdWrapper
+ {
+ ~FdWrapper() { Close(); }
+ void Write(const void* Cursor, size_t Size) { (void)!write(Fd, Cursor, Size); }
+ void Close()
+ {
+ if (Fd >= 0)
+ {
+ close(Fd);
+ Fd = -1;
+ }
+ }
+ int Fd;
+ } PayloadFile = {Fd};
+#endif // ZEN_PLATFORM_WINDOWS
+
+ size_t ChunkRemain = ChunkSize;
+ auto ChunkCursor = reinterpret_cast<const uint8_t*>(ChunkData);
+
+ while (ChunkRemain != 0)
+ {
+ uint32_t ByteCount = uint32_t(std::min<size_t>(4 * 1024 * 1024ull, ChunkRemain));
+
+ PayloadFile.Write(ChunkCursor, ByteCount);
+
+ ChunkCursor += ByteCount;
+ ChunkRemain -= ByteCount;
+ }
+
+ // We cannot rely on RAII to close the file handle since it would be closed
+ // *after* the lock is released due to the initialization order
+ PayloadFile.Close();
+
+ m_CasLog.Append({.Key = ChunkHash, .Size = ChunkSize});
+
+ HashLock.ReleaseNow();
+
+ bool IsNew = false;
+ {
+ RwLock::ExclusiveLockScope __(m_Lock);
+ IsNew = m_Index.insert({ChunkHash, IndexEntry{.Size = ChunkSize}}).second;
+ }
+ if (IsNew)
+ {
+ m_TotalSize.fetch_add(static_cast<uint64_t>(ChunkSize), std::memory_order::relaxed);
+ }
+
+ return {.New = IsNew};
+}
+
+IoBuffer
+FileCasStrategy::FindChunk(const IoHash& ChunkHash)
+{
+ ZEN_ASSERT(m_IsInitialized);
+
+ {
+ RwLock::SharedLockScope _(m_Lock);
+ if (!m_Index.contains(ChunkHash))
+ {
+ return {};
+ }
+ }
+
+ ShardingHelper Name(m_RootDirectory.c_str(), ChunkHash);
+
+ RwLock::SharedLockScope _(LockForHash(ChunkHash));
+
+ return IoBufferBuilder::MakeFromFile(Name.ShardedPath.c_str());
+}
+
+bool
+FileCasStrategy::HaveChunk(const IoHash& ChunkHash)
+{
+ ZEN_ASSERT(m_IsInitialized);
+
+ RwLock::SharedLockScope _(m_Lock);
+ return m_Index.contains(ChunkHash);
+}
+
+void
+FileCasStrategy::DeleteChunk(const IoHash& ChunkHash, std::error_code& Ec)
+{
+ ShardingHelper Name(m_RootDirectory.c_str(), ChunkHash);
+
+ uint64_t FileSize = static_cast<uint64_t>(std::filesystem::file_size(Name.ShardedPath.c_str(), Ec));
+ if (Ec)
+ {
+ ZEN_WARN("get file size FAILED, file cas '{}'", Name.ShardedPath.ToUtf8());
+ FileSize = 0;
+ }
+
+ ZEN_DEBUG("deleting CAS payload file '{}' {}", Name.ShardedPath.ToUtf8(), NiceBytes(FileSize));
+ std::filesystem::remove(Name.ShardedPath.c_str(), Ec);
+
+ if (!Ec || !std::filesystem::exists(Name.ShardedPath.c_str()))
+ {
+ {
+ RwLock::ExclusiveLockScope _(m_Lock);
+ if (auto It = m_Index.find(ChunkHash); It != m_Index.end())
+ {
+ m_TotalSize.fetch_sub(It->second.Size, std::memory_order_relaxed);
+ m_Index.erase(It);
+ }
+ }
+ m_CasLog.Append({.Key = ChunkHash, .Flags = FileCasIndexEntry::kTombStone, .Size = FileSize});
+ }
+}
+
+void
+FileCasStrategy::FilterChunks(HashKeySet& InOutChunks)
+{
+ ZEN_ASSERT(m_IsInitialized);
+
+ // NOTE: it's not a problem now, but in the future if a GC should happen while this
+ // is in flight, the result could be wrong since chunks could go away in the meantime.
+ //
+ // It would be good to have a pinning mechanism to make this less likely but
+ // given that chunks could go away at any point after the results are returned to
+ // a caller, this is something which needs to be taken into account by anyone consuming
+ // this functionality in any case
+
+ InOutChunks.RemoveHashesIf([&](const IoHash& Hash) { return HaveChunk(Hash); });
+}
+
+void
+FileCasStrategy::IterateChunks(std::function<void(const IoHash& Hash, IoBuffer&& Payload)>&& Callback)
+{
+ ZEN_ASSERT(m_IsInitialized);
+
+ RwLock::SharedLockScope _(m_Lock);
+ for (const auto& It : m_Index)
+ {
+ const IoHash& NameHash = It.first;
+ ShardingHelper Name(m_RootDirectory.c_str(), NameHash);
+ IoBuffer Payload = IoBufferBuilder::MakeFromFile(Name.ShardedPath.c_str());
+ Callback(NameHash, std::move(Payload));
+ }
+}
+
+void
+FileCasStrategy::Flush()
+{
+ // Since we don't keep files open after writing there's nothing specific
+ // to flush here.
+ //
+ // Depending on what semantics we want Flush() to provide, it could be
+ // argued that this should just flush the volume which we are using to
+ // store the CAS files on here, to ensure metadata is flushed along
+ // with file data
+ //
+ // Related: to facilitate more targeted validation during recovery we could
+ // maintain a log of when chunks were created
+}
+
+void
+FileCasStrategy::Scrub(ScrubContext& Ctx)
+{
+ ZEN_ASSERT(m_IsInitialized);
+
+ std::vector<IoHash> BadHashes;
+ uint64_t ChunkCount{0}, ChunkBytes{0};
+
+ {
+ std::vector<FileCasStrategy::FileCasIndexEntry> ScannedEntries = FileCasStrategy::ScanFolderForCasFiles(m_RootDirectory);
+ RwLock::ExclusiveLockScope _(m_Lock);
+ for (const FileCasStrategy::FileCasIndexEntry& Entry : ScannedEntries)
+ {
+ if (m_Index.insert({Entry.Key, {.Size = Entry.Size}}).second)
+ {
+ m_TotalSize.fetch_add(static_cast<uint64_t>(Entry.Size), std::memory_order::relaxed);
+ m_CasLog.Append({.Key = Entry.Key, .Size = Entry.Size});
+ }
+ }
+ }
+
+ IterateChunks([&](const IoHash& Hash, IoBuffer&& Payload) {
+ if (!Payload)
+ {
+ BadHashes.push_back(Hash);
+ return;
+ }
+ ++ChunkCount;
+ ChunkBytes += Payload.GetSize();
+
+ IoHash RawHash;
+ uint64_t RawSize;
+ if (CompressedBuffer::ValidateCompressedHeader(Payload, RawHash, RawSize))
+ {
+ if (RawHash != Hash)
+ {
+ // Hash mismatch
+ BadHashes.push_back(Hash);
+ return;
+ }
+ return;
+ }
+#if ZEN_WITH_TESTS
+ IoHash ComputedHash = IoHash::HashBuffer(CompositeBuffer(SharedBuffer(std::move(Payload))));
+ if (ComputedHash == Hash)
+ {
+ return;
+ }
+#endif
+ BadHashes.push_back(Hash);
+ });
+
+ Ctx.ReportScrubbed(ChunkCount, ChunkBytes);
+
+ if (!BadHashes.empty())
+ {
+ ZEN_WARN("file CAS scrubbing: {} bad chunks found", BadHashes.size());
+
+ if (Ctx.RunRecovery())
+ {
+ ZEN_WARN("recovery: deleting backing files for {} bad chunks which were identified as bad", BadHashes.size());
+
+ for (const IoHash& Hash : BadHashes)
+ {
+ std::error_code Ec;
+ DeleteChunk(Hash, Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("failed to delete file for chunk {}", Hash);
+ }
+ }
+ }
+ }
+
+ // Let whomever it concerns know about the bad chunks. This could
+ // be used to invalidate higher level data structures more efficiently
+ // than a full validation pass might be able to do
+ Ctx.ReportBadCidChunks(BadHashes);
+
+ ZEN_INFO("file CAS scrubbed: {} chunks ({})", ChunkCount, NiceBytes(ChunkBytes));
+}
+
+void
+FileCasStrategy::CollectGarbage(GcContext& GcCtx)
+{
+ ZEN_ASSERT(m_IsInitialized);
+
+ ZEN_DEBUG("collecting garbage from {}", m_RootDirectory);
+
+ std::vector<IoHash> ChunksToDelete;
+ std::atomic<uint64_t> ChunksToDeleteBytes{0};
+ std::atomic<uint64_t> ChunkCount{0}, ChunkBytes{0};
+
+ std::vector<IoHash> CandidateCas;
+ CandidateCas.resize(1);
+
+ uint64_t DeletedCount = 0;
+ uint64_t OldTotalSize = m_TotalSize.load(std::memory_order::relaxed);
+
+ Stopwatch TotalTimer;
+ const auto _ = MakeGuard([&] {
+ ZEN_DEBUG("garbage collect for '{}' DONE after {}, deleted {} out of {} files, removed {} out of {}",
+ m_RootDirectory,
+ NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()),
+ DeletedCount,
+ ChunkCount,
+ NiceBytes(OldTotalSize - m_TotalSize.load(std::memory_order::relaxed)),
+ NiceBytes(OldTotalSize));
+ });
+
+ IterateChunks([&](const IoHash& Hash, IoBuffer&& Payload) {
+ bool KeepThis = false;
+ CandidateCas[0] = Hash;
+ GcCtx.FilterCids(CandidateCas, [&](const IoHash& Hash) {
+ ZEN_UNUSED(Hash);
+ KeepThis = true;
+ });
+
+ const uint64_t FileSize = Payload.GetSize();
+
+ if (!KeepThis)
+ {
+ ChunksToDelete.push_back(Hash);
+ ChunksToDeleteBytes.fetch_add(FileSize);
+ }
+
+ ++ChunkCount;
+ ChunkBytes.fetch_add(FileSize);
+ });
+
+ // TODO, any entires we did not encounter during our IterateChunks should be removed from the index
+
+ if (ChunksToDelete.empty())
+ {
+ ZEN_DEBUG("gc for '{}' SKIPPED, nothing to delete", m_RootDirectory);
+ return;
+ }
+
+ ZEN_DEBUG("deleting file CAS garbage for '{}': {} out of {} chunks ({})",
+ m_RootDirectory,
+ ChunksToDelete.size(),
+ ChunkCount.load(),
+ NiceBytes(ChunksToDeleteBytes));
+
+ if (GcCtx.IsDeletionMode() == false)
+ {
+ ZEN_DEBUG("NOTE: not actually deleting anything since deletion is disabled");
+
+ return;
+ }
+
+ for (const IoHash& Hash : ChunksToDelete)
+ {
+ ZEN_TRACE("deleting chunk {}", Hash);
+
+ std::error_code Ec;
+ DeleteChunk(Hash, Ec);
+
+ if (Ec)
+ {
+ ZEN_WARN("gc for '{}' failed to delete file for chunk {}: '{}'", m_RootDirectory, Hash, Ec.message());
+ continue;
+ }
+ DeletedCount++;
+ }
+
+ GcCtx.AddDeletedCids(ChunksToDelete);
+}
+
+bool
+FileCasStrategy::ValidateEntry(const FileCasIndexEntry& Entry, std::string& OutReason)
+{
+ if (Entry.Key == IoHash::Zero)
+ {
+ OutReason = fmt::format("Invalid hash key {}", Entry.Key.ToHexString());
+ return false;
+ }
+ if (Entry.Flags & (~FileCasIndexEntry::kTombStone))
+ {
+ OutReason = fmt::format("Invalid flags {} for entry {}", Entry.Flags, Entry.Key.ToHexString());
+ return false;
+ }
+ if (Entry.IsFlagSet(FileCasIndexEntry::kTombStone))
+ {
+ return true;
+ }
+ uint64_t Size = Entry.Size;
+ if (Size == 0)
+ {
+ OutReason = fmt::format("Invalid size {} for entry {}", Size, Entry.Key.ToHexString());
+ return false;
+ }
+ return true;
+}
+
+void
+FileCasStrategy::MakeIndexSnapshot()
+{
+ using namespace filecas::impl;
+
+ uint64_t LogCount = m_CasLog.GetLogCount();
+ if (m_LogFlushPosition == LogCount)
+ {
+ return;
+ }
+ ZEN_DEBUG("write store snapshot for '{}'", m_RootDirectory);
+ uint64_t EntryCount = 0;
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("wrote store snapshot for '{}' containing {} entries in {}",
+ m_RootDirectory,
+ EntryCount,
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ namespace fs = std::filesystem;
+
+ fs::path IndexPath = GetIndexPath(m_RootDirectory);
+ fs::path STmpIndexPath = GetTempIndexPath(m_RootDirectory);
+
+ // Move index away, we keep it if something goes wrong
+ if (fs::is_regular_file(STmpIndexPath))
+ {
+ fs::remove(STmpIndexPath);
+ }
+ if (fs::is_regular_file(IndexPath))
+ {
+ fs::rename(IndexPath, STmpIndexPath);
+ }
+
+ try
+ {
+ // Write the current state of the location map to a new index state
+ std::vector<FileCasIndexEntry> Entries;
+
+ {
+ Entries.resize(m_Index.size());
+
+ uint64_t EntryIndex = 0;
+ for (auto& Entry : m_Index)
+ {
+ FileCasIndexEntry& IndexEntry = Entries[EntryIndex++];
+ IndexEntry.Key = Entry.first;
+ IndexEntry.Size = Entry.second.Size;
+ }
+ }
+
+ BasicFile ObjectIndexFile;
+ ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kTruncate);
+ filecas::impl::FileCasIndexHeader Header = {.EntryCount = Entries.size(), .LogPosition = LogCount};
+
+ Header.Checksum = filecas::impl::FileCasIndexHeader::ComputeChecksum(Header);
+
+ ObjectIndexFile.Write(&Header, sizeof(filecas::impl::FileCasIndexHeader), 0);
+ ObjectIndexFile.Write(Entries.data(), Entries.size() * sizeof(FileCasIndexEntry), sizeof(filecas::impl::FileCasIndexHeader));
+ ObjectIndexFile.Flush();
+ ObjectIndexFile.Close();
+ EntryCount = Entries.size();
+ m_LogFlushPosition = LogCount;
+ }
+ catch (std::exception& Err)
+ {
+ ZEN_ERROR("snapshot FAILED, reason: '{}'", Err.what());
+
+ // Restore any previous snapshot
+
+ if (fs::is_regular_file(STmpIndexPath))
+ {
+ fs::remove(IndexPath);
+ fs::rename(STmpIndexPath, IndexPath);
+ }
+ }
+ if (fs::is_regular_file(STmpIndexPath))
+ {
+ fs::remove(STmpIndexPath);
+ }
+}
+uint64_t
+FileCasStrategy::ReadIndexFile()
+{
+ using namespace filecas::impl;
+
+ std::vector<FileCasIndexEntry> Entries;
+ std::filesystem::path IndexPath = GetIndexPath(m_RootDirectory);
+ if (std::filesystem::is_regular_file(IndexPath))
+ {
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("read store '{}' index containing {} entries in {}",
+ IndexPath,
+ Entries.size(),
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+
+ BasicFile ObjectIndexFile;
+ ObjectIndexFile.Open(IndexPath, BasicFile::Mode::kRead);
+ uint64_t Size = ObjectIndexFile.FileSize();
+ if (Size >= sizeof(FileCasIndexHeader))
+ {
+ uint64_t ExpectedEntryCount = (Size - sizeof(sizeof(FileCasIndexHeader))) / sizeof(FileCasIndexEntry);
+ FileCasIndexHeader Header;
+ ObjectIndexFile.Read(&Header, sizeof(Header), 0);
+ if ((Header.Magic == FileCasIndexHeader::ExpectedMagic) && (Header.Version == FileCasIndexHeader::CurrentVersion) &&
+ (Header.Checksum == FileCasIndexHeader::ComputeChecksum(Header)) && (Header.EntryCount <= ExpectedEntryCount))
+ {
+ Entries.resize(Header.EntryCount);
+ ObjectIndexFile.Read(Entries.data(), Header.EntryCount * sizeof(FileCasIndexEntry), sizeof(FileCasIndexHeader));
+
+ std::string InvalidEntryReason;
+ for (const FileCasIndexEntry& Entry : Entries)
+ {
+ if (!ValidateEntry(Entry, InvalidEntryReason))
+ {
+ ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", IndexPath, InvalidEntryReason);
+ continue;
+ }
+ m_Index.insert_or_assign(Entry.Key, IndexEntry{.Size = Entry.Size});
+ }
+
+ return Header.LogPosition;
+ }
+ else
+ {
+ ZEN_WARN("skipping invalid index file '{}'", IndexPath);
+ }
+ }
+ return 0;
+ }
+
+ if (std::filesystem::is_directory(m_RootDirectory))
+ {
+ ZEN_INFO("missing index for file cas, scanning for cas files in {}", m_RootDirectory);
+ TCasLogFile<FileCasIndexEntry> CasLog;
+ uint64_t TotalSize = 0;
+ Stopwatch TotalTimer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("scanned file cas folder '{}' DONE after {}, found {} files totalling {}",
+ m_RootDirectory,
+ NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()),
+ CasLog.GetLogCount(),
+ NiceBytes(TotalSize));
+ });
+
+ std::filesystem::path LogPath = GetLogPath(m_RootDirectory);
+
+ std::vector<FileCasStrategy::FileCasIndexEntry> ScannedEntries = FileCasStrategy::ScanFolderForCasFiles(m_RootDirectory);
+ CasLog.Open(LogPath, CasLogFile::Mode::kTruncate);
+ std::string InvalidEntryReason;
+ for (const FileCasStrategy::FileCasIndexEntry& Entry : ScannedEntries)
+ {
+ if (!ValidateEntry(Entry, InvalidEntryReason))
+ {
+ ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", m_RootDirectory, InvalidEntryReason);
+ continue;
+ }
+ m_Index.insert_or_assign(Entry.Key, IndexEntry{.Size = Entry.Size});
+ CasLog.Append(Entry);
+ }
+
+ CasLog.Close();
+ }
+
+ return 0;
+}
+
+uint64_t
+FileCasStrategy::ReadLog(uint64_t SkipEntryCount)
+{
+ using namespace filecas::impl;
+
+ std::filesystem::path LogPath = GetLogPath(m_RootDirectory);
+ if (std::filesystem::is_regular_file(LogPath))
+ {
+ uint64_t LogEntryCount = 0;
+ Stopwatch Timer;
+ const auto _ = MakeGuard([&] {
+ ZEN_INFO("read store '{}' log containing {} entries in {}", LogPath, LogEntryCount, NiceTimeSpanMs(Timer.GetElapsedTimeMs()));
+ });
+ TCasLogFile<FileCasIndexEntry> CasLog;
+ CasLog.Open(LogPath, CasLogFile::Mode::kRead);
+ if (CasLog.Initialize())
+ {
+ uint64_t EntryCount = CasLog.GetLogCount();
+ if (EntryCount < SkipEntryCount)
+ {
+ ZEN_WARN("reading full log at '{}', reason: Log position from index snapshot is out of range", LogPath);
+ SkipEntryCount = 0;
+ }
+ LogEntryCount = EntryCount - SkipEntryCount;
+ m_Index.reserve(LogEntryCount);
+ uint64_t InvalidEntryCount = 0;
+ CasLog.Replay(
+ [&](const FileCasIndexEntry& Record) {
+ std::string InvalidEntryReason;
+ if (Record.Flags & FileCasIndexEntry::kTombStone)
+ {
+ m_Index.erase(Record.Key);
+ return;
+ }
+ if (!ValidateEntry(Record, InvalidEntryReason))
+ {
+ ZEN_WARN("skipping invalid entry in '{}', reason: '{}'", LogPath, InvalidEntryReason);
+ ++InvalidEntryCount;
+ return;
+ }
+ m_Index.insert_or_assign(Record.Key, IndexEntry{.Size = Record.Size});
+ },
+ SkipEntryCount);
+ if (InvalidEntryCount)
+ {
+ ZEN_WARN("found {} invalid entries in '{}'", InvalidEntryCount, LogPath);
+ }
+ return LogEntryCount;
+ }
+ }
+ return 0;
+}
+
+std::vector<FileCasStrategy::FileCasIndexEntry>
+FileCasStrategy::ScanFolderForCasFiles(const std::filesystem::path& RootDir)
+{
+ using namespace filecas::impl;
+
+ std::vector<FileCasIndexEntry> Entries;
+ struct Visitor : public FileSystemTraversal::TreeVisitor
+ {
+ Visitor(const std::filesystem::path& RootDir, std::vector<FileCasIndexEntry>& Entries) : RootDirectory(RootDir), Entries(Entries) {}
+ virtual void VisitFile(const std::filesystem::path& Parent, const path_view& File, uint64_t FileSize) override
+ {
+ std::filesystem::path RelPath = std::filesystem::relative(Parent, RootDirectory);
+
+ std::filesystem::path::string_type PathString = RelPath.native();
+
+ if ((PathString.size() == (3 + 2 + 1)) && (File.size() == (40 - 3 - 2)))
+ {
+ if (PathString.at(3) == std::filesystem::path::preferred_separator)
+ {
+ PathString.erase(3, 1);
+ }
+ PathString.append(File);
+
+ // TODO: should validate that we're actually dealing with a valid hex string here
+#if ZEN_PLATFORM_WINDOWS
+ StringBuilder<64> Utf8;
+ WideToUtf8(PathString, Utf8);
+ IoHash NameHash = IoHash::FromHexString({Utf8.Data(), Utf8.Size()});
+#else
+ IoHash NameHash = IoHash::FromHexString(PathString);
+#endif
+ Entries.emplace_back(FileCasIndexEntry{.Key = NameHash, .Size = FileSize});
+ }
+ }
+
+ virtual bool VisitDirectory([[maybe_unused]] const std::filesystem::path& Parent,
+ [[maybe_unused]] const path_view& DirectoryName) override
+ {
+ return true;
+ }
+
+ const std::filesystem::path& RootDirectory;
+ std::vector<FileCasIndexEntry>& Entries;
+ } CasVisitor{RootDir, Entries};
+
+ FileSystemTraversal Traversal;
+ Traversal.TraverseFileSystem(RootDir, CasVisitor);
+ return Entries;
+};
+
+ //////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("cas.file.move")
+{
+ // specifying an absolute path here can be helpful when using procmon to dig into things
+ ScopedTemporaryDirectory TempDir; // {"d:\\filecas_testdir"};
+
+ GcManager Gc;
+
+ FileCasStrategy FileCas(Gc);
+ FileCas.Initialize(TempDir.Path() / "cas", /* IsNewStore */ true);
+
+ {
+ std::filesystem::path Payload1Path{TempDir.Path() / "payload_1"};
+
+ IoBuffer ZeroBytes{1024 * 1024};
+ IoHash ZeroHash = IoHash::HashBuffer(ZeroBytes);
+
+ BasicFile PayloadFile;
+ PayloadFile.Open(Payload1Path, BasicFile::Mode::kTruncate);
+ PayloadFile.Write(ZeroBytes, 0);
+ PayloadFile.Close();
+
+ IoBuffer Payload1 = IoBufferBuilder::MakeFromTemporaryFile(Payload1Path);
+
+ CasStore::InsertResult Result = FileCas.InsertChunk(Payload1, ZeroHash);
+ CHECK_EQ(Result.New, true);
+ }
+
+# if 0
+ SUBCASE("stresstest")
+ {
+ std::vector<IoHash> PayloadHashes;
+
+ const int kWorkers = 64;
+ const int kItemCount = 128;
+
+ for (int w = 0; w < kWorkers; ++w)
+ {
+ for (int i = 0; i < kItemCount; ++i)
+ {
+ IoBuffer Payload{1024};
+ *reinterpret_cast<int*>(Payload.MutableData()) = i;
+ PayloadHashes.push_back(IoHash::HashBuffer(Payload));
+
+ std::filesystem::path PayloadPath{TempDir.Path() / fmt::format("payload_{}_{}", w, i)};
+ WriteFile(PayloadPath, Payload);
+ }
+ }
+
+ std::barrier Sync{kWorkers};
+
+ auto PopulateAll = [&](int w) {
+ std::vector<IoBuffer> Buffers;
+
+ for (int i = 0; i < kItemCount; ++i)
+ {
+ std::filesystem::path PayloadPath{TempDir.Path() / fmt::format("payload_{}_{}", w, i)};
+ IoBuffer Payload = IoBufferBuilder::MakeFromTemporaryFile(PayloadPath);
+ Buffers.push_back(Payload);
+ Sync.arrive_and_wait();
+ CasStore::InsertResult Result = FileCas.InsertChunk(Payload, PayloadHashes[i]);
+ }
+ };
+
+ std::vector<std::jthread> Threads;
+
+ for (int i = 0; i < kWorkers; ++i)
+ {
+ Threads.push_back(std::jthread(PopulateAll, i));
+ }
+
+ for (std::jthread& Thread : Threads)
+ {
+ Thread.join();
+ }
+ }
+# endif
+}
+
+TEST_CASE("cas.file.gc")
+{
+ // specifying an absolute path here can be helpful when using procmon to dig into things
+ ScopedTemporaryDirectory TempDir; // {"d:\\filecas_testdir"};
+
+ GcManager Gc;
+ FileCasStrategy FileCas(Gc);
+ FileCas.Initialize(TempDir.Path() / "cas", /* IsNewStore */ true);
+
+ const int kIterationCount = 1000;
+ std::vector<IoHash> Keys{kIterationCount};
+
+ auto InsertChunks = [&] {
+ for (int i = 0; i < kIterationCount; ++i)
+ {
+ CbObjectWriter Cbo;
+ Cbo << "id" << i;
+ CbObject Obj = Cbo.Save();
+
+ IoBuffer ObjBuffer = Obj.GetBuffer().AsIoBuffer();
+ IoHash Hash = HashBuffer(ObjBuffer);
+
+ FileCas.InsertChunk(ObjBuffer, Hash);
+
+ Keys[i] = Hash;
+ }
+ };
+
+ // Drop everything
+
+ {
+ InsertChunks();
+
+ GcContext Ctx(GcClock::Now() - std::chrono::hours(24));
+ FileCas.CollectGarbage(Ctx);
+
+ for (const IoHash& Key : Keys)
+ {
+ IoBuffer Chunk = FileCas.FindChunk(Key);
+
+ CHECK(!Chunk);
+ }
+ }
+
+ // Keep roughly half of the chunks
+
+ {
+ InsertChunks();
+
+ GcContext Ctx(GcClock::Now() - std::chrono::hours(24));
+
+ for (const IoHash& Key : Keys)
+ {
+ if (Key.Hash[0] & 1)
+ {
+ Ctx.AddRetainedCids(std::vector<IoHash>{Key});
+ }
+ }
+
+ FileCas.CollectGarbage(Ctx);
+
+ for (const IoHash& Key : Keys)
+ {
+ if (Key.Hash[0] & 1)
+ {
+ CHECK(FileCas.FindChunk(Key));
+ }
+ else
+ {
+ CHECK(!FileCas.FindChunk(Key));
+ }
+ }
+ }
+}
+
+#endif
+
+void
+filecas_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenstore/filecas.h b/src/zenstore/filecas.h
new file mode 100644
index 000000000..420b3a634
--- /dev/null
+++ b/src/zenstore/filecas.h
@@ -0,0 +1,102 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#include <zencore/filesystem.h>
+#include <zencore/iobuffer.h>
+#include <zencore/iohash.h>
+#include <zencore/thread.h>
+#include <zenstore/caslog.h>
+#include <zenstore/gc.h>
+
+#include "cas.h"
+
+#include <atomic>
+#include <functional>
+
+namespace spdlog {
+class logger;
+}
+
+namespace zen {
+
+class BasicFile;
+
+/** CAS storage strategy using a file-per-chunk storage strategy
+ */
+
+struct FileCasStrategy final : public GcStorage
+{
+ FileCasStrategy(GcManager& Gc);
+ ~FileCasStrategy();
+
+ void Initialize(const std::filesystem::path& RootDirectory, bool IsNewStore);
+ CasStore::InsertResult InsertChunk(IoBuffer Chunk,
+ const IoHash& ChunkHash,
+ CasStore::InsertMode Mode = CasStore::InsertMode::kMayBeMovedInPlace);
+ IoBuffer FindChunk(const IoHash& ChunkHash);
+ bool HaveChunk(const IoHash& ChunkHash);
+ void FilterChunks(HashKeySet& InOutChunks);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ virtual void CollectGarbage(GcContext& GcCtx) override;
+ virtual GcStorageSize StorageSize() const override { return {.DiskSize = m_TotalSize.load(std::memory_order::relaxed)}; }
+
+private:
+ void MakeIndexSnapshot();
+ uint64_t ReadIndexFile();
+ uint64_t ReadLog(uint64_t LogPosition);
+
+ struct IndexEntry
+ {
+ uint64_t Size = 0;
+ };
+ using IndexMap = tsl::robin_map<IoHash, IndexEntry, IoHash::Hasher>;
+
+ CasStore::InsertResult InsertChunk(const void* ChunkData, size_t ChunkSize, const IoHash& ChunkHash);
+
+ std::filesystem::path m_RootDirectory;
+ RwLock m_Lock;
+ IndexMap m_Index;
+ RwLock m_ShardLocks[256]; // TODO: these should be spaced out so they don't share cache lines
+ spdlog::logger& m_Log;
+ spdlog::logger& Log() { return m_Log; }
+ std::atomic_uint64_t m_TotalSize{};
+ bool m_IsInitialized = false;
+
+ struct FileCasIndexEntry
+ {
+ static const uint32_t kTombStone = 0x0000'0001;
+
+ bool IsFlagSet(const uint32_t Flag) const { return (Flags & kTombStone) == Flag; }
+
+ IoHash Key;
+ uint32_t Flags = 0;
+ uint64_t Size = 0;
+ };
+ static bool ValidateEntry(const FileCasIndexEntry& Entry, std::string& OutReason);
+ static std::vector<FileCasStrategy::FileCasIndexEntry> ScanFolderForCasFiles(const std::filesystem::path& RootDir);
+
+ static_assert(sizeof(FileCasIndexEntry) == 32);
+
+ TCasLogFile<FileCasIndexEntry> m_CasLog;
+ uint64_t m_LogFlushPosition = 0;
+
+ inline RwLock& LockForHash(const IoHash& Hash) { return m_ShardLocks[Hash.Hash[19]]; }
+ void IterateChunks(std::function<void(const IoHash& Hash, IoBuffer&& Payload)>&& Callback);
+ void DeleteChunk(const IoHash& ChunkHash, std::error_code& Ec);
+
+ struct ShardingHelper
+ {
+ ShardingHelper(const std::filesystem::path& RootPath, const IoHash& ChunkHash);
+
+ size_t Shard2len = 0;
+ ExtendablePathBuilder<128> ShardedPath;
+ };
+};
+
+void filecas_forcelink();
+
+} // namespace zen
diff --git a/src/zenstore/gc.cpp b/src/zenstore/gc.cpp
new file mode 100644
index 000000000..370c3c965
--- /dev/null
+++ b/src/zenstore/gc.cpp
@@ -0,0 +1,1312 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenstore/gc.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinaryvalidation.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/scopeguard.h>
+#include <zencore/string.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
+#include <zencore/timer.h>
+#include <zenstore/cidstore.h>
+
+#include "cas.h"
+
+#include <fmt/format.h>
+#include <filesystem>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#else
+# include <fcntl.h>
+# include <sys/file.h>
+# include <sys/stat.h>
+# include <unistd.h>
+#endif
+
+#if ZEN_WITH_TESTS
+# include <zencore/compress.h>
+# include <algorithm>
+# include <random>
+#endif
+
+template<>
+struct fmt::formatter<zen::GcClock::TimePoint> : formatter<string_view>
+{
+ template<typename FormatContext>
+ auto format(const zen::GcClock::TimePoint& TimePoint, FormatContext& ctx)
+ {
+ std::time_t Time = std::chrono::system_clock::to_time_t(TimePoint);
+ zen::ExtendableStringBuilder<32> String;
+ String << std::ctime(&Time);
+ return formatter<string_view>::format(String.ToView(), ctx);
+ }
+};
+
+namespace zen {
+
+using namespace std::literals;
+namespace fs = std::filesystem;
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace {
+ std::error_code CreateGCReserve(const std::filesystem::path& Path, uint64_t Size)
+ {
+ if (Size == 0)
+ {
+ std::filesystem::remove(Path);
+ return std::error_code{};
+ }
+ CreateDirectories(Path.parent_path());
+ if (std::filesystem::is_regular_file(Path) && std::filesystem::file_size(Path) == Size)
+ {
+ return std::error_code();
+ }
+#if ZEN_PLATFORM_WINDOWS
+ DWORD dwCreationDisposition = CREATE_ALWAYS;
+ DWORD dwDesiredAccess = GENERIC_READ | GENERIC_WRITE;
+
+ const DWORD dwShareMode = 0;
+ const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL;
+ HANDLE hTemplateFile = nullptr;
+
+ HANDLE FileHandle = CreateFile(Path.c_str(),
+ dwDesiredAccess,
+ dwShareMode,
+ /* lpSecurityAttributes */ nullptr,
+ dwCreationDisposition,
+ dwFlagsAndAttributes,
+ hTemplateFile);
+
+ if (FileHandle == INVALID_HANDLE_VALUE)
+ {
+ return MakeErrorCodeFromLastError();
+ }
+ bool Keep = true;
+ auto _ = MakeGuard([&]() {
+ ::CloseHandle(FileHandle);
+ if (!Keep)
+ {
+ ::DeleteFile(Path.c_str());
+ }
+ });
+ LARGE_INTEGER liFileSize;
+ liFileSize.QuadPart = Size;
+ BOOL OK = ::SetFilePointerEx(FileHandle, liFileSize, 0, FILE_BEGIN);
+ if (!OK)
+ {
+ return MakeErrorCodeFromLastError();
+ }
+ OK = ::SetEndOfFile(FileHandle);
+ if (!OK)
+ {
+ return MakeErrorCodeFromLastError();
+ }
+ Keep = true;
+#else
+ int OpenFlags = O_CLOEXEC | O_RDWR | O_CREAT;
+ int Fd = open(Path.c_str(), OpenFlags, 0666);
+ if (Fd < 0)
+ {
+ return MakeErrorCodeFromLastError();
+ }
+
+ bool Keep = true;
+ auto _ = MakeGuard([&]() {
+ close(Fd);
+ if (!Keep)
+ {
+ unlink(Path.c_str());
+ }
+ });
+
+ if (fchmod(Fd, 0666) < 0)
+ {
+ return MakeErrorCodeFromLastError();
+ }
+
+# if ZEN_PLATFORM_MAC
+ if (ftruncate(Fd, (off_t)Size) < 0)
+ {
+ return MakeErrorCodeFromLastError();
+ }
+# else
+ if (ftruncate64(Fd, (off64_t)Size) < 0)
+ {
+ return MakeErrorCodeFromLastError();
+ }
+ int Error = posix_fallocate64(Fd, 0, (off64_t)Size);
+ if (Error)
+ {
+ return MakeErrorCode(Error);
+ }
+# endif
+ Keep = true;
+#endif
+ return std::error_code{};
+ }
+
+} // namespace
+
+//////////////////////////////////////////////////////////////////////////
+
+CbObject
+LoadCompactBinaryObject(const fs::path& Path)
+{
+ FileContents Result = ReadFile(Path);
+
+ if (!Result.ErrorCode)
+ {
+ IoBuffer Buffer = Result.Flatten();
+ if (CbValidateError Error = ValidateCompactBinary(Buffer, CbValidateMode::All); Error == CbValidateError::None)
+ {
+ return LoadCompactBinaryObject(Buffer);
+ }
+ }
+
+ return CbObject();
+}
+
+void
+SaveCompactBinaryObject(const fs::path& Path, const CbObject& Object)
+{
+ WriteFile(Path, Object.GetBuffer().AsIoBuffer());
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+struct GcContext::GcState
+{
+ using CacheKeyContexts = std::unordered_map<std::string, std::vector<IoHash>>;
+
+ CacheKeyContexts m_ExpiredCacheKeys;
+ HashKeySet m_RetainedCids;
+ HashKeySet m_DeletedCids;
+ GcClock::TimePoint m_ExpireTime;
+ bool m_DeletionMode = true;
+ bool m_CollectSmallObjects = false;
+
+ std::filesystem::path DiskReservePath;
+};
+
+GcContext::GcContext(const GcClock::TimePoint& ExpireTime) : m_State(std::make_unique<GcState>())
+{
+ m_State->m_ExpireTime = ExpireTime;
+}
+
+GcContext::~GcContext()
+{
+}
+
+void
+GcContext::AddRetainedCids(std::span<const IoHash> Cids)
+{
+ m_State->m_RetainedCids.AddHashesToSet(Cids);
+}
+
+void
+GcContext::SetExpiredCacheKeys(const std::string& CacheKeyContext, std::vector<IoHash>&& ExpiredKeys)
+{
+ m_State->m_ExpiredCacheKeys[CacheKeyContext] = std::move(ExpiredKeys);
+}
+
+void
+GcContext::IterateCids(std::function<void(const IoHash&)> Callback)
+{
+ m_State->m_RetainedCids.IterateHashes([&](const IoHash& Hash) { Callback(Hash); });
+}
+
+void
+GcContext::FilterCids(std::span<const IoHash> Cid, std::function<void(const IoHash&)> KeepFunc)
+{
+ m_State->m_RetainedCids.FilterHashes(Cid, [&](const IoHash& Hash) { KeepFunc(Hash); });
+}
+
+void
+GcContext::FilterCids(std::span<const IoHash> Cid, std::function<void(const IoHash&, bool)>&& FilterFunc)
+{
+ m_State->m_RetainedCids.FilterHashes(Cid, std::move(FilterFunc));
+}
+
+void
+GcContext::AddDeletedCids(std::span<const IoHash> Cas)
+{
+ m_State->m_DeletedCids.AddHashesToSet(Cas);
+}
+
+const HashKeySet&
+GcContext::DeletedCids()
+{
+ return m_State->m_DeletedCids;
+}
+
+std::span<const IoHash>
+GcContext::ExpiredCacheKeys(const std::string& CacheKeyContext) const
+{
+ return m_State->m_ExpiredCacheKeys[CacheKeyContext];
+}
+
+bool
+GcContext::IsDeletionMode() const
+{
+ return m_State->m_DeletionMode;
+}
+
+void
+GcContext::SetDeletionMode(bool NewState)
+{
+ m_State->m_DeletionMode = NewState;
+}
+
+bool
+GcContext::CollectSmallObjects() const
+{
+ return m_State->m_CollectSmallObjects;
+}
+
+void
+GcContext::CollectSmallObjects(bool NewState)
+{
+ m_State->m_CollectSmallObjects = NewState;
+}
+
+GcClock::TimePoint
+GcContext::ExpireTime() const
+{
+ return m_State->m_ExpireTime;
+}
+
+void
+GcContext::DiskReservePath(const std::filesystem::path& Path)
+{
+ m_State->DiskReservePath = Path;
+}
+
+uint64_t
+GcContext::ClaimGCReserve()
+{
+ if (!std::filesystem::is_regular_file(m_State->DiskReservePath))
+ {
+ return 0;
+ }
+ uint64_t ReclaimedSize = std::filesystem::file_size(m_State->DiskReservePath);
+ if (std::filesystem::remove(m_State->DiskReservePath))
+ {
+ return ReclaimedSize;
+ }
+ return 0;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+GcContributor::GcContributor(GcManager& Gc) : m_Gc(Gc)
+{
+ m_Gc.AddGcContributor(this);
+}
+
+GcContributor::~GcContributor()
+{
+ m_Gc.RemoveGcContributor(this);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+GcStorage::GcStorage(GcManager& Gc) : m_Gc(Gc)
+{
+ m_Gc.AddGcStorage(this);
+}
+
+GcStorage::~GcStorage()
+{
+ m_Gc.RemoveGcStorage(this);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+GcManager::GcManager() : m_Log(logging::Get("gc"))
+{
+}
+
+GcManager::~GcManager()
+{
+}
+
+void
+GcManager::AddGcContributor(GcContributor* Contributor)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_GcContribs.push_back(Contributor);
+}
+
+void
+GcManager::RemoveGcContributor(GcContributor* Contributor)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ std::erase_if(m_GcContribs, [&](GcContributor* $) { return $ == Contributor; });
+}
+
+void
+GcManager::AddGcStorage(GcStorage* Storage)
+{
+ ZEN_ASSERT(Storage != nullptr);
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_GcStorage.push_back(Storage);
+}
+
+void
+GcManager::RemoveGcStorage(GcStorage* Storage)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ std::erase_if(m_GcStorage, [&](GcStorage* $) { return $ == Storage; });
+}
+
+void
+GcManager::CollectGarbage(GcContext& GcCtx)
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ // First gather reference set
+ {
+ Stopwatch Timer;
+ const auto Guard = MakeGuard([&] { ZEN_INFO("gathered references in {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); });
+ for (GcContributor* Contributor : m_GcContribs)
+ {
+ Contributor->GatherReferences(GcCtx);
+ }
+ }
+
+ // Then trim storage
+ {
+ GcStorageSize GCTotalSizeDiff;
+ Stopwatch Timer;
+ const auto Guard = MakeGuard([&] {
+ ZEN_INFO("collected garbage in {}. Removed {} disk space, {} memory",
+ NiceTimeSpanMs(Timer.GetElapsedTimeMs()),
+ NiceBytes(GCTotalSizeDiff.DiskSize),
+ NiceBytes(GCTotalSizeDiff.MemorySize));
+ });
+ for (GcStorage* Storage : m_GcStorage)
+ {
+ const auto PreSize = Storage->StorageSize();
+ Storage->CollectGarbage(GcCtx);
+ const auto PostSize = Storage->StorageSize();
+ GCTotalSizeDiff.DiskSize += PreSize.DiskSize > PostSize.DiskSize ? PreSize.DiskSize - PostSize.DiskSize : 0;
+ GCTotalSizeDiff.MemorySize += PreSize.MemorySize > PostSize.MemorySize ? PreSize.MemorySize - PostSize.MemorySize : 0;
+ }
+ }
+}
+
+GcStorageSize
+GcManager::TotalStorageSize() const
+{
+ RwLock::SharedLockScope _(m_Lock);
+
+ GcStorageSize TotalSize;
+
+ for (GcStorage* Storage : m_GcStorage)
+ {
+ const auto Size = Storage->StorageSize();
+ TotalSize.DiskSize += Size.DiskSize;
+ TotalSize.MemorySize += Size.MemorySize;
+ }
+
+ return TotalSize;
+}
+
+#if ZEN_USE_REF_TRACKING
+void
+GcManager::OnNewCidReferences(std::span<IoHash> Hashes)
+{
+ ZEN_UNUSED(Hashes);
+}
+
+void
+GcManager::OnCommittedCidReferences(std::span<IoHash> Hashes)
+{
+ ZEN_UNUSED(Hashes);
+}
+
+void
+GcManager::OnDroppedCidReferences(std::span<IoHash> Hashes)
+{
+ ZEN_UNUSED(Hashes);
+}
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+void
+DiskUsageWindow::KeepRange(GcClock::Tick StartTick, GcClock::Tick EndTick)
+{
+ auto It = m_LogWindow.begin();
+ if (It == m_LogWindow.end())
+ {
+ return;
+ }
+ while (It->SampleTime < StartTick)
+ {
+ ++It;
+ if (It == m_LogWindow.end())
+ {
+ m_LogWindow.clear();
+ return;
+ }
+ }
+ m_LogWindow.erase(m_LogWindow.begin(), It);
+
+ It = m_LogWindow.begin();
+ while (It != m_LogWindow.end())
+ {
+ if (It->SampleTime >= EndTick)
+ {
+ m_LogWindow.erase(It, m_LogWindow.end());
+ return;
+ }
+ It++;
+ }
+}
+
+std::vector<uint64_t>
+DiskUsageWindow::GetDiskDeltas(GcClock::Tick StartTick, GcClock::Tick EndTick, GcClock::Tick DeltaWidth, uint64_t& OutMaxDelta) const
+{
+ ZEN_ASSERT(StartTick != -1);
+ ZEN_ASSERT(DeltaWidth > 0);
+
+ std::vector<uint64_t> Result;
+ Result.reserve((EndTick - StartTick + DeltaWidth - 1) / DeltaWidth);
+
+ size_t WindowSize = m_LogWindow.size();
+ GcClock::Tick FirstWindowTick = WindowSize < 2 ? EndTick : m_LogWindow[1].SampleTime;
+
+ GcClock::Tick RangeStart = StartTick;
+ while (FirstWindowTick >= RangeStart + DeltaWidth && RangeStart < EndTick)
+ {
+ Result.push_back(0);
+ RangeStart += DeltaWidth;
+ }
+
+ uint64_t DeltaSum = 0;
+ size_t WindowIndex = 1;
+ while (WindowIndex < WindowSize && RangeStart < EndTick)
+ {
+ const DiskUsageEntry& Entry = m_LogWindow[WindowIndex];
+ if (Entry.SampleTime < RangeStart)
+ {
+ ++WindowIndex;
+ continue;
+ }
+ GcClock::Tick RangeEnd = Min(EndTick, RangeStart + DeltaWidth);
+ ZEN_ASSERT(Entry.SampleTime >= RangeStart);
+ if (Entry.SampleTime >= RangeEnd)
+ {
+ Result.push_back(DeltaSum);
+ OutMaxDelta = Max(DeltaSum, OutMaxDelta);
+ DeltaSum = 0;
+ RangeStart = RangeEnd;
+ continue;
+ }
+ const DiskUsageEntry& PrevEntry = m_LogWindow[WindowIndex - 1];
+ if (Entry.DiskUsage > PrevEntry.DiskUsage)
+ {
+ uint64_t Delta = Entry.DiskUsage - PrevEntry.DiskUsage;
+ DeltaSum += Delta;
+ }
+ WindowIndex++;
+ }
+
+ while (RangeStart < EndTick)
+ {
+ Result.push_back(DeltaSum);
+ OutMaxDelta = Max(DeltaSum, OutMaxDelta);
+ DeltaSum = 0;
+ RangeStart += DeltaWidth;
+ }
+ return Result;
+}
+
+GcClock::Tick
+DiskUsageWindow::FindTimepointThatRemoves(uint64_t Amount, GcClock::Tick EndTick) const
+{
+ ZEN_ASSERT(Amount > 0);
+ uint64_t RemainingToFind = Amount;
+ size_t Offset = 1;
+ while (Offset < m_LogWindow.size())
+ {
+ const DiskUsageEntry& Entry = m_LogWindow[Offset];
+ if (Entry.SampleTime >= EndTick)
+ {
+ return EndTick;
+ }
+ const DiskUsageEntry& PreviousEntry = m_LogWindow[Offset - 1];
+ uint64_t Delta = Entry.DiskUsage > PreviousEntry.DiskUsage ? Entry.DiskUsage - PreviousEntry.DiskUsage : 0;
+ if (Delta >= RemainingToFind)
+ {
+ return m_LogWindow[Offset].SampleTime + 1;
+ }
+ RemainingToFind -= Delta;
+ Offset++;
+ }
+ return EndTick;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+GcScheduler::GcScheduler(GcManager& GcManager) : m_Log(logging::Get("gc")), m_GcManager(GcManager)
+{
+}
+
+GcScheduler::~GcScheduler()
+{
+ Shutdown();
+}
+
+void
+GcScheduler::Initialize(const GcSchedulerConfig& Config)
+{
+ using namespace std::chrono;
+
+ m_Config = Config;
+
+ if (m_Config.Interval.count() && m_Config.Interval < m_Config.MonitorInterval)
+ {
+ m_Config.Interval = m_Config.MonitorInterval;
+ }
+
+ std::filesystem::create_directories(Config.RootDirectory);
+
+ std::error_code Ec = CreateGCReserve(m_Config.RootDirectory / "reserve.gc", m_Config.DiskReserveSize);
+ if (Ec)
+ {
+ ZEN_WARN("unable to create GC reserve at '{}' with size {}, reason '{}'",
+ m_Config.RootDirectory / "reserve.gc",
+ NiceBytes(m_Config.DiskReserveSize),
+ Ec.message());
+ }
+
+ m_LastGcTime = GcClock::Now();
+ m_LastGcExpireTime = GcClock::TimePoint::min();
+
+ if (CbObject SchedulerState = LoadCompactBinaryObject(Config.RootDirectory / "gc_state"))
+ {
+ m_LastGcTime = GcClock::TimePoint(GcClock::Duration(SchedulerState["LastGcTime"sv].AsInt64()));
+ m_LastGcExpireTime =
+ GcClock::TimePoint(GcClock::Duration(SchedulerState["LastGcExpireTime"].AsInt64(GcClock::Duration::min().count())));
+ if (m_LastGcTime + m_Config.Interval < GcClock::Now())
+ {
+ // TODO: Trigger GC?
+ m_LastGcTime = GcClock::Now();
+ }
+ }
+
+ m_DiskUsageLog.Open(m_Config.RootDirectory / "gc.dlog", CasLogFile::Mode::kWrite);
+ m_DiskUsageLog.Initialize();
+ const GcClock::Tick LastGCTick = m_LastGcTime.time_since_epoch().count();
+ m_DiskUsageLog.Replay(
+ [this, LastGCTick](const DiskUsageWindow::DiskUsageEntry& Entry) {
+ if (Entry.SampleTime >= m_LastGcExpireTime.time_since_epoch().count())
+ {
+ m_DiskUsageWindow.Append(Entry);
+ }
+ },
+ 0);
+
+ m_NextGcTime = NextGcTime(m_LastGcTime);
+ m_GcThread = std::thread(&GcScheduler::SchedulerThread, this);
+}
+
+void
+GcScheduler::Shutdown()
+{
+ if (static_cast<uint32_t>(GcSchedulerStatus::kStopped) != m_Status)
+ {
+ bool GcIsRunning = m_Status == static_cast<uint32_t>(GcSchedulerStatus::kRunning);
+ m_Status = static_cast<uint32_t>(GcSchedulerStatus::kStopped);
+ m_GcSignal.notify_one();
+
+ if (m_GcThread.joinable())
+ {
+ if (GcIsRunning)
+ {
+ ZEN_INFO("Waiting for garbage collection to complete");
+ }
+ m_GcThread.join();
+ }
+ }
+ m_DiskUsageLog.Flush();
+ m_DiskUsageLog.Close();
+}
+
+bool
+GcScheduler::Trigger(const GcScheduler::TriggerParams& Params)
+{
+ if (m_Config.Enabled)
+ {
+ std::unique_lock Lock(m_GcMutex);
+ if (static_cast<uint32_t>(GcSchedulerStatus::kIdle) == m_Status)
+ {
+ m_TriggerParams = Params;
+ uint32_t IdleState = static_cast<uint32_t>(GcSchedulerStatus::kIdle);
+ if (m_Status.compare_exchange_strong(IdleState, static_cast<uint32_t>(GcSchedulerStatus::kRunning)))
+ {
+ m_GcSignal.notify_one();
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+void
+GcScheduler::SchedulerThread()
+{
+ std::chrono::seconds WaitTime{0};
+
+ for (;;)
+ {
+ bool Timeout = false;
+ {
+ ZEN_ASSERT(WaitTime.count() >= 0);
+ std::unique_lock Lock(m_GcMutex);
+ Timeout = std::cv_status::timeout == m_GcSignal.wait_for(Lock, WaitTime);
+ }
+
+ if (Status() == GcSchedulerStatus::kStopped)
+ {
+ break;
+ }
+
+ if (!m_Config.Enabled)
+ {
+ WaitTime = std::chrono::seconds::max();
+ continue;
+ }
+
+ if (!Timeout && Status() == GcSchedulerStatus::kIdle)
+ {
+ continue;
+ }
+
+ bool Delete = true;
+ bool CollectSmallObjects = m_Config.CollectSmallObjects;
+ std::chrono::seconds MaxCacheDuration = m_Config.MaxCacheDuration;
+ uint64_t DiskSizeSoftLimit = m_Config.DiskSizeSoftLimit;
+ GcClock::TimePoint Now = GcClock::Now();
+ if (m_TriggerParams)
+ {
+ const auto TriggerParams = m_TriggerParams.value();
+ m_TriggerParams.reset();
+
+ CollectSmallObjects = TriggerParams.CollectSmallObjects;
+ if (TriggerParams.MaxCacheDuration != std::chrono::seconds::max())
+ {
+ MaxCacheDuration = TriggerParams.MaxCacheDuration;
+ }
+ if (TriggerParams.DiskSizeSoftLimit != 0)
+ {
+ DiskSizeSoftLimit = TriggerParams.DiskSizeSoftLimit;
+ }
+ }
+
+ GcClock::TimePoint ExpireTime = MaxCacheDuration == GcClock::Duration::max() ? GcClock::TimePoint::min() : Now - MaxCacheDuration;
+
+ std::error_code Ec;
+ const GcStorageSize TotalSize = m_GcManager.TotalStorageSize();
+
+ if (Timeout && Status() == GcSchedulerStatus::kIdle)
+ {
+ DiskSpace Space = DiskSpaceInfo(m_Config.RootDirectory, Ec);
+ if (Ec)
+ {
+ ZEN_WARN("get disk space info FAILED, reason: '{}'", Ec.message());
+ }
+
+ const int64_t PressureGraphLength = 30;
+ const std::chrono::duration LoadGraphTime = PressureGraphLength * m_Config.MonitorInterval;
+ std::vector<uint64_t> DiskDeltas;
+ uint64_t MaxLoad = 0;
+ {
+ const GcClock::Tick EpochTickCount = GcClock::Now().time_since_epoch().count();
+ std::unique_lock Lock(m_GcMutex);
+ m_DiskUsageWindow.Append({.SampleTime = EpochTickCount, .DiskUsage = TotalSize.DiskSize});
+ m_DiskUsageLog.Append({.SampleTime = EpochTickCount, .DiskUsage = TotalSize.DiskSize});
+ const GcClock::TimePoint LoadGraphStartTime = Now - LoadGraphTime;
+ GcClock::Tick Start = LoadGraphStartTime.time_since_epoch().count();
+ GcClock::Tick End = Now.time_since_epoch().count();
+ DiskDeltas = m_DiskUsageWindow.GetDiskDeltas(Start,
+ End,
+ Max(1, (End - Start + PressureGraphLength - 1) / PressureGraphLength),
+ MaxLoad);
+ }
+
+ std::string LoadGraph;
+ LoadGraph.resize(DiskDeltas.size(), '0');
+ if (DiskDeltas.size() > 0 && MaxLoad > 0)
+ {
+ char LoadIndicator[11] = "0123456789";
+ for (size_t Index = 0; Index < DiskDeltas.size(); ++Index)
+ {
+ size_t LoadIndex = (9 * DiskDeltas[Index] + MaxLoad - 1) / MaxLoad;
+ LoadGraph[Index] = LoadIndicator[LoadIndex];
+ }
+ }
+
+ uint64_t GcDiskSpaceGoal = 0;
+ if (DiskSizeSoftLimit != 0 && TotalSize.DiskSize > DiskSizeSoftLimit)
+ {
+ GcDiskSpaceGoal = TotalSize.DiskSize - DiskSizeSoftLimit;
+ std::unique_lock Lock(m_GcMutex);
+ GcClock::Tick AgeTick = m_DiskUsageWindow.FindTimepointThatRemoves(GcDiskSpaceGoal, Now.time_since_epoch().count());
+ GcClock::TimePoint SizeBasedExpireTime = GcClock::TimePointFromTick(AgeTick);
+ if (SizeBasedExpireTime > ExpireTime)
+ {
+ ExpireTime = SizeBasedExpireTime;
+ }
+ }
+
+ bool DiskSpaceGCTriggered = GcDiskSpaceGoal > 0;
+
+ std::chrono::seconds RemaingTime = std::chrono::duration_cast<std::chrono::seconds>(m_NextGcTime - GcClock::Now());
+
+ if (RemaingTime < std::chrono::seconds::zero())
+ {
+ RemaingTime = std::chrono::seconds::zero();
+ }
+
+ bool TimeBasedGCTriggered = !DiskSpaceGCTriggered && RemaingTime.count() == 0;
+ ZEN_INFO(
+ "{} in use,{} {} of total {} free disk space, disk writes last {} per {} [{}], peak {}/s. {}",
+ NiceBytes(TotalSize.DiskSize),
+ DiskSizeSoftLimit == 0 ? "" : fmt::format(" {} soft limit,", NiceBytes(DiskSizeSoftLimit)),
+ NiceBytes(Space.Free),
+ NiceBytes(Space.Total),
+ NiceTimeSpanMs(uint64_t(std::chrono::milliseconds(LoadGraphTime).count())),
+ NiceTimeSpanMs(uint64_t(std::chrono::milliseconds(LoadGraphTime).count() / PressureGraphLength)),
+ LoadGraph,
+ NiceBytes(MaxLoad * uint64_t(std::chrono::seconds(1).count()) / uint64_t(std::chrono::seconds(LoadGraphTime).count())),
+ DiskSpaceGCTriggered ? fmt::format("Disk use threshold triggered, trying to reclaim {}. ", NiceBytes(GcDiskSpaceGoal))
+ : TimeBasedGCTriggered ? "GC schedule triggered."
+ : m_NextGcTime == GcClock::TimePoint::max()
+ ? ""
+ : fmt::format("{} until next scheduled GC.", NiceTimeSpanMs(uint64_t(std::chrono::milliseconds(RemaingTime).count()))));
+
+ if (!DiskSpaceGCTriggered && !TimeBasedGCTriggered)
+ {
+ WaitTime = m_Config.MonitorInterval < RemaingTime ? m_Config.MonitorInterval : RemaingTime;
+ continue;
+ }
+
+ WaitTime = m_Config.MonitorInterval;
+ uint32_t IdleState = static_cast<uint32_t>(GcSchedulerStatus::kIdle);
+ if (!m_Status.compare_exchange_strong(IdleState, static_cast<uint32_t>(GcSchedulerStatus::kRunning)))
+ {
+ continue;
+ }
+ }
+
+ CollectGarbage(ExpireTime, Delete, CollectSmallObjects);
+
+ uint32_t RunningState = static_cast<uint32_t>(GcSchedulerStatus::kRunning);
+ if (!m_Status.compare_exchange_strong(RunningState, static_cast<uint32_t>(GcSchedulerStatus::kIdle)))
+ {
+ ZEN_ASSERT(m_Status == static_cast<uint32_t>(GcSchedulerStatus::kStopped));
+ break;
+ }
+
+ WaitTime = m_Config.MonitorInterval;
+ }
+}
+
+GcClock::TimePoint
+GcScheduler::NextGcTime(GcClock::TimePoint CurrentTime)
+{
+ if (m_Config.Interval.count())
+ {
+ return CurrentTime + m_Config.Interval;
+ }
+ else
+ {
+ return GcClock::TimePoint::max();
+ }
+}
+
+void
+GcScheduler::CollectGarbage(const GcClock::TimePoint& ExpireTime, bool Delete, bool CollectSmallObjects)
+{
+ GcContext GcCtx(ExpireTime);
+ GcCtx.SetDeletionMode(Delete);
+ GcCtx.CollectSmallObjects(CollectSmallObjects);
+ // GcCtx.MaxCacheDuration(MaxCacheDuration);
+ GcCtx.DiskReservePath(m_Config.RootDirectory / "reserve.gc");
+
+ ZEN_INFO("garbage collection STARTING, small objects gc {}, cutoff time {}",
+ GcCtx.CollectSmallObjects() ? "ENABLED"sv : "DISABLED"sv,
+ ExpireTime);
+ {
+ Stopwatch Timer;
+ const auto __ = MakeGuard([&] { ZEN_INFO("garbage collection DONE in {}", NiceTimeSpanMs(Timer.GetElapsedTimeMs())); });
+
+ m_GcManager.CollectGarbage(GcCtx);
+
+ if (Delete)
+ {
+ m_LastGcExpireTime = ExpireTime;
+ std::unique_lock Lock(m_GcMutex);
+ m_DiskUsageWindow.KeepRange(ExpireTime.time_since_epoch().count(), GcClock::Duration::max().count());
+ }
+
+ m_LastGcTime = GcClock::Now();
+ m_NextGcTime = NextGcTime(m_LastGcTime);
+
+ {
+ const fs::path Path = m_Config.RootDirectory / "gc_state";
+ ZEN_DEBUG("saving scheduler state to '{}'", Path);
+ CbObjectWriter SchedulerState;
+ SchedulerState << "LastGcTime"sv << static_cast<int64_t>(m_LastGcTime.time_since_epoch().count());
+ SchedulerState << "LastGcExpireTime"sv << static_cast<int64_t>(m_LastGcExpireTime.time_since_epoch().count());
+ SaveCompactBinaryObject(Path, SchedulerState.Save());
+ }
+
+ std::error_code Ec = CreateGCReserve(m_Config.RootDirectory / "reserve.gc", m_Config.DiskReserveSize);
+ if (Ec)
+ {
+ ZEN_WARN("unable to create GC reserve at '{}' with size {}, reason: '{}'",
+ m_Config.RootDirectory / "reserve.gc",
+ NiceBytes(m_Config.DiskReserveSize),
+ Ec.message());
+ }
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+#if ZEN_WITH_TESTS
+
+namespace gc::impl {
+ static IoBuffer CreateChunk(uint64_t Size)
+ {
+ static std::random_device rd;
+ static std::mt19937 g(rd());
+
+ std::vector<uint8_t> Values;
+ Values.resize(Size);
+ for (size_t Idx = 0; Idx < Size; ++Idx)
+ {
+ Values[Idx] = static_cast<uint8_t>(Idx);
+ }
+ std::shuffle(Values.begin(), Values.end(), g);
+
+ return IoBufferBuilder::MakeCloneFromMemory(Values.data(), Values.size());
+ }
+
+ static CompressedBuffer Compress(IoBuffer Buffer)
+ {
+ return CompressedBuffer::Compress(SharedBuffer::MakeView(Buffer.GetData(), Buffer.GetSize()));
+ }
+} // namespace gc::impl
+
+TEST_CASE("gc.basic")
+{
+ using namespace gc::impl;
+
+ ScopedTemporaryDirectory TempDir;
+
+ CidStoreConfiguration CasConfig;
+ CasConfig.RootDirectory = TempDir.Path() / "cas";
+
+ GcManager Gc;
+ CidStore CidStore(Gc);
+
+ CidStore.Initialize(CasConfig);
+
+ IoBuffer Chunk = CreateChunk(128);
+ auto CompressedChunk = Compress(Chunk);
+
+ const auto InsertResult = CidStore.AddChunk(CompressedChunk.GetCompressed().Flatten().AsIoBuffer(), CompressedChunk.DecodeRawHash());
+ CHECK(InsertResult.New);
+
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+
+ CidStore.Flush();
+ Gc.CollectGarbage(GcCtx);
+
+ CHECK(!CidStore.ContainsChunk(CompressedChunk.DecodeRawHash()));
+}
+
+TEST_CASE("gc.full")
+{
+ using namespace gc::impl;
+
+ ScopedTemporaryDirectory TempDir;
+
+ CidStoreConfiguration CasConfig;
+ CasConfig.RootDirectory = TempDir.Path() / "cas";
+
+ GcManager Gc;
+ std::unique_ptr<CasStore> CasStore = CreateCasStore(Gc);
+
+ CasStore->Initialize(CasConfig);
+
+ uint64_t ChunkSizes[9] = {128, 541, 1023, 781, 218, 37, 4, 997, 5};
+ IoBuffer Chunks[9] = {CreateChunk(ChunkSizes[0]),
+ CreateChunk(ChunkSizes[1]),
+ CreateChunk(ChunkSizes[2]),
+ CreateChunk(ChunkSizes[3]),
+ CreateChunk(ChunkSizes[4]),
+ CreateChunk(ChunkSizes[5]),
+ CreateChunk(ChunkSizes[6]),
+ CreateChunk(ChunkSizes[7]),
+ CreateChunk(ChunkSizes[8])};
+ IoHash ChunkHashes[9] = {
+ IoHash::HashBuffer(Chunks[0].Data(), Chunks[0].Size()),
+ IoHash::HashBuffer(Chunks[1].Data(), Chunks[1].Size()),
+ IoHash::HashBuffer(Chunks[2].Data(), Chunks[2].Size()),
+ IoHash::HashBuffer(Chunks[3].Data(), Chunks[3].Size()),
+ IoHash::HashBuffer(Chunks[4].Data(), Chunks[4].Size()),
+ IoHash::HashBuffer(Chunks[5].Data(), Chunks[5].Size()),
+ IoHash::HashBuffer(Chunks[6].Data(), Chunks[6].Size()),
+ IoHash::HashBuffer(Chunks[7].Data(), Chunks[7].Size()),
+ IoHash::HashBuffer(Chunks[8].Data(), Chunks[8].Size()),
+ };
+
+ CasStore->InsertChunk(Chunks[0], ChunkHashes[0]);
+ CasStore->InsertChunk(Chunks[1], ChunkHashes[1]);
+ CasStore->InsertChunk(Chunks[2], ChunkHashes[2]);
+ CasStore->InsertChunk(Chunks[3], ChunkHashes[3]);
+ CasStore->InsertChunk(Chunks[4], ChunkHashes[4]);
+ CasStore->InsertChunk(Chunks[5], ChunkHashes[5]);
+ CasStore->InsertChunk(Chunks[6], ChunkHashes[6]);
+ CasStore->InsertChunk(Chunks[7], ChunkHashes[7]);
+ CasStore->InsertChunk(Chunks[8], ChunkHashes[8]);
+
+ CidStoreSize InitialSize = CasStore->TotalSize();
+
+ // Keep first and last
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+
+ std::vector<IoHash> KeepChunks;
+ KeepChunks.push_back(ChunkHashes[0]);
+ KeepChunks.push_back(ChunkHashes[8]);
+ GcCtx.AddRetainedCids(KeepChunks);
+
+ CasStore->Flush();
+ Gc.CollectGarbage(GcCtx);
+
+ CHECK(CasStore->ContainsChunk(ChunkHashes[0]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[1]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[2]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[3]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[4]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[5]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[6]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[7]));
+ CHECK(CasStore->ContainsChunk(ChunkHashes[8]));
+
+ CHECK(ChunkHashes[0] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[0])));
+ CHECK(ChunkHashes[8] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[8])));
+ }
+
+ CasStore->InsertChunk(Chunks[1], ChunkHashes[1]);
+ CasStore->InsertChunk(Chunks[2], ChunkHashes[2]);
+ CasStore->InsertChunk(Chunks[3], ChunkHashes[3]);
+ CasStore->InsertChunk(Chunks[4], ChunkHashes[4]);
+ CasStore->InsertChunk(Chunks[5], ChunkHashes[5]);
+ CasStore->InsertChunk(Chunks[6], ChunkHashes[6]);
+ CasStore->InsertChunk(Chunks[7], ChunkHashes[7]);
+
+ // Keep last
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ std::vector<IoHash> KeepChunks;
+ KeepChunks.push_back(ChunkHashes[8]);
+ GcCtx.AddRetainedCids(KeepChunks);
+
+ CasStore->Flush();
+ Gc.CollectGarbage(GcCtx);
+
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[0]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[1]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[2]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[3]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[4]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[5]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[6]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[7]));
+ CHECK(CasStore->ContainsChunk(ChunkHashes[8]));
+
+ CHECK(ChunkHashes[8] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[8])));
+
+ CasStore->InsertChunk(Chunks[1], ChunkHashes[1]);
+ CasStore->InsertChunk(Chunks[2], ChunkHashes[2]);
+ CasStore->InsertChunk(Chunks[3], ChunkHashes[3]);
+ CasStore->InsertChunk(Chunks[4], ChunkHashes[4]);
+ CasStore->InsertChunk(Chunks[5], ChunkHashes[5]);
+ CasStore->InsertChunk(Chunks[6], ChunkHashes[6]);
+ CasStore->InsertChunk(Chunks[7], ChunkHashes[7]);
+ }
+
+ // Keep mixed
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ std::vector<IoHash> KeepChunks;
+ KeepChunks.push_back(ChunkHashes[1]);
+ KeepChunks.push_back(ChunkHashes[4]);
+ KeepChunks.push_back(ChunkHashes[7]);
+ GcCtx.AddRetainedCids(KeepChunks);
+
+ CasStore->Flush();
+ Gc.CollectGarbage(GcCtx);
+
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[0]));
+ CHECK(CasStore->ContainsChunk(ChunkHashes[1]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[2]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[3]));
+ CHECK(CasStore->ContainsChunk(ChunkHashes[4]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[5]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[6]));
+ CHECK(CasStore->ContainsChunk(ChunkHashes[7]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[8]));
+
+ CHECK(ChunkHashes[1] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[1])));
+ CHECK(ChunkHashes[4] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[4])));
+ CHECK(ChunkHashes[7] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[7])));
+
+ CasStore->InsertChunk(Chunks[0], ChunkHashes[0]);
+ CasStore->InsertChunk(Chunks[2], ChunkHashes[2]);
+ CasStore->InsertChunk(Chunks[3], ChunkHashes[3]);
+ CasStore->InsertChunk(Chunks[5], ChunkHashes[5]);
+ CasStore->InsertChunk(Chunks[6], ChunkHashes[6]);
+ CasStore->InsertChunk(Chunks[8], ChunkHashes[8]);
+ }
+
+ // Keep multiple at end
+ {
+ GcContext GcCtx(GcClock::Now() - std::chrono::hours(24));
+ GcCtx.CollectSmallObjects(true);
+ std::vector<IoHash> KeepChunks;
+ KeepChunks.push_back(ChunkHashes[6]);
+ KeepChunks.push_back(ChunkHashes[7]);
+ KeepChunks.push_back(ChunkHashes[8]);
+ GcCtx.AddRetainedCids(KeepChunks);
+
+ CasStore->Flush();
+ Gc.CollectGarbage(GcCtx);
+
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[0]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[1]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[2]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[3]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[4]));
+ CHECK(!CasStore->ContainsChunk(ChunkHashes[5]));
+ CHECK(CasStore->ContainsChunk(ChunkHashes[6]));
+ CHECK(CasStore->ContainsChunk(ChunkHashes[7]));
+ CHECK(CasStore->ContainsChunk(ChunkHashes[8]));
+
+ CHECK(ChunkHashes[6] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[6])));
+ CHECK(ChunkHashes[7] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[7])));
+ CHECK(ChunkHashes[8] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[8])));
+
+ CasStore->InsertChunk(Chunks[0], ChunkHashes[0]);
+ CasStore->InsertChunk(Chunks[1], ChunkHashes[1]);
+ CasStore->InsertChunk(Chunks[2], ChunkHashes[2]);
+ CasStore->InsertChunk(Chunks[3], ChunkHashes[3]);
+ CasStore->InsertChunk(Chunks[4], ChunkHashes[4]);
+ CasStore->InsertChunk(Chunks[5], ChunkHashes[5]);
+ }
+
+ // Verify that we nicely appended blocks even after all GC operations
+ CHECK(ChunkHashes[0] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[0])));
+ CHECK(ChunkHashes[1] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[1])));
+ CHECK(ChunkHashes[2] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[2])));
+ CHECK(ChunkHashes[3] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[3])));
+ CHECK(ChunkHashes[4] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[4])));
+ CHECK(ChunkHashes[5] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[5])));
+ CHECK(ChunkHashes[6] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[6])));
+ CHECK(ChunkHashes[7] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[7])));
+ CHECK(ChunkHashes[8] == IoHash::HashBuffer(CasStore->FindChunk(ChunkHashes[8])));
+
+ auto FinalSize = CasStore->TotalSize();
+
+ CHECK_LE(InitialSize.TinySize, FinalSize.TinySize);
+ CHECK_GE(InitialSize.TinySize + (1u << 28), FinalSize.TinySize);
+}
+
+TEST_CASE("gc.diskusagewindow")
+{
+ using namespace gc::impl;
+
+ DiskUsageWindow Stats;
+ Stats.Append({.SampleTime = 0, .DiskUsage = 0}); // 0 0
+ Stats.Append({.SampleTime = 10, .DiskUsage = 10}); // 1 10
+ Stats.Append({.SampleTime = 20, .DiskUsage = 20}); // 2 10
+ Stats.Append({.SampleTime = 30, .DiskUsage = 20}); // 3 0
+ Stats.Append({.SampleTime = 40, .DiskUsage = 15}); // 4 0
+ Stats.Append({.SampleTime = 50, .DiskUsage = 25}); // 5 10
+ Stats.Append({.SampleTime = 60, .DiskUsage = 30}); // 6 5
+ Stats.Append({.SampleTime = 70, .DiskUsage = 45}); // 7 15
+
+ SUBCASE("Truncate start")
+ {
+ Stats.KeepRange(-15, 31);
+ CHECK(Stats.m_LogWindow.size() == 4);
+ CHECK(Stats.m_LogWindow[0].SampleTime == 0);
+ CHECK(Stats.m_LogWindow[3].SampleTime == 30);
+ }
+
+ SUBCASE("Truncate end")
+ {
+ Stats.KeepRange(70, 71);
+ CHECK(Stats.m_LogWindow.size() == 1);
+ CHECK(Stats.m_LogWindow[0].SampleTime == 70);
+ }
+
+ SUBCASE("Truncate middle")
+ {
+ Stats.KeepRange(29, 69);
+ CHECK(Stats.m_LogWindow.size() == 4);
+ CHECK(Stats.m_LogWindow[0].SampleTime == 30);
+ CHECK(Stats.m_LogWindow[3].SampleTime == 60);
+ }
+
+ SUBCASE("Full range")
+ {
+ uint64_t MaxDelta = 0;
+ // 0-10, 10-20, 20-30, 30-40, 40-50, 50-60, 60-70, 70-80
+ std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(0, 80, 10, MaxDelta);
+ CHECK(DiskDeltas.size() == 8);
+ CHECK(MaxDelta == 15);
+ CHECK(DiskDeltas[0] == 0);
+ CHECK(DiskDeltas[1] == 10);
+ CHECK(DiskDeltas[2] == 10);
+ CHECK(DiskDeltas[3] == 0);
+ CHECK(DiskDeltas[4] == 0);
+ CHECK(DiskDeltas[5] == 10);
+ CHECK(DiskDeltas[6] == 5);
+ CHECK(DiskDeltas[7] == 15);
+ }
+
+ SUBCASE("Sub range")
+ {
+ uint64_t MaxDelta = 0;
+ std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(20, 40, 10, MaxDelta);
+ CHECK(DiskDeltas.size() == 2);
+ CHECK(MaxDelta == 10);
+ CHECK(DiskDeltas[0] == 10); // [20:30]
+ CHECK(DiskDeltas[1] == 0); // [30:40]
+ }
+ SUBCASE("Unaligned sub range 1")
+ {
+ uint64_t MaxDelta = 0;
+ std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(21, 51, 10, MaxDelta);
+ CHECK(DiskDeltas.size() == 3);
+ CHECK(MaxDelta == 10);
+ CHECK(DiskDeltas[0] == 0); // [21:31]
+ CHECK(DiskDeltas[1] == 0); // [31:41]
+ CHECK(DiskDeltas[2] == 10); // [41:51]
+ }
+ SUBCASE("Unaligned end range")
+ {
+ uint64_t MaxDelta = 0;
+ std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(29, 79, 10, MaxDelta);
+ CHECK(DiskDeltas.size() == 5);
+ CHECK(MaxDelta == 15);
+ CHECK(DiskDeltas[0] == 0); // [29:39]
+ CHECK(DiskDeltas[1] == 0); // [39:49]
+ CHECK(DiskDeltas[2] == 10); // [49:59]
+ CHECK(DiskDeltas[3] == 5); // [59:69]
+ CHECK(DiskDeltas[4] == 15); // [69:79]
+ }
+ SUBCASE("Ahead of window")
+ {
+ uint64_t MaxDelta = 0;
+ std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(-40, 0, 10, MaxDelta);
+ CHECK(DiskDeltas.size() == 4);
+ CHECK(MaxDelta == 0);
+ CHECK(DiskDeltas[0] == 0); // [-40:-30]
+ CHECK(DiskDeltas[1] == 0); // [-30:-20]
+ CHECK(DiskDeltas[2] == 0); // [-20:-10]
+ CHECK(DiskDeltas[3] == 0); // [-10:0]
+ }
+ SUBCASE("After of window")
+ {
+ uint64_t MaxDelta = 0;
+ std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(90, 120, 10, MaxDelta);
+ CHECK(DiskDeltas.size() == 3);
+ CHECK(MaxDelta == 0);
+ CHECK(DiskDeltas[0] == 0); // [90:100]
+ CHECK(DiskDeltas[1] == 0); // [100:110]
+ CHECK(DiskDeltas[2] == 0); // [110:120]
+ }
+ SUBCASE("Encapsulating window")
+ {
+ uint64_t MaxDelta = 0;
+ std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(-20, 100, 10, MaxDelta);
+ CHECK(DiskDeltas.size() == 12);
+ CHECK(MaxDelta == 15);
+ CHECK(DiskDeltas[0] == 0); // [-20:-10]
+ CHECK(DiskDeltas[1] == 0); // [ -10:0]
+ CHECK(DiskDeltas[2] == 0); // [0:10]
+ CHECK(DiskDeltas[3] == 10); // [10:20]
+ CHECK(DiskDeltas[4] == 10); // [20:30]
+ CHECK(DiskDeltas[5] == 0); // [30:40]
+ CHECK(DiskDeltas[6] == 0); // [40:50]
+ CHECK(DiskDeltas[7] == 10); // [50:60]
+ CHECK(DiskDeltas[8] == 5); // [60:70]
+ CHECK(DiskDeltas[9] == 15); // [70:80]
+ CHECK(DiskDeltas[10] == 0); // [80:90]
+ CHECK(DiskDeltas[11] == 0); // [90:100]
+ }
+
+ SUBCASE("Full range half stride")
+ {
+ uint64_t MaxDelta = 0;
+ std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(0, 80, 20, MaxDelta);
+ CHECK(DiskDeltas.size() == 4);
+ CHECK(MaxDelta == 20);
+ CHECK(DiskDeltas[0] == 10); // [0:20]
+ CHECK(DiskDeltas[1] == 10); // [20:40]
+ CHECK(DiskDeltas[2] == 10); // [40:60]
+ CHECK(DiskDeltas[3] == 20); // [60:80]
+ }
+
+ SUBCASE("Partial odd stride")
+ {
+ uint64_t MaxDelta = 0;
+ std::vector<uint64_t> DiskDeltas = Stats.GetDiskDeltas(13, 67, 18, MaxDelta);
+ CHECK(DiskDeltas.size() == 3);
+ CHECK(MaxDelta == 15);
+ CHECK(DiskDeltas[0] == 10); // [13:31]
+ CHECK(DiskDeltas[1] == 0); // [31:49]
+ CHECK(DiskDeltas[2] == 15); // [49:67]
+ }
+
+ SUBCASE("Find size window")
+ {
+ DiskUsageWindow Empty;
+ CHECK(Empty.FindTimepointThatRemoves(15u, 10000) == 10000);
+
+ CHECK(Stats.FindTimepointThatRemoves(15u, 40) == 21);
+ CHECK(Stats.FindTimepointThatRemoves(15u, 20) == 20);
+ CHECK(Stats.FindTimepointThatRemoves(100000u, 50) == 50);
+ CHECK(Stats.FindTimepointThatRemoves(100000u, 1000));
+ }
+}
+#endif
+
+void
+gc_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenstore/hashkeyset.cpp b/src/zenstore/hashkeyset.cpp
new file mode 100644
index 000000000..a5436f5cb
--- /dev/null
+++ b/src/zenstore/hashkeyset.cpp
@@ -0,0 +1,60 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenstore/hashkeyset.h>
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+void
+HashKeySet::AddHashToSet(const IoHash& HashToAdd)
+{
+ m_HashSet.insert(HashToAdd);
+}
+
+void
+HashKeySet::AddHashesToSet(std::span<const IoHash> HashesToAdd)
+{
+ m_HashSet.insert(HashesToAdd.begin(), HashesToAdd.end());
+}
+
+void
+HashKeySet::RemoveHashesIf(std::function<bool(const IoHash& CandidateHash)>&& Predicate)
+{
+ for (auto It = begin(m_HashSet), ItEnd = end(m_HashSet); It != ItEnd;)
+ {
+ if (Predicate(*It))
+ {
+ It = m_HashSet.erase(It);
+ }
+ else
+ {
+ ++It;
+ }
+ }
+}
+
+void
+HashKeySet::IterateHashes(std::function<void(const IoHash& Hash)>&& Callback) const
+{
+ for (auto It = begin(m_HashSet), ItEnd = end(m_HashSet); It != ItEnd; ++It)
+ {
+ Callback(*It);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+//
+// Testing related code follows...
+//
+
+#if ZEN_WITH_TESTS
+
+void
+hashkeyset_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenstore/include/zenstore/blockstore.h b/src/zenstore/include/zenstore/blockstore.h
new file mode 100644
index 000000000..857ccae38
--- /dev/null
+++ b/src/zenstore/include/zenstore/blockstore.h
@@ -0,0 +1,175 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/filesystem.h>
+#include <zencore/zencore.h>
+#include <zenutil/basicfile.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+struct BlockStoreLocation
+{
+ uint32_t BlockIndex;
+ uint64_t Offset;
+ uint64_t Size;
+
+ inline auto operator<=>(const BlockStoreLocation& Rhs) const = default;
+};
+
+#pragma pack(push)
+#pragma pack(1)
+
+struct BlockStoreDiskLocation
+{
+ constexpr static uint32_t MaxBlockIndexBits = 20;
+ constexpr static uint32_t MaxOffsetBits = 28;
+ constexpr static uint32_t MaxBlockIndex = (1ul << BlockStoreDiskLocation::MaxBlockIndexBits) - 1ul;
+ constexpr static uint32_t MaxOffset = (1ul << BlockStoreDiskLocation::MaxOffsetBits) - 1ul;
+
+ BlockStoreDiskLocation(const BlockStoreLocation& Location, uint64_t OffsetAlignment)
+ {
+ Init(Location.BlockIndex, Location.Offset / OffsetAlignment, Location.Size);
+ }
+
+ BlockStoreDiskLocation() = default;
+
+ inline BlockStoreLocation Get(uint64_t OffsetAlignment) const
+ {
+ uint64_t PackedOffset = 0;
+ memcpy(&PackedOffset, &m_Offset, sizeof m_Offset);
+ return {.BlockIndex = static_cast<std::uint32_t>(PackedOffset >> MaxOffsetBits),
+ .Offset = (PackedOffset & MaxOffset) * OffsetAlignment,
+ .Size = GetSize()};
+ }
+
+ inline uint32_t GetBlockIndex() const
+ {
+ uint64_t PackedOffset = 0;
+ memcpy(&PackedOffset, &m_Offset, sizeof m_Offset);
+ return static_cast<std::uint32_t>(PackedOffset >> MaxOffsetBits);
+ }
+
+ inline uint64_t GetOffset(uint64_t OffsetAlignment) const
+ {
+ uint64_t PackedOffset = 0;
+ memcpy(&PackedOffset, &m_Offset, sizeof m_Offset);
+ return (PackedOffset & MaxOffset) * OffsetAlignment;
+ }
+
+ inline uint64_t GetSize() const { return m_Size; }
+
+ inline auto operator<=>(const BlockStoreDiskLocation& Rhs) const = default;
+
+private:
+ inline void Init(uint32_t BlockIndex, uint64_t Offset, uint64_t Size)
+ {
+ ZEN_ASSERT(BlockIndex <= MaxBlockIndex);
+ ZEN_ASSERT(Offset <= MaxOffset);
+ ZEN_ASSERT(Size <= std::numeric_limits<std::uint32_t>::max());
+
+ m_Size = static_cast<uint32_t>(Size);
+ uint64_t PackedOffset = (static_cast<uint64_t>(BlockIndex) << MaxOffsetBits) + Offset;
+ memcpy(&m_Offset[0], &PackedOffset, sizeof m_Offset);
+ }
+
+ uint32_t m_Size;
+ uint8_t m_Offset[6];
+};
+
+#pragma pack(pop)
+
+struct BlockStoreFile : public RefCounted
+{
+ explicit BlockStoreFile(const std::filesystem::path& BlockPath);
+ ~BlockStoreFile();
+ const std::filesystem::path& GetPath() const;
+ void Open();
+ void Create(uint64_t InitialSize);
+ void MarkAsDeleteOnClose();
+ uint64_t FileSize();
+ IoBuffer GetChunk(uint64_t Offset, uint64_t Size);
+ void Read(void* Data, uint64_t Size, uint64_t FileOffset);
+ void Write(const void* Data, uint64_t Size, uint64_t FileOffset);
+ void Flush();
+ BasicFile& GetBasicFile();
+ void StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun);
+
+private:
+ const std::filesystem::path m_Path;
+ IoBuffer m_IoBuffer;
+ BasicFile m_File;
+};
+
+class BlockStore
+{
+public:
+ struct ReclaimSnapshotState
+ {
+ std::unordered_set<uint32_t> m_ActiveWriteBlocks;
+ size_t BlockCount;
+ };
+
+ typedef std::vector<std::pair<size_t, BlockStoreLocation>> MovedChunksArray;
+ typedef std::vector<size_t> ChunkIndexArray;
+
+ typedef std::function<void(const MovedChunksArray& MovedChunks, const ChunkIndexArray& RemovedChunks)> ReclaimCallback;
+ typedef std::function<uint64_t()> ClaimDiskReserveCallback;
+ typedef std::function<void(size_t ChunkIndex, const void* Data, uint64_t Size)> IterateChunksSmallSizeCallback;
+ typedef std::function<void(size_t ChunkIndex, BlockStoreFile& File, uint64_t Offset, uint64_t Size)> IterateChunksLargeSizeCallback;
+ typedef std::function<void(const BlockStoreLocation& Location)> WriteChunkCallback;
+
+ void Initialize(const std::filesystem::path& BlocksBasePath,
+ uint64_t MaxBlockSize,
+ uint64_t MaxBlockCount,
+ const std::vector<BlockStoreLocation>& KnownLocations);
+ void Close();
+
+ void WriteChunk(const void* Data, uint64_t Size, uint64_t Alignment, const WriteChunkCallback& Callback);
+
+ IoBuffer TryGetChunk(const BlockStoreLocation& Location) const;
+ void Flush();
+
+ ReclaimSnapshotState GetReclaimSnapshotState();
+ void ReclaimSpace(
+ const ReclaimSnapshotState& Snapshot,
+ const std::vector<BlockStoreLocation>& ChunkLocations,
+ const ChunkIndexArray& KeepChunkIndexes,
+ uint64_t PayloadAlignment,
+ bool DryRun,
+ const ReclaimCallback& ChangeCallback = [](const MovedChunksArray&, const ChunkIndexArray&) {},
+ const ClaimDiskReserveCallback& DiskReserveCallback = []() { return 0; });
+
+ void IterateChunks(const std::vector<BlockStoreLocation>& ChunkLocations,
+ const IterateChunksSmallSizeCallback& SmallSizeCallback,
+ const IterateChunksLargeSizeCallback& LargeSizeCallback);
+
+ static const char* GetBlockFileExtension();
+ static std::filesystem::path GetBlockPath(const std::filesystem::path& BlocksBasePath, const uint32_t BlockIndex);
+
+ inline uint64_t TotalSize() const { return m_TotalSize.load(std::memory_order::relaxed); }
+
+private:
+ std::unordered_map<uint32_t, Ref<BlockStoreFile>> m_ChunkBlocks;
+
+ mutable RwLock m_InsertLock; // used to serialize inserts
+ Ref<BlockStoreFile> m_WriteBlock;
+ std::uint64_t m_CurrentInsertOffset = 0;
+ std::atomic_uint32_t m_WriteBlockIndex{};
+ std::vector<uint32_t> m_ActiveWriteBlocks;
+
+ uint64_t m_MaxBlockSize = 1u << 28;
+ uint64_t m_MaxBlockCount = BlockStoreDiskLocation::MaxBlockIndex + 1;
+ std::filesystem::path m_BlocksBasePath;
+
+ std::atomic_uint64_t m_TotalSize{};
+};
+
+void blockstore_forcelink();
+
+} // namespace zen
diff --git a/src/zenstore/include/zenstore/caslog.h b/src/zenstore/include/zenstore/caslog.h
new file mode 100644
index 000000000..d8c3f22f3
--- /dev/null
+++ b/src/zenstore/include/zenstore/caslog.h
@@ -0,0 +1,91 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/uid.h>
+#include <zenutil/basicfile.h>
+
+namespace zen {
+
+class CasLogFile
+{
+public:
+ CasLogFile();
+ ~CasLogFile();
+
+ enum class Mode
+ {
+ kRead,
+ kWrite,
+ kTruncate
+ };
+
+ static bool IsValid(std::filesystem::path FileName, size_t RecordSize);
+ void Open(std::filesystem::path FileName, size_t RecordSize, Mode Mode);
+ void Append(const void* DataPointer, uint64_t DataSize);
+ void Replay(std::function<void(const void*)>&& Handler, uint64_t SkipEntryCount);
+ void Flush();
+ void Close();
+ uint64_t GetLogSize();
+ uint64_t GetLogCount();
+
+private:
+ struct FileHeader
+ {
+ uint8_t Magic[16];
+ uint32_t RecordSize = 0;
+ Oid LogId;
+ uint32_t ValidatedTail = 0;
+ uint32_t Pad[6];
+ uint32_t Checksum = 0;
+
+ static const inline uint8_t MagicSequence[16] = {'.', '-', '=', ' ', 'C', 'A', 'S', 'L', 'O', 'G', 'v', '1', ' ', '=', '-', '.'};
+
+ ZENCORE_API uint32_t ComputeChecksum();
+ void Finalize() { Checksum = ComputeChecksum(); }
+ };
+
+ static_assert(sizeof(FileHeader) == 64);
+
+private:
+ void Open(std::filesystem::path FileName, size_t RecordSize, BasicFile::Mode Mode);
+
+ BasicFile m_File;
+ FileHeader m_Header;
+ size_t m_RecordSize = 1;
+ std::atomic<uint64_t> m_AppendOffset = 0;
+};
+
+template<typename T>
+class TCasLogFile : public CasLogFile
+{
+public:
+ static bool IsValid(std::filesystem::path FileName) { return CasLogFile::IsValid(FileName, sizeof(T)); }
+ void Open(std::filesystem::path FileName, Mode Mode) { CasLogFile::Open(FileName, sizeof(T), Mode); }
+
+ // This should be called before the Replay() is called to do some basic sanity checking
+ bool Initialize() { return true; }
+
+ void Replay(Invocable<const T&> auto Handler, uint64_t SkipEntryCount)
+ {
+ CasLogFile::Replay(
+ [&](const void* VoidPtr) {
+ const T& Record = *reinterpret_cast<const T*>(VoidPtr);
+
+ Handler(Record);
+ },
+ SkipEntryCount);
+ }
+
+ void Append(const T& Record)
+ {
+ // TODO: implement some more efficent path here so we don't end up with
+ // a syscall per append
+
+ CasLogFile::Append(&Record, sizeof Record);
+ }
+
+ void Append(const std::span<T>& Records) { CasLogFile::Append(Records.data(), sizeof(T) * Records.size()); }
+};
+
+} // namespace zen
diff --git a/src/zenstore/include/zenstore/cidstore.h b/src/zenstore/include/zenstore/cidstore.h
new file mode 100644
index 000000000..16ca78225
--- /dev/null
+++ b/src/zenstore/include/zenstore/cidstore.h
@@ -0,0 +1,87 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenstore.h"
+
+#include <zencore/iohash.h>
+#include <zenstore/hashkeyset.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <tsl/robin_map.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <filesystem>
+
+namespace zen {
+
+class GcManager;
+class CasStore;
+class CompressedBuffer;
+class IoBuffer;
+class ScrubContext;
+
+/** Content Store
+ *
+ * Data in the content store is referenced by content identifiers (CIDs), it works
+ * with compressed buffers so the CID is expected to be the RAW hash. It stores the
+ * chunk directly under the RAW hash.
+ * This class maps uncompressed hashes (CIDs) to compressed hashes and may
+ * be used to deal with other kinds of indirections in the future. For example, if we want
+ * to support chunking then a CID may represent a list of chunks which could be concatenated
+ * to form the referenced chunk.
+ *
+ */
+
+struct CidStoreSize
+{
+ uint64_t TinySize = 0;
+ uint64_t SmallSize = 0;
+ uint64_t LargeSize = 0;
+ uint64_t TotalSize = 0;
+};
+
+struct CidStoreConfiguration
+{
+ // Root directory for CAS store
+ std::filesystem::path RootDirectory;
+
+ // Threshold below which values are considered 'tiny' and managed using the 'tiny values' strategy
+ uint64_t TinyValueThreshold = 1024;
+
+ // Threshold above which values are considered 'huge' and managed using the 'huge values' strategy
+ uint64_t HugeValueThreshold = 1024 * 1024;
+};
+
+class CidStore
+{
+public:
+ CidStore(GcManager& Gc);
+ ~CidStore();
+
+ struct InsertResult
+ {
+ bool New = false;
+ };
+ enum class InsertMode
+ {
+ kCopyOnly,
+ kMayBeMovedInPlace
+ };
+
+ void Initialize(const CidStoreConfiguration& Config);
+ InsertResult AddChunk(const IoBuffer& ChunkData, const IoHash& RawHash, InsertMode Mode = InsertMode::kMayBeMovedInPlace);
+ IoBuffer FindChunkByCid(const IoHash& DecompressedId);
+ bool ContainsChunk(const IoHash& DecompressedId);
+ void FilterChunks(HashKeySet& InOutChunks);
+ void Flush();
+ void Scrub(ScrubContext& Ctx);
+ CidStoreSize TotalSize() const;
+
+private:
+ struct Impl;
+ std::unique_ptr<CasStore> m_CasStore;
+ std::unique_ptr<Impl> m_Impl;
+};
+
+} // namespace zen
diff --git a/src/zenstore/include/zenstore/gc.h b/src/zenstore/include/zenstore/gc.h
new file mode 100644
index 000000000..e0354b331
--- /dev/null
+++ b/src/zenstore/include/zenstore/gc.h
@@ -0,0 +1,242 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iohash.h>
+#include <zencore/thread.h>
+#include <zenstore/caslog.h>
+
+#include <atomic>
+#include <chrono>
+#include <condition_variable>
+#include <filesystem>
+#include <functional>
+#include <optional>
+#include <span>
+#include <thread>
+
+#define ZEN_USE_REF_TRACKING 0 // This is not currently functional
+
+namespace spdlog {
+class logger;
+}
+
+namespace zen {
+
+class HashKeySet;
+class GcManager;
+class CidStore;
+struct IoHash;
+
+/** GC clock
+ */
+class GcClock
+{
+public:
+ using Clock = std::chrono::system_clock;
+ using TimePoint = Clock::time_point;
+ using Duration = Clock::duration;
+ using Tick = int64_t;
+
+ static Tick TickCount() { return Now().time_since_epoch().count(); }
+ static TimePoint Now() { return Clock::now(); }
+ static TimePoint TimePointFromTick(const Tick TickCount) { return TimePoint{Duration{TickCount}}; }
+};
+
+/** Garbage Collection context object
+ */
+class GcContext
+{
+public:
+ GcContext(const GcClock::TimePoint& ExpireTime);
+ ~GcContext();
+
+ void AddRetainedCids(std::span<const IoHash> Cid);
+ void SetExpiredCacheKeys(const std::string& CacheKeyContext, std::vector<IoHash>&& ExpiredKeys);
+
+ void IterateCids(std::function<void(const IoHash&)> Callback);
+
+ void FilterCids(std::span<const IoHash> Cid, std::function<void(const IoHash&)> KeepFunc);
+ void FilterCids(std::span<const IoHash> Cid, std::function<void(const IoHash&, bool)>&& FilterFunc);
+
+ void AddDeletedCids(std::span<const IoHash> Cas);
+ const HashKeySet& DeletedCids();
+
+ std::span<const IoHash> ExpiredCacheKeys(const std::string& CacheKeyContext) const;
+
+ bool IsDeletionMode() const;
+ void SetDeletionMode(bool NewState);
+
+ bool CollectSmallObjects() const;
+ void CollectSmallObjects(bool NewState);
+
+ GcClock::TimePoint ExpireTime() const;
+
+ void DiskReservePath(const std::filesystem::path& Path);
+ uint64_t ClaimGCReserve();
+
+private:
+ struct GcState;
+
+ std::unique_ptr<GcState> m_State;
+};
+
+/** GC root contributor
+
+ Higher level data structures provide roots for the garbage collector,
+ which ultimately determine what is garbage and what data we need to
+ retain.
+
+ */
+class GcContributor
+{
+public:
+ GcContributor(GcManager& Gc);
+ ~GcContributor();
+
+ virtual void GatherReferences(GcContext& GcCtx) = 0;
+
+protected:
+ GcManager& m_Gc;
+};
+
+struct GcStorageSize
+{
+ uint64_t DiskSize{};
+ uint64_t MemorySize{};
+};
+
+/** GC storage provider
+ */
+class GcStorage
+{
+public:
+ GcStorage(GcManager& Gc);
+ ~GcStorage();
+
+ virtual void CollectGarbage(GcContext& GcCtx) = 0;
+ virtual GcStorageSize StorageSize() const = 0;
+
+private:
+ GcManager& m_Gc;
+};
+
+/** GC orchestrator
+ */
+class GcManager
+{
+public:
+ GcManager();
+ ~GcManager();
+
+ void AddGcContributor(GcContributor* Contributor);
+ void RemoveGcContributor(GcContributor* Contributor);
+
+ void AddGcStorage(GcStorage* Contributor);
+ void RemoveGcStorage(GcStorage* Contributor);
+
+ void CollectGarbage(GcContext& GcCtx);
+
+ GcStorageSize TotalStorageSize() const;
+
+#if ZEN_USE_REF_TRACKING
+ void OnNewCidReferences(std::span<IoHash> Hashes);
+ void OnCommittedCidReferences(std::span<IoHash> Hashes);
+ void OnDroppedCidReferences(std::span<IoHash> Hashes);
+#endif
+
+private:
+ spdlog::logger& Log() { return m_Log; }
+ spdlog::logger& m_Log;
+ mutable RwLock m_Lock;
+ std::vector<GcContributor*> m_GcContribs;
+ std::vector<GcStorage*> m_GcStorage;
+ CidStore* m_CidStore = nullptr;
+};
+
+enum class GcSchedulerStatus : uint32_t
+{
+ kIdle,
+ kRunning,
+ kStopped
+};
+
+struct GcSchedulerConfig
+{
+ std::filesystem::path RootDirectory;
+ std::chrono::seconds MonitorInterval{30};
+ std::chrono::seconds Interval{};
+ std::chrono::seconds MaxCacheDuration{86400};
+ bool CollectSmallObjects = true;
+ bool Enabled = true;
+ uint64_t DiskReserveSize = 1ul << 28;
+ uint64_t DiskSizeSoftLimit = 0;
+};
+
+class DiskUsageWindow
+{
+public:
+ struct DiskUsageEntry
+ {
+ GcClock::Tick SampleTime;
+ uint64_t DiskUsage;
+ };
+
+ std::vector<DiskUsageEntry> m_LogWindow;
+ inline void Append(const DiskUsageEntry& Entry) { m_LogWindow.push_back(Entry); }
+ inline void Append(DiskUsageEntry&& Entry) { m_LogWindow.emplace_back(std::move(Entry)); }
+ void KeepRange(GcClock::Tick StartTick, GcClock::Tick EndTick);
+ std::vector<uint64_t> GetDiskDeltas(GcClock::Tick StartTick,
+ GcClock::Tick EndTick,
+ GcClock::Tick DeltaWidth,
+ uint64_t& OutMaxDelta) const;
+ GcClock::Tick FindTimepointThatRemoves(uint64_t Amount, GcClock::Tick EndTick) const;
+};
+
+/**
+ * GC scheduler
+ */
+class GcScheduler
+{
+public:
+ GcScheduler(GcManager& GcManager);
+ ~GcScheduler();
+
+ void Initialize(const GcSchedulerConfig& Config);
+ void Shutdown();
+ GcSchedulerStatus Status() const { return static_cast<GcSchedulerStatus>(m_Status.load()); }
+
+ struct TriggerParams
+ {
+ bool CollectSmallObjects = false;
+ std::chrono::seconds MaxCacheDuration = std::chrono::seconds::max();
+ uint64_t DiskSizeSoftLimit = 0;
+ };
+
+ bool Trigger(const TriggerParams& Params);
+
+private:
+ void SchedulerThread();
+ void CollectGarbage(const GcClock::TimePoint& ExpireTime, bool Delete, bool CollectSmallObjects);
+ GcClock::TimePoint NextGcTime(GcClock::TimePoint CurrentTime);
+ spdlog::logger& Log() { return m_Log; }
+
+ spdlog::logger& m_Log;
+ GcManager& m_GcManager;
+ GcSchedulerConfig m_Config;
+ GcClock::TimePoint m_LastGcTime{};
+ GcClock::TimePoint m_LastGcExpireTime{};
+ GcClock::TimePoint m_NextGcTime{};
+ std::atomic_uint32_t m_Status{};
+ std::thread m_GcThread;
+ std::mutex m_GcMutex;
+ std::condition_variable m_GcSignal;
+ std::optional<TriggerParams> m_TriggerParams;
+
+ TCasLogFile<DiskUsageWindow::DiskUsageEntry> m_DiskUsageLog;
+ DiskUsageWindow m_DiskUsageWindow;
+};
+
+void gc_forcelink();
+
+} // namespace zen
diff --git a/src/zenstore/include/zenstore/hashkeyset.h b/src/zenstore/include/zenstore/hashkeyset.h
new file mode 100644
index 000000000..411a6256e
--- /dev/null
+++ b/src/zenstore/include/zenstore/hashkeyset.h
@@ -0,0 +1,54 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include "zenstore.h"
+
+#include <zencore/iohash.h>
+
+#include <functional>
+#include <unordered_set>
+
+namespace zen {
+
+/** Manage a set of IoHash values
+ */
+
+class HashKeySet
+{
+public:
+ void AddHashToSet(const IoHash& HashToAdd);
+ void AddHashesToSet(std::span<const IoHash> HashesToAdd);
+ void RemoveHashesIf(std::function<bool(const IoHash& CandidateHash)>&& Predicate);
+ void IterateHashes(std::function<void(const IoHash& Hash)>&& Callback) const;
+ [[nodiscard]] inline bool ContainsHash(const IoHash& Hash) const { return m_HashSet.find(Hash) != m_HashSet.end(); }
+ [[nodiscard]] inline bool IsEmpty() const { return m_HashSet.empty(); }
+ [[nodiscard]] inline size_t GetSize() const { return m_HashSet.size(); }
+
+ inline void FilterHashes(std::span<const IoHash> Candidates, Invocable<const IoHash&> auto MatchFunc) const
+ {
+ for (const IoHash& Candidate : Candidates)
+ {
+ if (ContainsHash(Candidate))
+ {
+ MatchFunc(Candidate);
+ }
+ }
+ }
+
+ inline void FilterHashes(std::span<const IoHash> Candidates, Invocable<const IoHash&, bool> auto MatchFunc) const
+ {
+ for (const IoHash& Candidate : Candidates)
+ {
+ MatchFunc(Candidate, ContainsHash(Candidate));
+ }
+ }
+
+private:
+ // Q: should we protect this with a lock, or is that a higher level concern?
+ std::unordered_set<IoHash, IoHash::Hasher> m_HashSet;
+};
+
+void hashkeyset_forcelink();
+
+} // namespace zen
diff --git a/src/zenstore/include/zenstore/scrubcontext.h b/src/zenstore/include/zenstore/scrubcontext.h
new file mode 100644
index 000000000..0b884fcc6
--- /dev/null
+++ b/src/zenstore/include/zenstore/scrubcontext.h
@@ -0,0 +1,41 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/timer.h>
+#include <zenstore/hashkeyset.h>
+
+namespace zen {
+
+/** Context object for data scrubbing
+ *
+ * Data scrubbing is when we traverse stored data to validate it and
+ * optionally correct/recover
+ */
+
+class ScrubContext
+{
+public:
+ virtual void ReportBadCidChunks(std::span<IoHash> BadCasChunks) { m_BadCid.AddHashesToSet(BadCasChunks); }
+ inline uint64_t ScrubTimestamp() const { return m_ScrubTime; }
+ inline bool RunRecovery() const { return m_Recover; }
+ void ReportScrubbed(uint64_t ChunkCount, uint64_t ChunkBytes)
+ {
+ m_ChunkCount.fetch_add(ChunkCount);
+ m_ByteCount.fetch_add(ChunkBytes);
+ }
+
+ inline uint64_t ScrubbedChunks() const { return m_ChunkCount; }
+ inline uint64_t ScrubbedBytes() const { return m_ByteCount; }
+
+ const HashKeySet BadCids() const { return m_BadCid; }
+
+private:
+ uint64_t m_ScrubTime = GetHifreqTimerValue();
+ bool m_Recover = true;
+ std::atomic<uint64_t> m_ChunkCount{0};
+ std::atomic<uint64_t> m_ByteCount{0};
+ HashKeySet m_BadCid;
+};
+
+} // namespace zen
diff --git a/src/zenstore/include/zenstore/zenstore.h b/src/zenstore/include/zenstore/zenstore.h
new file mode 100644
index 000000000..46d62029d
--- /dev/null
+++ b/src/zenstore/include/zenstore/zenstore.h
@@ -0,0 +1,13 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/zencore.h>
+
+#define ZENSTORE_API
+
+namespace zen {
+
+ZENSTORE_API void zenstore_forcelinktests();
+
+}
diff --git a/src/zenstore/xmake.lua b/src/zenstore/xmake.lua
new file mode 100644
index 000000000..4469c5650
--- /dev/null
+++ b/src/zenstore/xmake.lua
@@ -0,0 +1,9 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target('zenstore')
+ set_kind("static")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_includedirs("include", {public=true})
+ add_deps("zencore", "zenutil")
+ add_packages("vcpkg::robin-map")
diff --git a/src/zenstore/zenstore.cpp b/src/zenstore/zenstore.cpp
new file mode 100644
index 000000000..d87652fde
--- /dev/null
+++ b/src/zenstore/zenstore.cpp
@@ -0,0 +1,32 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenstore/zenstore.h"
+
+#if ZEN_WITH_TESTS
+
+# include <zenstore/blockstore.h>
+# include <zenstore/gc.h>
+# include <zenstore/hashkeyset.h>
+# include <zenutil/basicfile.h>
+
+# include "cas.h"
+# include "compactcas.h"
+# include "filecas.h"
+
+namespace zen {
+
+void
+zenstore_forcelinktests()
+{
+ basicfile_forcelink();
+ CAS_forcelink();
+ filecas_forcelink();
+ blockstore_forcelink();
+ compactcas_forcelink();
+ gc_forcelink();
+ hashkeyset_forcelink();
+}
+
+} // namespace zen
+
+#endif
diff --git a/src/zentest-appstub/xmake.lua b/src/zentest-appstub/xmake.lua
new file mode 100644
index 000000000..d8e0283c1
--- /dev/null
+++ b/src/zentest-appstub/xmake.lua
@@ -0,0 +1,16 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target("zentest-appstub")
+ set_kind("binary")
+ add_headerfiles("**.h")
+ add_files("*.cpp")
+
+ if is_os("linux") then
+ add_syslinks("pthread")
+ end
+
+ if is_plat("macosx") then
+ add_ldflags("-framework CoreFoundation")
+ add_ldflags("-framework Security")
+ add_ldflags("-framework SystemConfiguration")
+ end
diff --git a/src/zentest-appstub/zentest-appstub.cpp b/src/zentest-appstub/zentest-appstub.cpp
new file mode 100644
index 000000000..66e6e03fd
--- /dev/null
+++ b/src/zentest-appstub/zentest-appstub.cpp
@@ -0,0 +1,34 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <stdio.h>
+#include <cstdlib>
+#include <cstring>
+#include <thread>
+
+using namespace std::chrono_literals;
+
+int
+main(int argc, char* argv[])
+{
+ int ExitCode = 0;
+
+ for (int i = 0; i < argc; ++i)
+ {
+ if (std::strncmp(argv[i], "-t=", 3) == 0)
+ {
+ const int SleepTime = std::atoi(argv[i] + 3);
+
+ printf("[zentest] sleeping for %ds...\n", SleepTime);
+
+ std::this_thread::sleep_for(SleepTime * 1s);
+ }
+ else if (std::strncmp(argv[i], "-f=", 3) == 0)
+ {
+ ExitCode = std::atoi(argv[i] + 3);
+ }
+ }
+
+ printf("[zentest] exiting with exit code: %d\n", ExitCode);
+
+ return ExitCode;
+}
diff --git a/src/zenutil/basicfile.cpp b/src/zenutil/basicfile.cpp
new file mode 100644
index 000000000..1e6043d7e
--- /dev/null
+++ b/src/zenutil/basicfile.cpp
@@ -0,0 +1,575 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenutil/basicfile.h"
+
+#include <zencore/compactbinary.h>
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/testing.h>
+#include <zencore/testutils.h>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+#else
+# include <fcntl.h>
+# include <sys/file.h>
+# include <sys/stat.h>
+# include <unistd.h>
+#endif
+
+#include <fmt/format.h>
+#include <gsl/gsl-lite.hpp>
+
+namespace zen {
+
+BasicFile::~BasicFile()
+{
+ Close();
+}
+
+void
+BasicFile::Open(const std::filesystem::path& FileName, Mode Mode)
+{
+ std::error_code Ec;
+ Open(FileName, Mode, Ec);
+
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("failed to open file '{}'", FileName));
+ }
+}
+
+void
+BasicFile::Open(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec)
+{
+ Ec.clear();
+
+#if ZEN_PLATFORM_WINDOWS
+ DWORD dwCreationDisposition = 0;
+ DWORD dwDesiredAccess = 0;
+ switch (Mode)
+ {
+ case Mode::kRead:
+ dwCreationDisposition |= OPEN_EXISTING;
+ dwDesiredAccess |= GENERIC_READ;
+ break;
+ case Mode::kWrite:
+ dwCreationDisposition |= OPEN_ALWAYS;
+ dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE);
+ break;
+ case Mode::kDelete:
+ dwCreationDisposition |= OPEN_ALWAYS;
+ dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE | DELETE);
+ break;
+ case Mode::kTruncate:
+ dwCreationDisposition |= CREATE_ALWAYS;
+ dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE);
+ break;
+ case Mode::kTruncateDelete:
+ dwCreationDisposition |= CREATE_ALWAYS;
+ dwDesiredAccess |= (GENERIC_READ | GENERIC_WRITE | DELETE);
+ break;
+ }
+
+ const DWORD dwShareMode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE;
+ const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL;
+ HANDLE hTemplateFile = nullptr;
+
+ HANDLE FileHandle = CreateFile(FileName.c_str(),
+ dwDesiredAccess,
+ dwShareMode,
+ /* lpSecurityAttributes */ nullptr,
+ dwCreationDisposition,
+ dwFlagsAndAttributes,
+ hTemplateFile);
+
+ if (FileHandle == INVALID_HANDLE_VALUE)
+ {
+ Ec = MakeErrorCodeFromLastError();
+
+ return;
+ }
+#else
+ int OpenFlags = O_CLOEXEC;
+ switch (Mode)
+ {
+ case Mode::kRead:
+ OpenFlags |= O_RDONLY;
+ break;
+ case Mode::kWrite:
+ case Mode::kDelete:
+ OpenFlags |= (O_RDWR | O_CREAT);
+ break;
+ case Mode::kTruncate:
+ case Mode::kTruncateDelete:
+ OpenFlags |= (O_RDWR | O_CREAT | O_TRUNC);
+ break;
+ }
+
+ int Fd = open(FileName.c_str(), OpenFlags, 0666);
+ if (Fd < 0)
+ {
+ Ec = MakeErrorCodeFromLastError();
+ return;
+ }
+ if (Mode != Mode::kRead)
+ {
+ fchmod(Fd, 0666);
+ }
+
+ void* FileHandle = (void*)(uintptr_t(Fd));
+#endif
+
+ m_FileHandle = FileHandle;
+}
+
+void
+BasicFile::Close()
+{
+ if (m_FileHandle)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ ::CloseHandle(m_FileHandle);
+#else
+ int Fd = int(uintptr_t(m_FileHandle));
+ close(Fd);
+#endif
+ m_FileHandle = nullptr;
+ }
+}
+
+void
+BasicFile::Read(void* Data, uint64_t BytesToRead, uint64_t FileOffset)
+{
+ const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024;
+
+ while (BytesToRead)
+ {
+ const uint64_t NumberOfBytesToRead = Min(BytesToRead, MaxChunkSize);
+
+#if ZEN_PLATFORM_WINDOWS
+ OVERLAPPED Ovl{};
+
+ Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu);
+ Ovl.OffsetHigh = DWORD(FileOffset >> 32);
+
+ DWORD dwNumberOfBytesRead = 0;
+ BOOL Success = ::ReadFile(m_FileHandle, Data, DWORD(NumberOfBytesToRead), &dwNumberOfBytesRead, &Ovl);
+
+ ZEN_ASSERT(dwNumberOfBytesRead == NumberOfBytesToRead);
+#else
+ static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files");
+ int Fd = int(uintptr_t(m_FileHandle));
+ int BytesRead = pread(Fd, Data, NumberOfBytesToRead, FileOffset);
+ bool Success = (BytesRead > 0);
+#endif
+
+ if (!Success)
+ {
+ ThrowLastError(fmt::format("Failed to read from file '{}'", zen::PathFromHandle(m_FileHandle)));
+ }
+
+ BytesToRead -= NumberOfBytesToRead;
+ FileOffset += NumberOfBytesToRead;
+ Data = reinterpret_cast<uint8_t*>(Data) + NumberOfBytesToRead;
+ }
+}
+
+IoBuffer
+BasicFile::ReadAll()
+{
+ IoBuffer Buffer(FileSize());
+ Read(Buffer.MutableData(), Buffer.Size(), 0);
+ return Buffer;
+}
+
+void
+BasicFile::StreamFile(std::function<void(const void* Data, uint64_t Size)>&& ChunkFun)
+{
+ StreamByteRange(0, FileSize(), std::move(ChunkFun));
+}
+
+void
+BasicFile::StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun)
+{
+ const uint64_t ChunkSize = 128 * 1024;
+ IoBuffer ReadBuffer{ChunkSize};
+ void* BufferPtr = ReadBuffer.MutableData();
+
+ uint64_t RemainBytes = Size;
+ uint64_t CurrentOffset = FileOffset;
+
+ while (RemainBytes)
+ {
+ const uint64_t ThisChunkBytes = zen::Min(ChunkSize, RemainBytes);
+
+ Read(BufferPtr, ThisChunkBytes, CurrentOffset);
+
+ ChunkFun(BufferPtr, ThisChunkBytes);
+
+ CurrentOffset += ThisChunkBytes;
+ RemainBytes -= ThisChunkBytes;
+ }
+}
+
+void
+BasicFile::Write(MemoryView Data, uint64_t FileOffset, std::error_code& Ec)
+{
+ Write(Data.GetData(), Data.GetSize(), FileOffset, Ec);
+}
+
+void
+BasicFile::Write(const void* Data, uint64_t Size, uint64_t FileOffset, std::error_code& Ec)
+{
+ Ec.clear();
+
+ const uint64_t MaxChunkSize = 2u * 1024 * 1024 * 1024;
+
+ while (Size)
+ {
+ const uint64_t NumberOfBytesToWrite = Min(Size, MaxChunkSize);
+
+#if ZEN_PLATFORM_WINDOWS
+ OVERLAPPED Ovl{};
+
+ Ovl.Offset = DWORD(FileOffset & 0xffff'ffffu);
+ Ovl.OffsetHigh = DWORD(FileOffset >> 32);
+
+ DWORD dwNumberOfBytesWritten = 0;
+
+ BOOL Success = ::WriteFile(m_FileHandle, Data, DWORD(NumberOfBytesToWrite), &dwNumberOfBytesWritten, &Ovl);
+#else
+ static_assert(sizeof(off_t) >= sizeof(uint64_t), "sizeof(off_t) does not support large files");
+ int Fd = int(uintptr_t(m_FileHandle));
+ int BytesWritten = pwrite(Fd, Data, NumberOfBytesToWrite, FileOffset);
+ bool Success = (BytesWritten > 0);
+#endif
+
+ if (!Success)
+ {
+ Ec = MakeErrorCodeFromLastError();
+
+ return;
+ }
+
+ Size -= NumberOfBytesToWrite;
+ FileOffset += NumberOfBytesToWrite;
+ Data = reinterpret_cast<const uint8_t*>(Data) + NumberOfBytesToWrite;
+ }
+}
+
+void
+BasicFile::Write(MemoryView Data, uint64_t FileOffset)
+{
+ Write(Data.GetData(), Data.GetSize(), FileOffset);
+}
+
+void
+BasicFile::Write(const void* Data, uint64_t Size, uint64_t Offset)
+{
+ std::error_code Ec;
+ Write(Data, Size, Offset, Ec);
+
+ if (Ec)
+ {
+ throw std::system_error(Ec, fmt::format("Failed to write to file '{}'", zen::PathFromHandle(m_FileHandle)));
+ }
+}
+
+void
+BasicFile::WriteAll(IoBuffer Data, std::error_code& Ec)
+{
+ Write(Data.Data(), Data.Size(), 0, Ec);
+}
+
+void
+BasicFile::Flush()
+{
+#if ZEN_PLATFORM_WINDOWS
+ FlushFileBuffers(m_FileHandle);
+#else
+ int Fd = int(uintptr_t(m_FileHandle));
+ fsync(Fd);
+#endif
+}
+
+uint64_t
+BasicFile::FileSize()
+{
+#if ZEN_PLATFORM_WINDOWS
+ ULARGE_INTEGER liFileSize;
+ liFileSize.LowPart = ::GetFileSize(m_FileHandle, &liFileSize.HighPart);
+ if (liFileSize.LowPart == INVALID_FILE_SIZE)
+ {
+ int Error = zen::GetLastError();
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to get file size from file '{}'", PathFromHandle(m_FileHandle)));
+ }
+ }
+ return uint64_t(liFileSize.QuadPart);
+#else
+ int Fd = int(uintptr_t(m_FileHandle));
+ static_assert(sizeof(decltype(stat::st_size)) == sizeof(uint64_t), "fstat() doesn't support large files");
+ struct stat Stat;
+ fstat(Fd, &Stat);
+ return uint64_t(Stat.st_size);
+#endif
+}
+
+void
+BasicFile::SetFileSize(uint64_t FileSize)
+{
+#if ZEN_PLATFORM_WINDOWS
+ LARGE_INTEGER liFileSize;
+ liFileSize.QuadPart = FileSize;
+ BOOL OK = ::SetFilePointerEx(m_FileHandle, liFileSize, 0, FILE_BEGIN);
+ if (OK == FALSE)
+ {
+ int Error = zen::GetLastError();
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to set file pointer to {} for file {}", FileSize, PathFromHandle(m_FileHandle)));
+ }
+ }
+ OK = ::SetEndOfFile(m_FileHandle);
+ if (OK == FALSE)
+ {
+ int Error = zen::GetLastError();
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to set end of file to {} for file {}", FileSize, PathFromHandle(m_FileHandle)));
+ }
+ }
+#elif ZEN_PLATFORM_MAC
+ int Fd = int(intptr_t(m_FileHandle));
+ if (ftruncate(Fd, (off_t)FileSize) < 0)
+ {
+ int Error = zen::GetLastError();
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to set truncate file to {} for file {}", FileSize, PathFromHandle(m_FileHandle)));
+ }
+ }
+#else
+ int Fd = int(intptr_t(m_FileHandle));
+ if (ftruncate64(Fd, (off64_t)FileSize) < 0)
+ {
+ int Error = zen::GetLastError();
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to set truncate file to {} for file {}", FileSize, PathFromHandle(m_FileHandle)));
+ }
+ }
+ if (FileSize > 0)
+ {
+ int Error = posix_fallocate64(Fd, 0, (off64_t)FileSize);
+ if (Error)
+ {
+ ThrowSystemError(Error, fmt::format("Failed to allocate space of {} for file {}", FileSize, PathFromHandle(m_FileHandle)));
+ }
+ }
+#endif
+}
+
+void*
+BasicFile::Detach()
+{
+ void* FileHandle = m_FileHandle;
+ m_FileHandle = 0;
+ return FileHandle;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+TemporaryFile::~TemporaryFile()
+{
+ Close();
+}
+
+void
+TemporaryFile::Close()
+{
+ if (m_FileHandle)
+ {
+#if ZEN_PLATFORM_WINDOWS
+ // Mark file for deletion when final handle is closed
+
+ FILE_DISPOSITION_INFO Fdi{.DeleteFile = TRUE};
+
+ SetFileInformationByHandle(m_FileHandle, FileDispositionInfo, &Fdi, sizeof Fdi);
+#else
+ std::filesystem::path FilePath = zen::PathFromHandle(m_FileHandle);
+ unlink(FilePath.c_str());
+#endif
+
+ BasicFile::Close();
+ }
+}
+
+void
+TemporaryFile::CreateTemporary(std::filesystem::path TempDirName, std::error_code& Ec)
+{
+ StringBuilder<64> TempName;
+ Oid::NewOid().ToString(TempName);
+
+ m_TempPath = TempDirName / TempName.c_str();
+
+ Open(m_TempPath, BasicFile::Mode::kTruncateDelete, Ec);
+}
+
+void
+TemporaryFile::MoveTemporaryIntoPlace(std::filesystem::path FinalFileName, std::error_code& Ec)
+{
+ // We intentionally call the base class Close() since otherwise we'll end up
+ // deleting the temporary file
+ BasicFile::Close();
+
+ std::filesystem::rename(m_TempPath, FinalFileName, Ec);
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+LockFile::LockFile()
+{
+}
+
+LockFile::~LockFile()
+{
+#if ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ int Fd = int(intptr_t(m_FileHandle));
+ flock(Fd, LOCK_UN | LOCK_NB);
+#endif
+}
+
+void
+LockFile::Create(std::filesystem::path FileName, CbObject Payload, std::error_code& Ec)
+{
+#if ZEN_PLATFORM_WINDOWS
+ Ec.clear();
+
+ const DWORD dwCreationDisposition = CREATE_ALWAYS;
+ DWORD dwDesiredAccess = GENERIC_READ | GENERIC_WRITE | DELETE;
+ const DWORD dwShareMode = FILE_SHARE_READ;
+ const DWORD dwFlagsAndAttributes = FILE_ATTRIBUTE_NORMAL | FILE_FLAG_DELETE_ON_CLOSE;
+ HANDLE hTemplateFile = nullptr;
+
+ HANDLE FileHandle = CreateFile(FileName.c_str(),
+ dwDesiredAccess,
+ dwShareMode,
+ /* lpSecurityAttributes */ nullptr,
+ dwCreationDisposition,
+ dwFlagsAndAttributes,
+ hTemplateFile);
+
+ if (FileHandle == INVALID_HANDLE_VALUE)
+ {
+ Ec = zen::MakeErrorCodeFromLastError();
+
+ return;
+ }
+#elif ZEN_PLATFORM_LINUX || ZEN_PLATFORM_MAC
+ int Fd = open(FileName.c_str(), O_RDWR | O_CREAT | O_CLOEXEC, 0666);
+ if (Fd < 0)
+ {
+ Ec = zen::MakeErrorCodeFromLastError();
+ return;
+ }
+ fchmod(Fd, 0666);
+
+ int LockRet = flock(Fd, LOCK_EX | LOCK_NB);
+ if (LockRet < 0)
+ {
+ Ec = zen::MakeErrorCodeFromLastError();
+ close(Fd);
+ return;
+ }
+
+ void* FileHandle = (void*)uintptr_t(Fd);
+#endif
+
+ m_FileHandle = FileHandle;
+
+ BasicFile::Write(Payload.GetBuffer(), 0, Ec);
+}
+
+void
+LockFile::Update(CbObject Payload, std::error_code& Ec)
+{
+ BasicFile::Write(Payload.GetBuffer(), 0, Ec);
+}
+
+/*
+ ___________ __
+ \__ ___/___ _______/ |_ ______
+ | |_/ __ \ / ___/\ __\/ ___/
+ | |\ ___/ \___ \ | | \___ \
+ |____| \___ >____ > |__| /____ >
+ \/ \/ \/
+*/
+
+#if ZEN_WITH_TESTS
+
+TEST_CASE("BasicFile")
+{
+ ScopedCurrentDirectoryChange _;
+
+ BasicFile File1;
+ CHECK_THROWS(File1.Open("zonk", BasicFile::Mode::kRead));
+ CHECK_NOTHROW(File1.Open("zonk", BasicFile::Mode::kTruncate));
+ CHECK_NOTHROW(File1.Write("abcd", 4, 0));
+ CHECK(File1.FileSize() == 4);
+ {
+ IoBuffer Data = File1.ReadAll();
+ CHECK(Data.Size() == 4);
+ CHECK_EQ(memcmp(Data.Data(), "abcd", 4), 0);
+ }
+ CHECK_NOTHROW(File1.Write("efgh", 4, 2));
+ CHECK(File1.FileSize() == 6);
+ {
+ IoBuffer Data = File1.ReadAll();
+ CHECK(Data.Size() == 6);
+ CHECK_EQ(memcmp(Data.Data(), "abefgh", 6), 0);
+ }
+}
+
+TEST_CASE("TemporaryFile")
+{
+ ScopedCurrentDirectoryChange _;
+
+ SUBCASE("DeleteOnClose")
+ {
+ TemporaryFile TmpFile;
+ std::error_code Ec;
+ TmpFile.CreateTemporary(std::filesystem::current_path(), Ec);
+ CHECK(!Ec);
+ CHECK(std::filesystem::exists(TmpFile.GetPath()));
+ TmpFile.Close();
+ CHECK(std::filesystem::exists(TmpFile.GetPath()) == false);
+ }
+
+ SUBCASE("MoveIntoPlace")
+ {
+ TemporaryFile TmpFile;
+ std::error_code Ec;
+ TmpFile.CreateTemporary(std::filesystem::current_path(), Ec);
+ CHECK(!Ec);
+ std::filesystem::path TempPath = TmpFile.GetPath();
+ std::filesystem::path FinalPath = std::filesystem::current_path() / "final";
+ CHECK(std::filesystem::exists(TempPath));
+ TmpFile.MoveTemporaryIntoPlace(FinalPath, Ec);
+ CHECK(!Ec);
+ CHECK(std::filesystem::exists(TempPath) == false);
+ CHECK(std::filesystem::exists(FinalPath));
+ }
+}
+
+void
+basicfile_forcelink()
+{
+}
+
+#endif
+
+} // namespace zen
diff --git a/src/zenutil/cache/cachekey.cpp b/src/zenutil/cache/cachekey.cpp
new file mode 100644
index 000000000..545b47f11
--- /dev/null
+++ b/src/zenutil/cache/cachekey.cpp
@@ -0,0 +1,9 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/cache/cachekey.h>
+
+namespace zen {
+
+const CacheKey CacheKey::Empty = CacheKey{.Bucket = std::string(), .Hash = IoHash()};
+
+} // namespace zen
diff --git a/src/zenutil/cache/cachepolicy.cpp b/src/zenutil/cache/cachepolicy.cpp
new file mode 100644
index 000000000..3bca363bb
--- /dev/null
+++ b/src/zenutil/cache/cachepolicy.cpp
@@ -0,0 +1,282 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/cache/cachepolicy.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/enumflags.h>
+#include <zencore/string.h>
+
+#include <algorithm>
+#include <unordered_map>
+
+namespace zen::Private {
+class CacheRecordPolicyShared;
+}
+
+namespace zen {
+
+using namespace std::literals;
+
+namespace DerivedData::Private {
+
+ constinit char CachePolicyDelimiter = ',';
+
+ struct CachePolicyToTextData
+ {
+ CachePolicy Policy;
+ std::string_view Text;
+ };
+
+ constinit CachePolicyToTextData CachePolicyToText[]{
+ // Flags with multiple bits are ordered by bit count to minimize token count in the text format.
+ {CachePolicy::Default, "Default"sv},
+ {CachePolicy::Remote, "Remote"sv},
+ {CachePolicy::Local, "Local"sv},
+ {CachePolicy::Store, "Store"sv},
+ {CachePolicy::Query, "Query"sv},
+ // Flags with only one bit can be in any order. Match the order in CachePolicy.
+ {CachePolicy::QueryLocal, "QueryLocal"sv},
+ {CachePolicy::QueryRemote, "QueryRemote"sv},
+ {CachePolicy::StoreLocal, "StoreLocal"sv},
+ {CachePolicy::StoreRemote, "StoreRemote"sv},
+ {CachePolicy::SkipMeta, "SkipMeta"sv},
+ {CachePolicy::SkipData, "SkipData"sv},
+ {CachePolicy::PartialRecord, "PartialRecord"sv},
+ {CachePolicy::KeepAlive, "KeepAlive"sv},
+ // None must be last because it matches every policy.
+ {CachePolicy::None, "None"sv},
+ };
+
+ constinit CachePolicy CachePolicyKnownFlags =
+ CachePolicy::Default | CachePolicy::SkipMeta | CachePolicy::SkipData | CachePolicy::PartialRecord | CachePolicy::KeepAlive;
+
+ StringBuilderBase& CachePolicyToString(StringBuilderBase& Builder, CachePolicy Policy)
+ {
+ // Mask out unknown flags. None will be written if no flags are known.
+ Policy &= CachePolicyKnownFlags;
+ for (const CachePolicyToTextData& Pair : CachePolicyToText)
+ {
+ if (EnumHasAllFlags(Policy, Pair.Policy))
+ {
+ EnumRemoveFlags(Policy, Pair.Policy);
+ Builder << Pair.Text << CachePolicyDelimiter;
+ if (Policy == CachePolicy::None)
+ {
+ break;
+ }
+ }
+ }
+ Builder.RemoveSuffix(1);
+ return Builder;
+ }
+
+ CachePolicy ParseCachePolicy(const std::string_view Text)
+ {
+ ZEN_ASSERT(!Text.empty()); // ParseCachePolicy requires a non-empty string
+ CachePolicy Policy = CachePolicy::None;
+ ForEachStrTok(Text, CachePolicyDelimiter, [&Policy, Index = int32_t(0)](const std::string_view& Token) mutable {
+ const int32_t EndIndex = Index;
+ for (; size_t(Index) < sizeof(CachePolicyToText) / sizeof(CachePolicyToText[0]); ++Index)
+ {
+ if (CachePolicyToText[Index].Text == Token)
+ {
+ Policy |= CachePolicyToText[Index].Policy;
+ ++Index;
+ return true;
+ }
+ }
+ for (Index = 0; Index < EndIndex; ++Index)
+ {
+ if (CachePolicyToText[Index].Text == Token)
+ {
+ Policy |= CachePolicyToText[Index].Policy;
+ ++Index;
+ return true;
+ }
+ }
+ return true;
+ });
+ return Policy;
+ }
+
+} // namespace DerivedData::Private
+
+StringBuilderBase&
+operator<<(StringBuilderBase& Builder, CachePolicy Policy)
+{
+ return DerivedData::Private::CachePolicyToString(Builder, Policy);
+}
+
+CachePolicy
+ParseCachePolicy(std::string_view Text)
+{
+ return DerivedData::Private::ParseCachePolicy(Text);
+}
+
+CachePolicy
+ConvertToUpstream(CachePolicy Policy)
+{
+ // Set Local flags equal to downstream's Remote flags.
+ // Delete Skip flags if StoreLocal is true, otherwise use the downstream value.
+ // Use the downstream value for all other flags.
+
+ CachePolicy UpstreamPolicy = CachePolicy::None;
+
+ if (EnumHasAllFlags(Policy, CachePolicy::QueryRemote))
+ {
+ UpstreamPolicy |= CachePolicy::QueryLocal;
+ }
+
+ if (EnumHasAllFlags(Policy, CachePolicy::StoreRemote))
+ {
+ UpstreamPolicy |= CachePolicy::StoreLocal;
+ }
+
+ if (!EnumHasAllFlags(Policy, CachePolicy::StoreLocal))
+ {
+ UpstreamPolicy |= (Policy & (CachePolicy::SkipData | CachePolicy::SkipMeta));
+ }
+
+ UpstreamPolicy |= Policy & ~(CachePolicy::Local | CachePolicy::SkipData | CachePolicy::SkipMeta);
+
+ return UpstreamPolicy;
+}
+
+class Private::CacheRecordPolicyShared final : public Private::ICacheRecordPolicyShared
+{
+public:
+ inline void AddValuePolicy(const CacheValuePolicy& Value) final
+ {
+ ZEN_ASSERT(Value.Id); // Failed to add value policy because the ID is null.
+ const auto Insert =
+ std::lower_bound(Values.begin(), Values.end(), Value, [](const CacheValuePolicy& Existing, const CacheValuePolicy& New) {
+ return Existing.Id < New.Id;
+ });
+ ZEN_ASSERT(
+ !(Insert < Values.end() &&
+ Insert->Id == Value.Id)); // Failed to add value policy with ID %s because it has an existing value policy with that ID. ")
+ Values.insert(Insert, Value);
+ }
+
+ inline std::span<const CacheValuePolicy> GetValuePolicies() const final { return Values; }
+
+private:
+ std::vector<CacheValuePolicy> Values;
+};
+
+CachePolicy
+CacheRecordPolicy::GetValuePolicy(const Oid& Id) const
+{
+ if (Shared)
+ {
+ const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies();
+ const auto Iter =
+ std::lower_bound(Values.begin(), Values.end(), Id, [](const CacheValuePolicy& A, const Oid& B) { return A.Id < B; });
+ if (Iter != Values.end() && Iter->Id == Id)
+ {
+ return Iter->Policy;
+ }
+ }
+ return DefaultValuePolicy;
+}
+
+void
+CacheRecordPolicy::Save(CbWriter& Writer) const
+{
+ Writer.BeginObject();
+ // The RecordPolicy is calculated from the ValuePolicies and does not need to be saved separately.
+ Writer.AddString("BasePolicy"sv, WriteToString<128>(GetBasePolicy()));
+ if (!IsUniform())
+ {
+ Writer.BeginArray("ValuePolicies"sv);
+ for (const CacheValuePolicy& Value : GetValuePolicies())
+ {
+ Writer.BeginObject();
+ Writer.AddObjectId("Id"sv, Value.Id);
+ Writer.AddString("Policy"sv, WriteToString<128>(Value.Policy));
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+}
+
+OptionalCacheRecordPolicy
+CacheRecordPolicy::Load(const CbObjectView Object)
+{
+ std::string_view BasePolicyText = Object["BasePolicy"sv].AsString();
+ if (BasePolicyText.empty())
+ {
+ return {};
+ }
+
+ CacheRecordPolicyBuilder Builder(ParseCachePolicy(BasePolicyText));
+ for (CbFieldView ValueField : Object["ValuePolicies"sv])
+ {
+ const CbObjectView Value = ValueField.AsObjectView();
+ const Oid Id = Value["Id"sv].AsObjectId();
+ const std::string_view PolicyText = Value["Policy"sv].AsString();
+ if (!Id || PolicyText.empty())
+ {
+ return {};
+ }
+ CachePolicy Policy = ParseCachePolicy(PolicyText);
+ if (EnumHasAnyFlags(Policy, ~CacheValuePolicy::PolicyMask))
+ {
+ return {};
+ }
+ Builder.AddValuePolicy(Id, Policy);
+ }
+
+ return Builder.Build();
+}
+
+CacheRecordPolicy
+CacheRecordPolicy::ConvertToUpstream() const
+{
+ CacheRecordPolicyBuilder Builder(zen::ConvertToUpstream(GetBasePolicy()));
+ for (const CacheValuePolicy& ValuePolicy : GetValuePolicies())
+ {
+ Builder.AddValuePolicy(ValuePolicy.Id, zen::ConvertToUpstream(ValuePolicy.Policy));
+ }
+ return Builder.Build();
+}
+
+void
+CacheRecordPolicyBuilder::AddValuePolicy(const CacheValuePolicy& Value)
+{
+ ZEN_ASSERT(!EnumHasAnyFlags(Value.Policy,
+ ~Value.PolicyMask)); // Value policy contains flags that only make sense on the record policy. Policy: %s
+ if (Value.Policy == (BasePolicy & Value.PolicyMask))
+ {
+ return;
+ }
+ if (!Shared)
+ {
+ Shared = new Private::CacheRecordPolicyShared;
+ }
+ Shared->AddValuePolicy(Value);
+}
+
+CacheRecordPolicy
+CacheRecordPolicyBuilder::Build()
+{
+ CacheRecordPolicy Policy(BasePolicy);
+ if (Shared)
+ {
+ const auto Add = [](const CachePolicy A, const CachePolicy B) {
+ return ((A | B) & ~CachePolicy::SkipData) | ((A & B) & CachePolicy::SkipData);
+ };
+ const std::span<const CacheValuePolicy> Values = Shared->GetValuePolicies();
+ Policy.RecordPolicy = BasePolicy;
+ for (const CacheValuePolicy& ValuePolicy : Values)
+ {
+ Policy.RecordPolicy = Add(Policy.RecordPolicy, ValuePolicy.Policy);
+ }
+ Policy.Shared = std::move(Shared);
+ }
+ return Policy;
+}
+
+} // namespace zen
diff --git a/src/zenutil/cache/cacherequests.cpp b/src/zenutil/cache/cacherequests.cpp
new file mode 100644
index 000000000..4c865ec22
--- /dev/null
+++ b/src/zenutil/cache/cacherequests.cpp
@@ -0,0 +1,1643 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/cache/cacherequests.h>
+
+#include <zencore/compactbinary.h>
+#include <zencore/compactbinarybuilder.h>
+#include <zencore/compactbinarypackage.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/zencore.h>
+
+#include <string>
+#include <string_view>
+
+#if ZEN_WITH_TESTS
+# include <zencore/testing.h>
+#endif
+
+namespace zen {
+
+namespace cacherequests {
+
+ namespace {
+ constinit AsciiSet ValidNamespaceNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789-_.ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+ constinit AsciiSet ValidBucketNameCharactersSet{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"};
+
+ std::optional<std::string> GetValidNamespaceName(std::string_view Name)
+ {
+ if (Name.empty())
+ {
+ ZEN_WARN("Namespace is invalid, empty namespace is not allowed");
+ return {};
+ }
+
+ if (Name.length() > 64)
+ {
+ ZEN_WARN("Namespace '{}' is invalid, length exceeds 64 characters", Name);
+ return {};
+ }
+
+ if (!AsciiSet::HasOnly(Name, ValidNamespaceNameCharactersSet))
+ {
+ ZEN_WARN("Namespace '{}' is invalid, invalid characters detected", Name);
+ return {};
+ }
+
+ return ToLower(Name);
+ }
+
+ std::optional<std::string> GetValidBucketName(std::string_view Name)
+ {
+ if (Name.empty())
+ {
+ ZEN_WARN("Bucket name is invalid, empty bucket name is not allowed");
+ return {};
+ }
+
+ if (!AsciiSet::HasOnly(Name, ValidBucketNameCharactersSet))
+ {
+ ZEN_WARN("Bucket name '{}' is invalid, invalid characters detected", Name);
+ return {};
+ }
+
+ return ToLower(Name);
+ }
+
+ std::optional<IoHash> GetValidIoHash(std::string_view Hash)
+ {
+ if (Hash.length() != IoHash::StringLength)
+ {
+ return {};
+ }
+
+ IoHash KeyHash;
+ if (!ParseHexBytes(Hash.data(), Hash.size(), KeyHash.Hash))
+ {
+ return {};
+ }
+ return KeyHash;
+ }
+
+ std::optional<CacheRecordPolicy> Convert(const OptionalCacheRecordPolicy& Policy)
+ {
+ return Policy.IsValid() ? Policy.Get() : std::optional<CacheRecordPolicy>{};
+ };
+ } // namespace
+
+ std::optional<std::string> GetRequestNamespace(const CbObjectView& Params)
+ {
+ CbFieldView NamespaceField = Params["Namespace"];
+ if (!NamespaceField)
+ {
+ return std::string("!default!"); // ZenCacheStore::DefaultNamespace);
+ }
+
+ if (NamespaceField.HasError())
+ {
+ return {};
+ }
+ if (!NamespaceField.IsString())
+ {
+ return {};
+ }
+ return GetValidNamespaceName(NamespaceField.AsString());
+ }
+
+ bool GetRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key)
+ {
+ CbFieldView BucketField = KeyView["Bucket"];
+ if (BucketField.HasError())
+ {
+ return false;
+ }
+ if (!BucketField.IsString())
+ {
+ return false;
+ }
+ std::optional<std::string> Bucket = GetValidBucketName(BucketField.AsString());
+ if (!Bucket.has_value())
+ {
+ return false;
+ }
+ CbFieldView HashField = KeyView["Hash"];
+ if (HashField.HasError())
+ {
+ return false;
+ }
+ if (!HashField.IsHash())
+ {
+ return false;
+ }
+ Key.Bucket = *Bucket;
+ Key.Hash = HashField.AsHash();
+ return true;
+ }
+
+ void WriteCacheRequestKey(CbObjectWriter& Writer, const CacheKey& Value)
+ {
+ Writer.BeginObject("Key");
+ {
+ Writer << "Bucket" << Value.Bucket;
+ Writer << "Hash" << Value.Hash;
+ }
+ Writer.EndObject();
+ }
+
+ void WriteOptionalCacheRequestPolicy(CbObjectWriter& Writer, std::string_view FieldName, const std::optional<CacheRecordPolicy>& Policy)
+ {
+ if (Policy)
+ {
+ Writer.SetName(FieldName);
+ Policy->Save(Writer);
+ }
+ }
+
+ std::optional<CachePolicy> GetCachePolicy(CbObjectView ObjectView, std::string_view FieldName)
+ {
+ std::string_view DefaultPolicyText = ObjectView[FieldName].AsString();
+ if (DefaultPolicyText.empty())
+ {
+ return {};
+ }
+ return ParseCachePolicy(DefaultPolicyText);
+ }
+
+ void WriteCachePolicy(CbObjectWriter& Writer, std::string_view FieldName, const std::optional<CachePolicy>& Policy)
+ {
+ if (Policy)
+ {
+ Writer << FieldName << WriteToString<128>(*Policy);
+ }
+ }
+
+ bool PutCacheRecordsRequest::Parse(const CbPackage& Package)
+ {
+ CbObjectView BatchObject = Package.GetObject();
+ ZEN_ASSERT(BatchObject["Method"].AsString() == "PutCacheRecords");
+ AcceptMagic = BatchObject["AcceptType"].AsUInt32(0);
+
+ CbObjectView Params = BatchObject["Params"].AsObjectView();
+ std::optional<std::string> RequestNamespace = GetRequestNamespace(Params);
+ if (!RequestNamespace)
+ {
+ return false;
+ }
+ Namespace = *RequestNamespace;
+ DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default);
+
+ CbArrayView RequestFieldArray = Params["Requests"].AsArrayView();
+ Requests.resize(RequestFieldArray.Num());
+ for (size_t RequestIndex = 0; CbFieldView RequestField : RequestFieldArray)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView RecordObject = RequestObject["Record"].AsObjectView();
+ CbObjectView KeyView = RecordObject["Key"].AsObjectView();
+
+ PutCacheRecordRequest& Request = Requests[RequestIndex++];
+
+ if (!GetRequestCacheKey(KeyView, Request.Key))
+ {
+ return false;
+ }
+
+ Request.Policy = Convert(CacheRecordPolicy::Load(RequestObject["Policy"].AsObjectView()));
+
+ std::unordered_map<IoHash, size_t, IoHash::Hasher> RawHashToAttachmentIndex;
+
+ CbArrayView ValuesArray = RecordObject["Values"].AsArrayView();
+ Request.Values.resize(ValuesArray.Num());
+ RawHashToAttachmentIndex.reserve(ValuesArray.Num());
+ for (size_t Index = 0; CbFieldView Value : ValuesArray)
+ {
+ CbObjectView ObjectView = Value.AsObjectView();
+ IoHash AttachmentHash = ObjectView["RawHash"].AsHash();
+ RawHashToAttachmentIndex[AttachmentHash] = Index;
+ Request.Values[Index++] = {.Id = ObjectView["Id"].AsObjectId(), .RawHash = AttachmentHash};
+ }
+
+ RecordObject.IterateAttachments([&](CbFieldView HashView) {
+ const IoHash ValueHash = HashView.AsHash();
+ if (const CbAttachment* Attachment = Package.FindAttachment(ValueHash))
+ {
+ if (Attachment->IsCompressedBinary())
+ {
+ auto It = RawHashToAttachmentIndex.find(ValueHash);
+ ZEN_ASSERT(It != RawHashToAttachmentIndex.end());
+ PutCacheRecordRequestValue& Value = Request.Values[It->second];
+ ZEN_ASSERT(Value.RawHash == ValueHash);
+ Value.Body = Attachment->AsCompressedBinary();
+ ZEN_ASSERT_SLOW(Value.Body.DecodeRawHash() == Value.RawHash);
+ }
+ }
+ });
+ }
+
+ return true;
+ }
+
+ bool PutCacheRecordsRequest::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter Writer;
+ Writer << "Method"
+ << "PutCacheRecords";
+ if (AcceptMagic != 0)
+ {
+ Writer << "Accept" << AcceptMagic;
+ }
+
+ Writer.BeginObject("Params");
+ {
+ Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy);
+ Writer << "Namespace" << Namespace;
+
+ Writer.BeginArray("Requests");
+ for (const PutCacheRecordRequest& RecordRequest : Requests)
+ {
+ Writer.BeginObject();
+ {
+ Writer.BeginObject("Record");
+ {
+ WriteCacheRequestKey(Writer, RecordRequest.Key);
+ Writer.BeginArray("Values");
+ for (const PutCacheRecordRequestValue& Value : RecordRequest.Values)
+ {
+ Writer.BeginObject();
+ {
+ Writer.AddObjectId("Id", Value.Id);
+ const CompressedBuffer& Buffer = Value.Body;
+ if (Buffer)
+ {
+ IoHash AttachmentHash = Buffer.DecodeRawHash(); // TODO: Slow!
+ Writer.AddBinaryAttachment("RawHash", AttachmentHash);
+ OutPackage.AddAttachment(CbAttachment(Buffer, AttachmentHash));
+ Writer.AddInteger("RawSize", Buffer.DecodeRawSize()); // TODO: Slow!
+ }
+ else
+ {
+ if (Value.RawHash == IoHash::Zero)
+ {
+ return false;
+ }
+ Writer.AddBinaryAttachment("RawHash", Value.RawHash);
+ }
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+ WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+ OutPackage.SetObject(Writer.Save());
+
+ return true;
+ }
+
+ bool PutCacheRecordsResult::Parse(const CbPackage& Package)
+ {
+ CbArrayView ResultsArray = Package.GetObject()["Result"].AsArrayView();
+ if (!ResultsArray)
+ {
+ return false;
+ }
+ CbFieldViewIterator It = ResultsArray.CreateViewIterator();
+ while (It.HasValue())
+ {
+ Success.push_back(It.AsBool());
+ It++;
+ }
+ return true;
+ }
+
+ bool PutCacheRecordsResult::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result");
+ for (bool Value : Success)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ OutPackage.SetObject(ResponseObject.Save());
+ return true;
+ }
+
+ bool GetCacheRecordsRequest::Parse(const CbObjectView& RpcRequest)
+ {
+ ZEN_ASSERT(RpcRequest["Method"].AsString() == "GetCacheRecords");
+ AcceptMagic = RpcRequest["AcceptType"].AsUInt32(0);
+ AcceptOptions = RpcRequest["AcceptFlags"].AsUInt16(0);
+ ProcessPid = RpcRequest["Pid"].AsInt32(0);
+
+ CbObjectView Params = RpcRequest["Params"].AsObjectView();
+ std::optional<std::string> RequestNamespace = GetRequestNamespace(Params);
+ if (!RequestNamespace)
+ {
+ return false;
+ }
+
+ Namespace = *RequestNamespace;
+ DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default);
+
+ CbArrayView RequestsArray = Params["Requests"].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"].AsObjectView();
+
+ GetCacheRecordRequest& Request = Requests.emplace_back();
+
+ if (!GetRequestCacheKey(KeyObject, Request.Key))
+ {
+ return false;
+ }
+
+ Request.Policy = Convert(CacheRecordPolicy::Load(RequestObject["Policy"].AsObjectView()));
+ }
+ return true;
+ }
+
+ bool GetCacheRecordsRequest::Parse(const CbPackage& RpcRequest) { return Parse(RpcRequest.GetObject()); }
+
+ bool GetCacheRecordsRequest::Format(CbObjectWriter& Writer, const std::span<const size_t> OptionalRecordFilter) const
+ {
+ Writer << "Method"
+ << "GetCacheRecords";
+ if (AcceptMagic != 0)
+ {
+ Writer << "Accept" << AcceptMagic;
+ }
+ if (AcceptOptions != 0)
+ {
+ Writer << "AcceptFlags" << AcceptOptions;
+ }
+ if (ProcessPid != 0)
+ {
+ Writer << "Pid" << ProcessPid;
+ }
+
+ Writer.BeginObject("Params");
+ {
+ Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy);
+ Writer << "Namespace" << Namespace;
+ Writer.BeginArray("Requests");
+ if (OptionalRecordFilter.empty())
+ {
+ for (const GetCacheRecordRequest& RecordRequest : Requests)
+ {
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, RecordRequest.Key);
+ WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ }
+ else
+ {
+ for (size_t Index : OptionalRecordFilter)
+ {
+ const GetCacheRecordRequest& RecordRequest = Requests[Index];
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, RecordRequest.Key);
+ WriteOptionalCacheRequestPolicy(Writer, "Policy", RecordRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+
+ return true;
+ }
+
+ bool GetCacheRecordsRequest::Format(CbPackage& OutPackage, const std::span<const size_t> OptionalRecordFilter) const
+ {
+ CbObjectWriter Writer;
+ if (!Format(Writer, OptionalRecordFilter))
+ {
+ return false;
+ }
+ OutPackage.SetObject(Writer.Save());
+ return true;
+ }
+
+ bool GetCacheRecordsResult::Parse(const CbPackage& Package, const std::span<const size_t> OptionalRecordResultIndexes)
+ {
+ CbObject ResponseObject = Package.GetObject();
+ CbArrayView ResultsArray = ResponseObject["Result"].AsArrayView();
+ if (!ResultsArray)
+ {
+ return false;
+ }
+
+ Results.reserve(ResultsArray.Num());
+ if (!OptionalRecordResultIndexes.empty() && ResultsArray.Num() != OptionalRecordResultIndexes.size())
+ {
+ return false;
+ }
+ for (size_t Index = 0; CbFieldView RecordView : ResultsArray)
+ {
+ size_t ResultIndex = OptionalRecordResultIndexes.empty() ? Index : OptionalRecordResultIndexes[Index];
+ Index++;
+
+ if (Results.size() <= ResultIndex)
+ {
+ Results.resize(ResultIndex + 1);
+ }
+ if (RecordView.IsNull())
+ {
+ continue;
+ }
+ Results[ResultIndex] = GetCacheRecordResult{};
+ GetCacheRecordResult& Request = Results[ResultIndex].value();
+ CbObjectView RecordObject = RecordView.AsObjectView();
+ CbObjectView KeyObject = RecordObject["Key"].AsObjectView();
+ if (!GetRequestCacheKey(KeyObject, Request.Key))
+ {
+ return false;
+ }
+
+ CbArrayView ValuesArray = RecordObject["Values"].AsArrayView();
+ Request.Values.reserve(ValuesArray.Num());
+ for (CbFieldView Value : ValuesArray)
+ {
+ CbObjectView ValueObject = Value.AsObjectView();
+ IoHash RawHash = ValueObject["RawHash"].AsHash();
+ uint64_t RawSize = ValueObject["RawSize"].AsUInt64();
+ Oid Id = ValueObject["Id"].AsObjectId();
+ const CbAttachment* Attachment = Package.FindAttachment(RawHash);
+ if (!Attachment)
+ {
+ Request.Values.push_back({.Id = Id, .RawHash = RawHash, .RawSize = RawSize, .Body = {}});
+ continue;
+ }
+ if (!Attachment->IsCompressedBinary())
+ {
+ return false;
+ }
+ Request.Values.push_back({.Id = Id, .RawHash = RawHash, .RawSize = RawSize, .Body = Attachment->AsCompressedBinary()});
+ }
+ }
+ return true;
+ }
+
+ bool GetCacheRecordsResult::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter Writer;
+
+ Writer.BeginArray("Result");
+ for (const std::optional<GetCacheRecordResult>& RecordResult : Results)
+ {
+ if (!RecordResult.has_value())
+ {
+ Writer.AddNull();
+ continue;
+ }
+ Writer.BeginObject();
+ WriteCacheRequestKey(Writer, RecordResult->Key);
+
+ Writer.BeginArray("Values");
+ for (const GetCacheRecordResultValue& Value : RecordResult->Values)
+ {
+ IoHash AttachmentHash = Value.Body ? Value.Body.DecodeRawHash() : Value.RawHash;
+ Writer.BeginObject();
+ {
+ Writer.AddObjectId("Id", Value.Id);
+ Writer.AddHash("RawHash", AttachmentHash);
+ Writer.AddInteger("RawSize", Value.Body ? Value.Body.DecodeRawSize() : Value.RawSize);
+ }
+ Writer.EndObject();
+ if (Value.Body)
+ {
+ OutPackage.AddAttachment(CbAttachment(Value.Body, AttachmentHash));
+ }
+ }
+
+ Writer.EndArray();
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+
+ OutPackage.SetObject(Writer.Save());
+ return true;
+ }
+
+ bool PutCacheValuesRequest::Parse(const CbPackage& Package)
+ {
+ CbObjectView BatchObject = Package.GetObject();
+ ZEN_ASSERT(BatchObject["Method"].AsString() == "PutCacheValues");
+ AcceptMagic = BatchObject["AcceptType"].AsUInt32(0);
+
+ CbObjectView Params = BatchObject["Params"].AsObjectView();
+ std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params);
+ if (!RequestNamespace)
+ {
+ return false;
+ }
+
+ Namespace = *RequestNamespace;
+ DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default);
+
+ CbArrayView RequestsArray = Params["Requests"].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"].AsObjectView();
+
+ PutCacheValueRequest& Request = Requests.emplace_back();
+
+ if (!GetRequestCacheKey(KeyObject, Request.Key))
+ {
+ return false;
+ }
+
+ Request.RawHash = RequestObject["RawHash"].AsBinaryAttachment();
+ Request.Policy = GetCachePolicy(RequestObject, "Policy");
+
+ if (const CbAttachment* Attachment = Package.FindAttachment(Request.RawHash))
+ {
+ if (!Attachment->IsCompressedBinary())
+ {
+ return false;
+ }
+ Request.Body = Attachment->AsCompressedBinary();
+ }
+ }
+ return true;
+ }
+
+ bool PutCacheValuesRequest::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter Writer;
+ Writer << "Method"
+ << "PutCacheValues";
+ if (AcceptMagic != 0)
+ {
+ Writer << "Accept" << AcceptMagic;
+ }
+
+ Writer.BeginObject("Params");
+ {
+ Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy);
+ Writer << "Namespace" << Namespace;
+
+ Writer.BeginArray("Requests");
+ for (const PutCacheValueRequest& ValueRequest : Requests)
+ {
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, ValueRequest.Key);
+ if (ValueRequest.Body)
+ {
+ IoHash AttachmentHash = ValueRequest.Body.DecodeRawHash();
+ if (ValueRequest.RawHash != IoHash::Zero && AttachmentHash != ValueRequest.RawHash)
+ {
+ return false;
+ }
+ Writer.AddBinaryAttachment("RawHash", AttachmentHash);
+ OutPackage.AddAttachment(CbAttachment(ValueRequest.Body, AttachmentHash));
+ }
+ else if (ValueRequest.RawHash != IoHash::Zero)
+ {
+ Writer.AddBinaryAttachment("RawHash", ValueRequest.RawHash);
+ }
+ else
+ {
+ return false;
+ }
+ WriteCachePolicy(Writer, "Policy", ValueRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+
+ OutPackage.SetObject(Writer.Save());
+ return true;
+ }
+
+ bool PutCacheValuesResult::Parse(const CbPackage& Package)
+ {
+ CbArrayView ResultsArray = Package.GetObject()["Result"].AsArrayView();
+ if (!ResultsArray)
+ {
+ return false;
+ }
+ CbFieldViewIterator It = ResultsArray.CreateViewIterator();
+ while (It.HasValue())
+ {
+ Success.push_back(It.AsBool());
+ It++;
+ }
+ return true;
+ }
+
+ bool PutCacheValuesResult::Format(CbPackage& OutPackage) const
+ {
+ if (Success.empty())
+ {
+ return false;
+ }
+ CbObjectWriter ResponseObject;
+ ResponseObject.BeginArray("Result");
+ for (bool Value : Success)
+ {
+ ResponseObject.AddBool(Value);
+ }
+ ResponseObject.EndArray();
+
+ OutPackage.SetObject(ResponseObject.Save());
+ return true;
+ }
+
+ bool GetCacheValuesRequest::Parse(const CbObjectView& BatchObject)
+ {
+ ZEN_ASSERT(BatchObject["Method"].AsString() == "GetCacheValues");
+ AcceptMagic = BatchObject["AcceptType"].AsUInt32(0);
+ AcceptOptions = BatchObject["AcceptFlags"].AsUInt16(0);
+ ProcessPid = BatchObject["Pid"].AsInt32(0);
+
+ CbObjectView Params = BatchObject["Params"].AsObjectView();
+ std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params);
+ if (!RequestNamespace)
+ {
+ return false;
+ }
+
+ Namespace = *RequestNamespace;
+ DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default);
+
+ CbArrayView RequestsArray = Params["Requests"].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"].AsObjectView();
+
+ GetCacheValueRequest& Request = Requests.emplace_back();
+
+ if (!GetRequestCacheKey(KeyObject, Request.Key))
+ {
+ return false;
+ }
+
+ Request.Policy = GetCachePolicy(RequestObject, "Policy");
+ }
+ return true;
+ }
+
+ bool GetCacheValuesRequest::Format(CbPackage& OutPackage, const std::span<const size_t> OptionalValueFilter) const
+ {
+ CbObjectWriter Writer;
+ Writer << "Method"
+ << "GetCacheValues";
+ if (AcceptMagic != 0)
+ {
+ Writer << "Accept" << AcceptMagic;
+ }
+ if (AcceptOptions != 0)
+ {
+ Writer << "AcceptFlags" << AcceptOptions;
+ }
+ if (ProcessPid != 0)
+ {
+ Writer << "Pid" << ProcessPid;
+ }
+
+ Writer.BeginObject("Params");
+ {
+ Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy);
+ Writer << "Namespace" << Namespace;
+
+ Writer.BeginArray("Requests");
+ if (OptionalValueFilter.empty())
+ {
+ for (const GetCacheValueRequest& ValueRequest : Requests)
+ {
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, ValueRequest.Key);
+ WriteCachePolicy(Writer, "Policy", ValueRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ }
+ else
+ {
+ for (size_t Index : OptionalValueFilter)
+ {
+ const GetCacheValueRequest& ValueRequest = Requests[Index];
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, ValueRequest.Key);
+ WriteCachePolicy(Writer, "Policy", ValueRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+
+ OutPackage.SetObject(Writer.Save());
+ return true;
+ }
+
+ bool CacheValuesResult::Parse(const CbPackage& Package, const std::span<const size_t> OptionalValueResultIndexes)
+ {
+ CbObject ResponseObject = Package.GetObject();
+ CbArrayView ResultsArray = ResponseObject["Result"].AsArrayView();
+ if (!ResultsArray)
+ {
+ return false;
+ }
+ Results.reserve(ResultsArray.Num());
+ if (!OptionalValueResultIndexes.empty() && ResultsArray.Num() != OptionalValueResultIndexes.size())
+ {
+ return false;
+ }
+ for (size_t Index = 0; CbFieldView RecordView : ResultsArray)
+ {
+ size_t ResultIndex = OptionalValueResultIndexes.empty() ? Index : OptionalValueResultIndexes[Index];
+ Index++;
+
+ if (Results.size() <= ResultIndex)
+ {
+ Results.resize(ResultIndex + 1);
+ }
+ if (RecordView.IsNull())
+ {
+ continue;
+ }
+
+ CacheValueResult& ValueResult = Results[ResultIndex];
+ CbObjectView RecordObject = RecordView.AsObjectView();
+
+ CbFieldView RawHashField = RecordObject["RawHash"];
+ ValueResult.RawHash = RawHashField.AsHash();
+ bool Succeeded = !RawHashField.HasError();
+ if (Succeeded)
+ {
+ const CbAttachment* Attachment = Package.FindAttachment(ValueResult.RawHash);
+ ValueResult.Body = Attachment ? Attachment->AsCompressedBinary() : CompressedBuffer();
+ if (ValueResult.Body)
+ {
+ ValueResult.RawSize = ValueResult.Body.DecodeRawSize();
+ }
+ else
+ {
+ ValueResult.RawSize = RecordObject["RawSize"].AsUInt64(UINT64_MAX);
+ }
+ }
+ }
+ return true;
+ }
+
+ bool CacheValuesResult::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter ResponseObject;
+
+ ResponseObject.BeginArray("Result");
+ for (const CacheValueResult& ValueResult : Results)
+ {
+ ResponseObject.BeginObject();
+ if (ValueResult.RawHash != IoHash::Zero)
+ {
+ ResponseObject.AddHash("RawHash", ValueResult.RawHash);
+ if (ValueResult.Body)
+ {
+ OutPackage.AddAttachment(CbAttachment(ValueResult.Body, ValueResult.RawHash));
+ }
+ else
+ {
+ ResponseObject.AddInteger("RawSize", ValueResult.RawSize);
+ }
+ }
+ ResponseObject.EndObject();
+ }
+ ResponseObject.EndArray();
+
+ OutPackage.SetObject(ResponseObject.Save());
+ return true;
+ }
+
+ bool GetCacheChunksRequest::Parse(const CbObjectView& BatchObject)
+ {
+ ZEN_ASSERT(BatchObject["Method"].AsString() == "GetCacheChunks");
+ AcceptMagic = BatchObject["AcceptType"].AsUInt32(0);
+ AcceptOptions = BatchObject["AcceptFlags"].AsUInt16(0);
+ ProcessPid = BatchObject["Pid"].AsInt32(0);
+
+ CbObjectView Params = BatchObject["Params"].AsObjectView();
+ std::optional<std::string> RequestNamespace = cacherequests::GetRequestNamespace(Params);
+ if (!RequestNamespace)
+ {
+ return false;
+ }
+
+ Namespace = *RequestNamespace;
+ DefaultPolicy = GetCachePolicy(Params, "DefaultPolicy").value_or(CachePolicy::Default);
+
+ CbArrayView RequestsArray = Params["ChunkRequests"].AsArrayView();
+ Requests.reserve(RequestsArray.Num());
+ for (CbFieldView RequestField : RequestsArray)
+ {
+ CbObjectView RequestObject = RequestField.AsObjectView();
+ CbObjectView KeyObject = RequestObject["Key"].AsObjectView();
+
+ GetCacheChunkRequest& Request = Requests.emplace_back();
+
+ if (!GetRequestCacheKey(KeyObject, Request.Key))
+ {
+ return false;
+ }
+
+ Request.ValueId = RequestObject["ValueId"].AsObjectId();
+ Request.ChunkId = RequestObject["ChunkId"].AsHash();
+ Request.RawOffset = RequestObject["RawOffset"].AsUInt64();
+ Request.RawSize = RequestObject["RawSize"].AsUInt64(UINT64_MAX);
+
+ Request.Policy = GetCachePolicy(RequestObject, "Policy");
+ }
+ return true;
+ }
+
+ bool GetCacheChunksRequest::Format(CbPackage& OutPackage) const
+ {
+ CbObjectWriter Writer;
+ Writer << "Method"
+ << "GetCacheChunks";
+ if (AcceptMagic != 0)
+ {
+ Writer << "Accept" << AcceptMagic;
+ }
+ if (AcceptOptions != 0)
+ {
+ Writer << "AcceptFlags" << AcceptOptions;
+ }
+ if (ProcessPid != 0)
+ {
+ Writer << "Pid" << ProcessPid;
+ }
+
+ Writer.BeginObject("Params");
+ {
+ Writer << "DefaultPolicy" << WriteToString<128>(DefaultPolicy);
+ Writer << "Namespace" << Namespace;
+
+ Writer.BeginArray("ChunkRequests");
+ for (const GetCacheChunkRequest& ValueRequest : Requests)
+ {
+ Writer.BeginObject();
+ {
+ WriteCacheRequestKey(Writer, ValueRequest.Key);
+
+ Writer.AddObjectId("ValueId", ValueRequest.ValueId);
+ Writer.AddHash("ChunkId", ValueRequest.ChunkId);
+ Writer.AddInteger("RawOffset", ValueRequest.RawOffset);
+ Writer.AddInteger("RawSize", ValueRequest.RawSize);
+
+ WriteCachePolicy(Writer, "Policy", ValueRequest.Policy);
+ }
+ Writer.EndObject();
+ }
+ Writer.EndArray();
+ }
+ Writer.EndObject();
+
+ OutPackage.SetObject(Writer.Save());
+ return true;
+ }
+
+ bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data)
+ {
+ std::vector<std::string_view> Tokens;
+ uint32_t TokenCount = zen::ForEachStrTok(Key, '/', [&](const std::string_view& Token) {
+ Tokens.push_back(Token);
+ return true;
+ });
+
+ switch (TokenCount)
+ {
+ case 1:
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ return Data.Namespace.has_value();
+ case 2:
+ {
+ std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]);
+ if (PossibleHashKey.has_value())
+ {
+ // Legacy bucket/key request
+ Data.Bucket = GetValidBucketName(Tokens[0]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = PossibleHashKey;
+ return true;
+ }
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ return true;
+ }
+ case 3:
+ {
+ std::optional<IoHash> PossibleHashKey = GetValidIoHash(Tokens[1]);
+ if (PossibleHashKey.has_value())
+ {
+ // Legacy bucket/key/valueid request
+ Data.Bucket = GetValidBucketName(Tokens[0]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = PossibleHashKey;
+ Data.ValueContentId = GetValidIoHash(Tokens[2]);
+ if (!Data.ValueContentId.has_value())
+ {
+ return false;
+ }
+ return true;
+ }
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+ Data.HashKey = GetValidIoHash(Tokens[2]);
+ if (!Data.HashKey)
+ {
+ return false;
+ }
+ return true;
+ }
+ case 4:
+ {
+ Data.Namespace = GetValidNamespaceName(Tokens[0]);
+ if (!Data.Namespace.has_value())
+ {
+ return false;
+ }
+
+ Data.Bucket = GetValidBucketName(Tokens[1]);
+ if (!Data.Bucket.has_value())
+ {
+ return false;
+ }
+
+ Data.HashKey = GetValidIoHash(Tokens[2]);
+ if (!Data.HashKey.has_value())
+ {
+ return false;
+ }
+
+ Data.ValueContentId = GetValidIoHash(Tokens[3]);
+ if (!Data.ValueContentId.has_value())
+ {
+ return false;
+ }
+ return true;
+ }
+ default:
+ return false;
+ }
+ }
+
+ // bool CacheRecord::Parse(CbObjectView& Reader)
+ // {
+ // CbObjectView KeyView = Reader["Key"].AsObjectView();
+ //
+ // if (!GetRequestCacheKey(KeyView, Key))
+ // {
+ // return false;
+ // }
+ // CbArrayView ValuesArray = Reader["Values"].AsArrayView();
+ // Values.reserve(ValuesArray.Num());
+ // for (CbFieldView Value : ValuesArray)
+ // {
+ // CbObjectView ObjectView = Value.AsObjectView();
+ // Values.push_back({.Id = ObjectView["Id"].AsObjectId(),
+ // .RawHash = ObjectView["RawHash"].AsHash(),
+ // .RawSize = ObjectView["RawSize"].AsUInt64()});
+ // }
+ // return true;
+ // }
+ //
+ // bool CacheRecord::Format(CbObjectWriter& Writer) const
+ // {
+ // WriteCacheRequestKey(Writer, Key);
+ // Writer.BeginArray("Values");
+ // for (const CacheRecordValue& Value : Values)
+ // {
+ // Writer.BeginObject();
+ // {
+ // Writer.AddObjectId("Id", Value.Id);
+ // Writer.AddHash("RawHash", Value.RawHash);
+ // Writer.AddInteger("RawSize", Value.RawSize);
+ // }
+ // Writer.EndObject();
+ // }
+ // Writer.EndArray();
+ // return true;
+ // }
+
+#if ZEN_WITH_TESTS
+
+ static bool operator==(const PutCacheRecordRequestValue& Lhs, const PutCacheRecordRequestValue& Rhs)
+ {
+ const IoHash LhsRawHash = Lhs.RawHash != IoHash::Zero ? Lhs.RawHash : Lhs.Body.DecodeRawHash();
+ const IoHash RhsRawHash = Rhs.RawHash != IoHash::Zero ? Rhs.RawHash : Rhs.Body.DecodeRawHash();
+ return Lhs.Id == Rhs.Id && LhsRawHash == RhsRawHash &&
+ Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView());
+ }
+
+ static bool operator==(const zen::CacheValuePolicy& Lhs, const zen::CacheValuePolicy& Rhs)
+ {
+ return (Lhs.Id == Rhs.Id) && (Lhs.Policy == Rhs.Policy);
+ }
+
+ static bool operator==(const std::span<const zen::CacheValuePolicy>& Lhs, const std::span<const zen::CacheValuePolicy>& Rhs)
+ {
+ if (Lhs.size() != Lhs.size())
+ {
+ return false;
+ }
+ for (size_t Idx = 0; Idx < Lhs.size(); ++Idx)
+ {
+ if (Lhs[Idx] != Rhs[Idx])
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ static bool operator==(const zen::CacheRecordPolicy& Lhs, const zen::CacheRecordPolicy& Rhs)
+ {
+ return (Lhs.GetRecordPolicy() == Rhs.GetRecordPolicy()) && (Lhs.GetBasePolicy() == Rhs.GetBasePolicy()) &&
+ (Lhs.GetValuePolicies() == Rhs.GetValuePolicies());
+ }
+
+ static bool operator==(const std::optional<CacheRecordPolicy>& Lhs, const std::optional<CacheRecordPolicy>& Rhs)
+ {
+ return (Lhs.has_value() == Rhs.has_value()) && (!Lhs || (*Lhs == *Rhs));
+ }
+
+ static bool operator==(const PutCacheRecordRequest& Lhs, const PutCacheRecordRequest& Rhs)
+ {
+ return (Lhs.Key == Rhs.Key) && (Lhs.Values == Rhs.Values) && (Lhs.Policy == Rhs.Policy);
+ }
+
+ static bool operator==(const PutCacheRecordsRequest& Lhs, const PutCacheRecordsRequest& Rhs)
+ {
+ return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests);
+ }
+
+ static bool operator==(const PutCacheRecordsResult& Lhs, const PutCacheRecordsResult& Rhs) { return (Lhs.Success == Rhs.Success); }
+
+ static bool operator==(const GetCacheRecordRequest& Lhs, const GetCacheRecordRequest& Rhs)
+ {
+ return (Lhs.Key == Rhs.Key) && (Lhs.Policy == Rhs.Policy);
+ }
+
+ static bool operator==(const GetCacheRecordsRequest& Lhs, const GetCacheRecordsRequest& Rhs)
+ {
+ return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests);
+ }
+
+ static bool operator==(const GetCacheRecordResultValue& Lhs, const GetCacheRecordResultValue& Rhs)
+ {
+ if ((Lhs.Id != Rhs.Id) || (Lhs.RawHash != Rhs.RawHash) || (Lhs.RawSize != Rhs.RawSize))
+ {
+ return false;
+ }
+ if (bool(Lhs.Body) != bool(Rhs.Body))
+ {
+ return false;
+ }
+ if (bool(Lhs.Body) && Lhs.Body.DecodeRawHash() != Rhs.Body.DecodeRawHash())
+ {
+ return false;
+ }
+ return true;
+ }
+
+ static bool operator==(const GetCacheRecordResult& Lhs, const GetCacheRecordResult& Rhs)
+ {
+ return Lhs.Key == Rhs.Key && Lhs.Values == Rhs.Values;
+ }
+
+ static bool operator==(const std::optional<GetCacheRecordResult>& Lhs, const std::optional<GetCacheRecordResult>& Rhs)
+ {
+ if (Lhs.has_value() != Rhs.has_value())
+ {
+ return false;
+ }
+ return *Lhs == Rhs;
+ }
+
+ static bool operator==(const GetCacheRecordsResult& Lhs, const GetCacheRecordsResult& Rhs) { return Lhs.Results == Rhs.Results; }
+
+ static bool operator==(const PutCacheValueRequest& Lhs, const PutCacheValueRequest& Rhs)
+ {
+ if ((Lhs.Key != Rhs.Key) || (Lhs.RawHash != Rhs.RawHash))
+ {
+ return false;
+ }
+
+ if (bool(Lhs.Body) != bool(Rhs.Body))
+ {
+ return false;
+ }
+ if (!Lhs.Body)
+ {
+ return true;
+ }
+ return Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView());
+ }
+
+ static bool operator==(const PutCacheValuesRequest& Lhs, const PutCacheValuesRequest& Rhs)
+ {
+ return (Lhs.DefaultPolicy == Rhs.DefaultPolicy) && (Lhs.Namespace == Rhs.Namespace) && (Lhs.Requests == Rhs.Requests);
+ }
+
+ static bool operator==(const PutCacheValuesResult& Lhs, const PutCacheValuesResult& Rhs) { return (Lhs.Success == Rhs.Success); }
+
+ static bool operator==(const GetCacheValueRequest& Lhs, const GetCacheValueRequest& Rhs)
+ {
+ return Lhs.Key == Rhs.Key && Lhs.Policy == Rhs.Policy;
+ }
+
+ static bool operator==(const GetCacheValuesRequest& Lhs, const GetCacheValuesRequest& Rhs)
+ {
+ return Lhs.DefaultPolicy == Rhs.DefaultPolicy && Lhs.Namespace == Rhs.Namespace && Lhs.Requests == Rhs.Requests;
+ }
+
+ static bool operator==(const CacheValueResult& Lhs, const CacheValueResult& Rhs)
+ {
+ if (Lhs.RawHash != Rhs.RawHash)
+ {
+ return false;
+ };
+ if (Lhs.Body)
+ {
+ if (!Rhs.Body)
+ {
+ return false;
+ }
+ return Lhs.Body.GetCompressed().Flatten().GetView().EqualBytes(Rhs.Body.GetCompressed().Flatten().GetView());
+ }
+ return Lhs.RawSize == Rhs.RawSize;
+ }
+
+ static bool operator==(const CacheValuesResult& Lhs, const CacheValuesResult& Rhs) { return Lhs.Results == Rhs.Results; }
+
+ static bool operator==(const GetCacheChunkRequest& Lhs, const GetCacheChunkRequest& Rhs)
+ {
+ return Lhs.Key == Rhs.Key && Lhs.ValueId == Rhs.ValueId && Lhs.ChunkId == Rhs.ChunkId && Lhs.RawOffset == Rhs.RawOffset &&
+ Lhs.RawSize == Rhs.RawSize && Lhs.Policy == Rhs.Policy;
+ }
+
+ static bool operator==(const GetCacheChunksRequest& Lhs, const GetCacheChunksRequest& Rhs)
+ {
+ return Lhs.DefaultPolicy == Rhs.DefaultPolicy && Lhs.Namespace == Rhs.Namespace && Lhs.Requests == Rhs.Requests;
+ }
+
+ static CompressedBuffer MakeCompressedBuffer(size_t Size) { return CompressedBuffer::Compress(SharedBuffer(IoBuffer(Size))); };
+
+ TEST_CASE("cacherequests.put.cache.records")
+ {
+ PutCacheRecordsRequest EmptyRequest;
+ CbPackage EmptyRequestPackage;
+ CHECK(EmptyRequest.Format(EmptyRequestPackage));
+ PutCacheRecordsRequest EmptyRequestCopy;
+ CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required
+
+ PutCacheRecordsRequest FullRequest = {
+ .DefaultPolicy = CachePolicy::Remote,
+ .Namespace = "the_namespace",
+ .Requests = {{.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(2134)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(213)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(7)}},
+ .Policy = CachePolicy::StoreLocal},
+ {.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1234)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(99)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(124)}},
+ .Policy = CachePolicy::Store},
+ {.Key = {.Bucket = "theotherbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(19)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1248)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(823)}}}}};
+
+ CbPackage FullRequestPackage;
+ CHECK(FullRequest.Format(FullRequestPackage));
+ PutCacheRecordsRequest FullRequestCopy;
+ CHECK(FullRequestCopy.Parse(FullRequestPackage));
+ CHECK(FullRequest == FullRequestCopy);
+
+ PutCacheRecordsResult EmptyResult;
+ CbPackage EmptyResponsePackage;
+ CHECK(EmptyResult.Format(EmptyResponsePackage));
+ PutCacheRecordsResult EmptyResultCopy;
+ CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage));
+ CHECK(EmptyResult == EmptyResultCopy);
+
+ PutCacheRecordsResult FullResult = {.Success = {true, false, true, true, false}};
+ CbPackage FullResponsePackage;
+ CHECK(FullResult.Format(FullResponsePackage));
+ PutCacheRecordsResult FullResultCopy;
+ CHECK(FullResultCopy.Parse(FullResponsePackage));
+ CHECK(FullResult == FullResultCopy);
+ }
+
+ TEST_CASE("cacherequests.get.cache.records")
+ {
+ GetCacheRecordsRequest EmptyRequest;
+ CbPackage EmptyRequestPackage;
+ CHECK(EmptyRequest.Format(EmptyRequestPackage));
+ GetCacheRecordsRequest EmptyRequestCopy;
+ CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required
+
+ GetCacheRecordsRequest FullRequest = {
+ .DefaultPolicy = CachePolicy::StoreLocal,
+ .Namespace = "other_namespace",
+ .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .Policy = CachePolicy::Local},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .Policy = CachePolicy::Remote},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}}}};
+
+ CbPackage FullRequestPackage;
+ CHECK(FullRequest.Format(FullRequestPackage));
+ GetCacheRecordsRequest FullRequestCopy;
+ CHECK(FullRequestCopy.Parse(FullRequestPackage));
+ CHECK(FullRequest == FullRequestCopy);
+
+ CbPackage PartialRequestPackage;
+ CHECK(FullRequest.Format(PartialRequestPackage, std::initializer_list<size_t>{0, 2}));
+ GetCacheRecordsRequest PartialRequest = FullRequest;
+ PartialRequest.Requests.erase(PartialRequest.Requests.begin() + 1);
+ GetCacheRecordsRequest PartialRequestCopy;
+ CHECK(PartialRequestCopy.Parse(PartialRequestPackage));
+ CHECK(PartialRequest == PartialRequestCopy);
+
+ GetCacheRecordsResult EmptyResult;
+ CbPackage EmptyResponsePackage;
+ CHECK(EmptyResult.Format(EmptyResponsePackage));
+ GetCacheRecordsResult EmptyResultCopy;
+ CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage));
+ CHECK(EmptyResult == EmptyResultCopy);
+
+ PutCacheRecordsRequest FullPutRequest = {
+ .DefaultPolicy = CachePolicy::Remote,
+ .Namespace = "the_namespace",
+ .Requests = {{.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(2134)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(213)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(7)}},
+ .Policy = CachePolicy::StoreLocal},
+ {.Key = {.Bucket = "thebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1234)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(99)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(124)}},
+ .Policy = CachePolicy::Store},
+ {.Key = {.Bucket = "theotherbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")},
+ .Values = {{.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(19)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(1248)},
+ {.Id = Oid::NewOid(), .Body = MakeCompressedBuffer(823)}}}}};
+
+ CbPackage FullPutRequestPackage;
+ CHECK(FullPutRequest.Format(FullPutRequestPackage));
+ PutCacheRecordsRequest FullPutRequestCopy;
+ CHECK(FullPutRequestCopy.Parse(FullPutRequestPackage));
+
+ GetCacheRecordsResult FullResult = {
+ {GetCacheRecordResult{.Key = FullPutRequestCopy.Requests[0].Key,
+ .Values = {{.Id = FullPutRequestCopy.Requests[0].Values[0].Id,
+ .RawHash = FullPutRequestCopy.Requests[0].Values[0].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[0].Values[0].Body.DecodeRawSize(),
+ .Body = FullPutRequestCopy.Requests[0].Values[0].Body},
+ {.Id = FullPutRequestCopy.Requests[0].Values[1].Id,
+
+ .RawHash = FullPutRequestCopy.Requests[0].Values[1].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[0].Values[1].Body.DecodeRawSize(),
+ .Body = FullPutRequestCopy.Requests[0].Values[1].Body},
+ {.Id = FullPutRequestCopy.Requests[0].Values[2].Id,
+ .RawHash = FullPutRequestCopy.Requests[0].Values[2].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[0].Values[2].Body.DecodeRawSize(),
+ .Body = FullPutRequestCopy.Requests[0].Values[2].Body}}},
+ {}, // Simulate not have!
+ GetCacheRecordResult{.Key = FullPutRequestCopy.Requests[2].Key,
+ .Values = {{.Id = FullPutRequestCopy.Requests[2].Values[0].Id,
+ .RawHash = FullPutRequestCopy.Requests[2].Values[0].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[2].Values[0].Body.DecodeRawSize(),
+ .Body = FullPutRequestCopy.Requests[2].Values[0].Body},
+ {.Id = FullPutRequestCopy.Requests[2].Values[1].Id,
+ .RawHash = FullPutRequestCopy.Requests[2].Values[1].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[2].Values[1].Body.DecodeRawSize(),
+ .Body = {}}, // Simulate not have
+ {.Id = FullPutRequestCopy.Requests[2].Values[2].Id,
+ .RawHash = FullPutRequestCopy.Requests[2].Values[2].Body.DecodeRawHash(),
+ .RawSize = FullPutRequestCopy.Requests[2].Values[2].Body.DecodeRawSize(),
+ .Body = FullPutRequestCopy.Requests[2].Values[2].Body}}}}};
+ CbPackage FullResponsePackage;
+ CHECK(FullResult.Format(FullResponsePackage));
+ GetCacheRecordsResult FullResultCopy;
+ CHECK(FullResultCopy.Parse(FullResponsePackage));
+ CHECK(FullResult.Results[0] == FullResultCopy.Results[0]);
+ CHECK(!FullResultCopy.Results[1]);
+ CHECK(FullResult.Results[2] == FullResultCopy.Results[2]);
+
+ GetCacheRecordsResult PartialResultCopy;
+ CHECK(PartialResultCopy.Parse(FullResponsePackage, std::initializer_list<size_t>{0, 3, 4}));
+ CHECK(FullResult.Results[0] == PartialResultCopy.Results[0]);
+ CHECK(!PartialResultCopy.Results[1]);
+ CHECK(!PartialResultCopy.Results[2]);
+ CHECK(!PartialResultCopy.Results[3]);
+ CHECK(FullResult.Results[2] == PartialResultCopy.Results[4]);
+ }
+
+ TEST_CASE("cacherequests.put.cache.values")
+ {
+ PutCacheValuesRequest EmptyRequest;
+ CbPackage EmptyRequestPackage;
+ CHECK(EmptyRequest.Format(EmptyRequestPackage));
+ PutCacheValuesRequest EmptyRequestCopy;
+ CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage)); // Namespace is required
+
+ CompressedBuffer Buffers[3] = {MakeCompressedBuffer(969), MakeCompressedBuffer(3469), MakeCompressedBuffer(9)};
+ PutCacheValuesRequest FullRequest = {
+ .DefaultPolicy = CachePolicy::StoreLocal,
+ .Namespace = "other_namespace",
+ .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .RawHash = Buffers[0].DecodeRawHash(),
+ .Body = Buffers[0],
+ .Policy = CachePolicy::Local},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .RawHash = Buffers[1].DecodeRawHash(),
+ .Body = Buffers[1],
+ .Policy = CachePolicy::Remote},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")},
+ .RawHash = Buffers[2].DecodeRawHash()}}};
+
+ CbPackage FullRequestPackage;
+ CHECK(FullRequest.Format(FullRequestPackage));
+ PutCacheValuesRequest FullRequestCopy;
+ CHECK(FullRequestCopy.Parse(FullRequestPackage));
+ CHECK(FullRequest == FullRequestCopy);
+
+ PutCacheValuesResult EmptyResult;
+ CbPackage EmptyResponsePackage;
+ CHECK(!EmptyResult.Format(EmptyResponsePackage));
+
+ PutCacheValuesResult FullResult = {.Success = {true, false, true}};
+
+ CbPackage FullResponsePackage;
+ CHECK(FullResult.Format(FullResponsePackage));
+ PutCacheValuesResult FullResultCopy;
+ CHECK(FullResultCopy.Parse(FullResponsePackage));
+ CHECK(FullResult == FullResultCopy);
+ }
+
+ TEST_CASE("cacherequests.get.cache.values")
+ {
+ GetCacheValuesRequest EmptyRequest;
+ CbPackage EmptyRequestPackage;
+ CHECK(EmptyRequest.Format(EmptyRequestPackage));
+ GetCacheValuesRequest EmptyRequestCopy;
+ CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage.GetObject())); // Namespace is required
+
+ GetCacheValuesRequest FullRequest = {
+ .DefaultPolicy = CachePolicy::StoreLocal,
+ .Namespace = "other_namespace",
+ .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .Policy = CachePolicy::Local},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .Policy = CachePolicy::Remote},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")}}}};
+
+ CbPackage FullRequestPackage;
+ CHECK(FullRequest.Format(FullRequestPackage));
+ GetCacheValuesRequest FullRequestCopy;
+ CHECK(FullRequestCopy.Parse(FullRequestPackage.GetObject()));
+ CHECK(FullRequest == FullRequestCopy);
+
+ CbPackage PartialRequestPackage;
+ CHECK(FullRequest.Format(PartialRequestPackage, std::initializer_list<size_t>{0, 2}));
+ GetCacheValuesRequest PartialRequest = FullRequest;
+ PartialRequest.Requests.erase(PartialRequest.Requests.begin() + 1);
+ GetCacheValuesRequest PartialRequestCopy;
+ CHECK(PartialRequestCopy.Parse(PartialRequestPackage.GetObject()));
+ CHECK(PartialRequest == PartialRequestCopy);
+
+ CacheValuesResult EmptyResult;
+ CbPackage EmptyResponsePackage;
+ CHECK(EmptyResult.Format(EmptyResponsePackage));
+ CacheValuesResult EmptyResultCopy;
+ CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage));
+ CHECK(EmptyResult == EmptyResultCopy);
+
+ CompressedBuffer Buffers[3][3] = {{MakeCompressedBuffer(123), MakeCompressedBuffer(321), MakeCompressedBuffer(333)},
+ {MakeCompressedBuffer(6123), MakeCompressedBuffer(8321), MakeCompressedBuffer(7333)},
+ {MakeCompressedBuffer(5123), MakeCompressedBuffer(2321), MakeCompressedBuffer(2333)}};
+ CacheValuesResult FullResult = {
+ .Results = {CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][0].DecodeRawHash(), .Body = Buffers[0][0]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][1].DecodeRawHash(), .Body = Buffers[0][1]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][2].DecodeRawHash(), .Body = Buffers[0][2]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][0].DecodeRawHash(), .Body = Buffers[2][0]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][1].DecodeRawHash(), .Body = Buffers[2][1]},
+ CacheValueResult{.RawSize = Buffers[2][2].DecodeRawSize(), .RawHash = Buffers[2][2].DecodeRawHash()}}};
+ CbPackage FullResponsePackage;
+ CHECK(FullResult.Format(FullResponsePackage));
+ CacheValuesResult FullResultCopy;
+ CHECK(FullResultCopy.Parse(FullResponsePackage));
+ CHECK(FullResult == FullResultCopy);
+
+ CacheValuesResult PartialResultCopy;
+ CHECK(PartialResultCopy.Parse(FullResponsePackage, std::initializer_list<size_t>{0, 3, 4, 5, 6, 9}));
+ CHECK(PartialResultCopy.Results[0] == FullResult.Results[0]);
+ CHECK(PartialResultCopy.Results[1].RawHash == IoHash::Zero);
+ CHECK(PartialResultCopy.Results[2].RawHash == IoHash::Zero);
+ CHECK(PartialResultCopy.Results[3] == FullResult.Results[1]);
+ CHECK(PartialResultCopy.Results[4] == FullResult.Results[2]);
+ CHECK(PartialResultCopy.Results[5] == FullResult.Results[3]);
+ CHECK(PartialResultCopy.Results[6] == FullResult.Results[4]);
+ CHECK(PartialResultCopy.Results[7].RawHash == IoHash::Zero);
+ CHECK(PartialResultCopy.Results[8].RawHash == IoHash::Zero);
+ CHECK(PartialResultCopy.Results[9] == FullResult.Results[5]);
+ }
+
+ TEST_CASE("cacherequests.get.cache.chunks")
+ {
+ GetCacheChunksRequest EmptyRequest;
+ CbPackage EmptyRequestPackage;
+ CHECK(EmptyRequest.Format(EmptyRequestPackage));
+ GetCacheChunksRequest EmptyRequestCopy;
+ CHECK(!EmptyRequestCopy.Parse(EmptyRequestPackage.GetObject())); // Namespace is required
+
+ GetCacheChunksRequest FullRequest = {
+ .DefaultPolicy = CachePolicy::StoreLocal,
+ .Namespace = "other_namespace",
+ .Requests = {{.Key = {.Bucket = "finebucket", .Hash = IoHash::FromHexString("d1df59fcab06793a5f2c372d795bb907a15cab15")},
+ .ValueId = Oid::NewOid(),
+ .ChunkId = IoHash::FromHexString("ab3917854bfef7e7af2c372d795bb907a15cab15"),
+ .RawOffset = 77,
+ .RawSize = 33,
+ .Policy = CachePolicy::Local},
+ {.Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("177030568fdd461bf4fe5ddbf4d463e514e8178e")},
+ .ValueId = Oid::NewOid(),
+ .ChunkId = IoHash::FromHexString("372d795bb907a15cab15ab3917854bfef7e7af2c"),
+ .Policy = CachePolicy::Remote},
+ {
+ .Key = {.Bucket = "badbucket", .Hash = IoHash::FromHexString("e1ce9e1ac8a6f5953dc14c1fa9512b804ed689df")},
+ .ChunkId = IoHash::FromHexString("372d795bb907a15cab15ab3917854bfef7e7af2c"),
+ }}};
+
+ CbPackage FullRequestPackage;
+ CHECK(FullRequest.Format(FullRequestPackage));
+ GetCacheChunksRequest FullRequestCopy;
+ CHECK(FullRequestCopy.Parse(FullRequestPackage.GetObject()));
+ CHECK(FullRequest == FullRequestCopy);
+
+ GetCacheChunksResult EmptyResult;
+ CbPackage EmptyResponsePackage;
+ CHECK(EmptyResult.Format(EmptyResponsePackage));
+ GetCacheChunksResult EmptyResultCopy;
+ CHECK(!EmptyResultCopy.Parse(EmptyResponsePackage));
+ CHECK(EmptyResult == EmptyResultCopy);
+
+ CompressedBuffer Buffers[3][3] = {{MakeCompressedBuffer(123), MakeCompressedBuffer(321), MakeCompressedBuffer(333)},
+ {MakeCompressedBuffer(6123), MakeCompressedBuffer(8321), MakeCompressedBuffer(7333)},
+ {MakeCompressedBuffer(5123), MakeCompressedBuffer(2321), MakeCompressedBuffer(2333)}};
+ GetCacheChunksResult FullResult = {
+ .Results = {CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][0].DecodeRawHash(), .Body = Buffers[0][0]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][1].DecodeRawHash(), .Body = Buffers[0][1]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[0][2].DecodeRawHash(), .Body = Buffers[0][2]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][0].DecodeRawHash(), .Body = Buffers[2][0]},
+ CacheValueResult{.RawSize = 0, .RawHash = Buffers[2][1].DecodeRawHash(), .Body = Buffers[2][1]},
+ CacheValueResult{.RawSize = Buffers[2][2].DecodeRawSize(), .RawHash = Buffers[2][2].DecodeRawHash()}}};
+ CbPackage FullResponsePackage;
+ CHECK(FullResult.Format(FullResponsePackage));
+ GetCacheChunksResult FullResultCopy;
+ CHECK(FullResultCopy.Parse(FullResponsePackage));
+ CHECK(FullResult == FullResultCopy);
+ }
+
+ TEST_CASE("z$service.parse.relative.Uri")
+ {
+ HttpRequestData LegacyBucketRequestBecomesNamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("test", LegacyBucketRequestBecomesNamespaceRequest));
+ CHECK(LegacyBucketRequestBecomesNamespaceRequest.Namespace == "test");
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.Bucket.has_value());
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.HashKey.has_value());
+ CHECK(!LegacyBucketRequestBecomesNamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData LegacyHashKeyRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", LegacyHashKeyRequest));
+ CHECK(!LegacyHashKeyRequest.Namespace);
+ CHECK(LegacyHashKeyRequest.Bucket == "test");
+ CHECK(LegacyHashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"));
+ CHECK(!LegacyHashKeyRequest.ValueContentId.has_value());
+
+ HttpRequestData LegacyValueContentIdRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789",
+ LegacyValueContentIdRequest));
+ CHECK(!LegacyValueContentIdRequest.Namespace);
+ CHECK(LegacyValueContentIdRequest.Bucket == "test");
+ CHECK(LegacyValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"));
+ CHECK(LegacyValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"));
+
+ HttpRequestData V2DefaultNamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("ue4.ddc", V2DefaultNamespaceRequest));
+ CHECK(V2DefaultNamespaceRequest.Namespace == "ue4.ddc");
+ CHECK(!V2DefaultNamespaceRequest.Bucket.has_value());
+ CHECK(!V2DefaultNamespaceRequest.HashKey.has_value());
+ CHECK(!V2DefaultNamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData V2NamespaceRequest;
+ CHECK(HttpRequestParseRelativeUri("nicenamespace", V2NamespaceRequest));
+ CHECK(V2NamespaceRequest.Namespace == "nicenamespace");
+ CHECK(!V2NamespaceRequest.Bucket.has_value());
+ CHECK(!V2NamespaceRequest.HashKey.has_value());
+ CHECK(!V2NamespaceRequest.ValueContentId.has_value());
+
+ HttpRequestData V2BucketRequestWithDefaultNamespace;
+ CHECK(HttpRequestParseRelativeUri("ue4.ddc/test", V2BucketRequestWithDefaultNamespace));
+ CHECK(V2BucketRequestWithDefaultNamespace.Namespace == "ue4.ddc");
+ CHECK(V2BucketRequestWithDefaultNamespace.Bucket == "test");
+ CHECK(!V2BucketRequestWithDefaultNamespace.HashKey.has_value());
+ CHECK(!V2BucketRequestWithDefaultNamespace.ValueContentId.has_value());
+
+ HttpRequestData V2BucketRequestWithNamespace;
+ CHECK(HttpRequestParseRelativeUri("nicenamespace/test", V2BucketRequestWithNamespace));
+ CHECK(V2BucketRequestWithNamespace.Namespace == "nicenamespace");
+ CHECK(V2BucketRequestWithNamespace.Bucket == "test");
+ CHECK(!V2BucketRequestWithNamespace.HashKey.has_value());
+ CHECK(!V2BucketRequestWithNamespace.ValueContentId.has_value());
+
+ HttpRequestData V2HashKeyRequest;
+ CHECK(HttpRequestParseRelativeUri("test/0123456789abcdef12340123456789abcdef1234", V2HashKeyRequest));
+ CHECK(!V2HashKeyRequest.Namespace);
+ CHECK(V2HashKeyRequest.Bucket == "test");
+ CHECK(V2HashKeyRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"));
+ CHECK(!V2HashKeyRequest.ValueContentId.has_value());
+
+ HttpRequestData V2ValueContentIdRequest;
+ CHECK(HttpRequestParseRelativeUri(
+ "nicenamespace/test/0123456789abcdef12340123456789abcdef1234/56789abcdef12345678956789abcdef123456789",
+ V2ValueContentIdRequest));
+ CHECK(V2ValueContentIdRequest.Namespace == "nicenamespace");
+ CHECK(V2ValueContentIdRequest.Bucket == "test");
+ CHECK(V2ValueContentIdRequest.HashKey == IoHash::FromHexString("0123456789abcdef12340123456789abcdef1234"));
+ CHECK(V2ValueContentIdRequest.ValueContentId == IoHash::FromHexString("56789abcdef12345678956789abcdef123456789"));
+
+ HttpRequestData Invalid;
+ CHECK(!HttpRequestParseRelativeUri("", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("/", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("bad\2_namespace", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("nice/\2\1bucket", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789a", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcdef1234", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/pppppppp89abcdef12340123456789abcdef1234", Invalid));
+ CHECK(!HttpRequestParseRelativeUri("namespace/bucket/0123456789abcdef12340123456789abcdef1234/56789abcd", Invalid));
+ CHECK(!HttpRequestParseRelativeUri(
+ "namespace/bucket/0123456789abcdef12340123456789abcdef1234/ppppppppdef12345678956789abcdef123456789",
+ Invalid));
+ }
+#endif
+} // namespace cacherequests
+
+void
+cacherequests_forcelink()
+{
+}
+
+} // namespace zen
diff --git a/src/zenutil/cache/rpcrecording.cpp b/src/zenutil/cache/rpcrecording.cpp
new file mode 100644
index 000000000..4958a27f6
--- /dev/null
+++ b/src/zenutil/cache/rpcrecording.cpp
@@ -0,0 +1,210 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenutil/basicfile.h>
+#include <zenutil/cache/rpcrecording.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <fmt/format.h>
+#include <gsl/gsl-lite.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+namespace zen::cache {
+struct RecordedRequest
+{
+ uint64_t Offset;
+ uint64_t Length;
+ ZenContentType ContentType;
+ ZenContentType AcceptType;
+};
+
+const uint64_t RecordedRequestBlockSize = 1ull << 31u;
+
+struct RecordedRequestsWriter
+{
+ void BeginWrite(const std::filesystem::path& BasePath)
+ {
+ m_BasePath = BasePath;
+ std::filesystem::create_directories(m_BasePath);
+ }
+
+ void EndWrite()
+ {
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_BlockFiles.clear();
+
+ IoBuffer IndexBuffer(IoBuffer::Wrap, m_Entries.data(), m_Entries.size() * sizeof(RecordedRequest));
+ BasicFile IndexFile;
+ IndexFile.Open(m_BasePath / "index.bin", BasicFile::Mode::kTruncate);
+ std::error_code Ec;
+ IndexFile.WriteAll(IndexBuffer, Ec);
+ IndexFile.Close();
+ m_Entries.clear();
+ }
+
+ uint64_t WriteRequest(ZenContentType ContentType, ZenContentType AcceptType, const IoBuffer& RequestBuffer)
+ {
+ RwLock::ExclusiveLockScope Lock(m_Lock);
+ uint64_t RequestIndex = m_Entries.size();
+ RecordedRequest& Entry = m_Entries.emplace_back(
+ RecordedRequest{.Offset = ~0ull, .Length = RequestBuffer.Size(), .ContentType = ContentType, .AcceptType = AcceptType});
+ if (Entry.Length < 1 * 1024 * 1024)
+ {
+ uint32_t BlockIndex = gsl::narrow<uint32_t>((m_ChunkOffset + Entry.Length) / RecordedRequestBlockSize);
+ if (BlockIndex == m_BlockFiles.size())
+ {
+ std::unique_ptr<BasicFile>& NewBlockFile = m_BlockFiles.emplace_back(std::make_unique<BasicFile>());
+ NewBlockFile->Open(m_BasePath / fmt::format("chunks{}.bin", BlockIndex), BasicFile::Mode::kTruncate);
+ m_ChunkOffset = BlockIndex * RecordedRequestBlockSize;
+ }
+ ZEN_ASSERT(BlockIndex < m_BlockFiles.size());
+ BasicFile* BlockFile = m_BlockFiles[BlockIndex].get();
+ ZEN_ASSERT(BlockFile != nullptr);
+
+ Entry.Offset = m_ChunkOffset;
+ m_ChunkOffset = RoundUp(m_ChunkOffset + Entry.Length, 1u << 4u);
+ Lock.ReleaseNow();
+
+ std::error_code Ec;
+ BlockFile->Write(RequestBuffer.Data(), RequestBuffer.Size(), Entry.Offset - BlockIndex * RecordedRequestBlockSize, Ec);
+ if (Ec)
+ {
+ Entry.Length = 0;
+ return ~0ull;
+ }
+ return RequestIndex;
+ }
+ Lock.ReleaseNow();
+
+ BasicFile RequestFile;
+ RequestFile.Open(m_BasePath / fmt::format("request{}.bin", RequestIndex), BasicFile::Mode::kTruncate);
+ std::error_code Ec;
+ RequestFile.WriteAll(RequestBuffer, Ec);
+ if (Ec)
+ {
+ Entry.Length = 0;
+ return ~0ull;
+ }
+ return RequestIndex;
+ }
+
+ std::filesystem::path m_BasePath;
+ mutable RwLock m_Lock;
+ std::vector<RecordedRequest> m_Entries;
+ std::vector<std::unique_ptr<BasicFile>> m_BlockFiles;
+ uint64_t m_ChunkOffset;
+};
+
+struct RecordedRequestsReader
+{
+ uint64_t BeginRead(const std::filesystem::path& BasePath, bool InMemory)
+ {
+ m_BasePath = BasePath;
+ BasicFile IndexFile;
+ IndexFile.Open(m_BasePath / "index.bin", BasicFile::Mode::kRead);
+ m_Entries.resize(IndexFile.FileSize() / sizeof(RecordedRequest));
+ IndexFile.Read(m_Entries.data(), IndexFile.FileSize(), 0);
+ uint64_t MaxChunkPosition = 0;
+ for (const RecordedRequest& R : m_Entries)
+ {
+ if (R.Offset != ~0ull)
+ {
+ MaxChunkPosition = Max(MaxChunkPosition, R.Offset + R.Length);
+ }
+ }
+ uint32_t BlockCount = gsl::narrow<uint32_t>(MaxChunkPosition / RecordedRequestBlockSize) + 1;
+ m_BlockFiles.resize(BlockCount);
+ for (uint32_t BlockIndex = 0; BlockIndex < BlockCount; ++BlockIndex)
+ {
+ if (InMemory)
+ {
+ BasicFile Chunk;
+ Chunk.Open(m_BasePath / fmt::format("chunks{}.bin", BlockIndex), BasicFile::Mode::kRead);
+ m_BlockFiles[BlockIndex] = Chunk.ReadAll();
+ continue;
+ }
+ m_BlockFiles[BlockIndex] = IoBufferBuilder::MakeFromFile(m_BasePath / fmt::format("chunks{}.bin", BlockIndex));
+ }
+ return m_Entries.size();
+ }
+ void EndRead() { m_BlockFiles.clear(); }
+
+ std::pair<ZenContentType, ZenContentType> ReadRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) const
+ {
+ if (RequestIndex >= m_Entries.size())
+ {
+ return {ZenContentType::kUnknownContentType, ZenContentType::kUnknownContentType};
+ }
+ const RecordedRequest& Entry = m_Entries[RequestIndex];
+ if (Entry.Length == 0)
+ {
+ return {ZenContentType::kUnknownContentType, ZenContentType::kUnknownContentType};
+ }
+ if (Entry.Offset != ~0ull)
+ {
+ uint32_t BlockIndex = gsl::narrow<uint32_t>((Entry.Offset + Entry.Length) / RecordedRequestBlockSize);
+ uint64_t ChunkOffset = Entry.Offset - (BlockIndex * RecordedRequestBlockSize);
+ OutBuffer = IoBuffer(m_BlockFiles[BlockIndex], ChunkOffset, Entry.Length);
+ return {Entry.ContentType, Entry.AcceptType};
+ }
+ OutBuffer = IoBufferBuilder::MakeFromFile(m_BasePath / fmt::format("request{}.bin", RequestIndex));
+ return {Entry.ContentType, Entry.AcceptType};
+ }
+
+ std::filesystem::path m_BasePath;
+ std::vector<RecordedRequest> m_Entries;
+ std::vector<IoBuffer> m_BlockFiles;
+};
+
+class DiskRequestRecorder : public IRpcRequestRecorder
+{
+public:
+ DiskRequestRecorder(const std::filesystem::path& BasePath) { m_RecordedRequests.BeginWrite(BasePath); }
+ virtual ~DiskRequestRecorder() { m_RecordedRequests.EndWrite(); }
+
+private:
+ virtual uint64_t RecordRequest(const ZenContentType ContentType,
+ const ZenContentType AcceptType,
+ const IoBuffer& RequestBuffer) override
+ {
+ return m_RecordedRequests.WriteRequest(ContentType, AcceptType, RequestBuffer);
+ }
+ virtual void RecordResponse(uint64_t, const ZenContentType, const IoBuffer&) override {}
+ virtual void RecordResponse(uint64_t, const ZenContentType, const CompositeBuffer&) override {}
+ RecordedRequestsWriter m_RecordedRequests;
+};
+
+class DiskRequestReplayer : public IRpcRequestReplayer
+{
+public:
+ DiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory)
+ {
+ m_RequestCount = m_RequestBuffer.BeginRead(BasePath, InMemory);
+ }
+ virtual ~DiskRequestReplayer() { m_RequestBuffer.EndRead(); }
+
+private:
+ virtual uint64_t GetRequestCount() const override { return m_RequestCount; }
+
+ virtual std::pair<ZenContentType, ZenContentType> GetRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) override
+ {
+ return m_RequestBuffer.ReadRequest(RequestIndex, OutBuffer);
+ }
+ virtual ZenContentType GetResponse(uint64_t, IoBuffer&) override { return ZenContentType::kUnknownContentType; }
+
+ std::uint64_t m_RequestCount;
+ RecordedRequestsReader m_RequestBuffer;
+};
+
+std::unique_ptr<cache::IRpcRequestRecorder>
+MakeDiskRequestRecorder(const std::filesystem::path& BasePath)
+{
+ return std::make_unique<DiskRequestRecorder>(BasePath);
+}
+
+std::unique_ptr<cache::IRpcRequestReplayer>
+MakeDiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory)
+{
+ return std::make_unique<DiskRequestReplayer>(BasePath, InMemory);
+}
+
+} // namespace zen::cache
diff --git a/src/zenutil/include/zenutil/basicfile.h b/src/zenutil/include/zenutil/basicfile.h
new file mode 100644
index 000000000..877df0f92
--- /dev/null
+++ b/src/zenutil/include/zenutil/basicfile.h
@@ -0,0 +1,113 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iobuffer.h>
+
+#include <filesystem>
+#include <functional>
+
+namespace zen {
+
+class CbObject;
+
+/**
+ * Probably the most basic file abstraction in the universe
+ *
+ * One thing of note is that there is no notion of a "current file position"
+ * in this API -- all reads and writes are done from explicit offsets in
+ * the file. This avoids concurrency issues which can occur otherwise.
+ *
+ */
+
+class BasicFile
+{
+public:
+ BasicFile() = default;
+ ~BasicFile();
+
+ BasicFile(const BasicFile&) = delete;
+ BasicFile& operator=(const BasicFile&) = delete;
+
+ enum class Mode : uint32_t
+ {
+ kRead = 0, // Opens a existing file for read only
+ kWrite = 1, // Opens (or creates) a file for read and write
+ kTruncate = 2, // Opens (or creates) a file for read and write and sets the size to zero
+ kDelete = 3, // Opens (or creates) a file for read and write allowing .DeleteFile file disposition to be set
+ kTruncateDelete =
+ 4 // Opens (or creates) a file for read and write and sets the size to zero allowing .DeleteFile file disposition to be set
+ };
+
+ void Open(const std::filesystem::path& FileName, Mode Mode);
+ void Open(const std::filesystem::path& FileName, Mode Mode, std::error_code& Ec);
+ void Close();
+ void Read(void* Data, uint64_t Size, uint64_t FileOffset);
+ void StreamFile(std::function<void(const void* Data, uint64_t Size)>&& ChunkFun);
+ void StreamByteRange(uint64_t FileOffset, uint64_t Size, std::function<void(const void* Data, uint64_t Size)>&& ChunkFun);
+ void Write(MemoryView Data, uint64_t FileOffset);
+ void Write(MemoryView Data, uint64_t FileOffset, std::error_code& Ec);
+ void Write(const void* Data, uint64_t Size, uint64_t FileOffset);
+ void Write(const void* Data, uint64_t Size, uint64_t FileOffset, std::error_code& Ec);
+ void Flush();
+ uint64_t FileSize();
+ void SetFileSize(uint64_t FileSize);
+ IoBuffer ReadAll();
+ void WriteAll(IoBuffer Data, std::error_code& Ec);
+ void* Detach();
+
+ inline void* Handle() { return m_FileHandle; }
+
+protected:
+ void* m_FileHandle = nullptr; // This is either null or valid
+private:
+};
+
+/**
+ * Simple abstraction for a temporary file
+ *
+ * Works like a regular BasicFile but implements a simple mechanism to allow creating
+ * a temporary file for writing in a directory which may later be moved atomically
+ * into the intended location after it has been fully written to.
+ *
+ */
+
+class TemporaryFile : public BasicFile
+{
+public:
+ TemporaryFile() = default;
+ ~TemporaryFile();
+
+ TemporaryFile(const TemporaryFile&) = delete;
+ TemporaryFile& operator=(const TemporaryFile&) = delete;
+
+ void Close();
+ void CreateTemporary(std::filesystem::path TempDirName, std::error_code& Ec);
+ void MoveTemporaryIntoPlace(std::filesystem::path FinalFileName, std::error_code& Ec);
+ const std::filesystem::path& GetPath() const { return m_TempPath; }
+
+private:
+ std::filesystem::path m_TempPath;
+
+ using BasicFile::Open;
+};
+
+/** Lock file abstraction
+
+ */
+
+class LockFile : protected BasicFile
+{
+public:
+ LockFile();
+ ~LockFile();
+
+ void Create(std::filesystem::path FileName, CbObject Payload, std::error_code& Ec);
+ void Update(CbObject Payload, std::error_code& Ec);
+
+private:
+};
+
+ZENCORE_API void basicfile_forcelink();
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/cache/cache.h b/src/zenutil/include/zenutil/cache/cache.h
new file mode 100644
index 000000000..1a1dd9386
--- /dev/null
+++ b/src/zenutil/include/zenutil/cache/cache.h
@@ -0,0 +1,6 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zenutil/cache/cachekey.h>
+#include <zenutil/cache/cachepolicy.h>
diff --git a/src/zenutil/include/zenutil/cache/cachekey.h b/src/zenutil/include/zenutil/cache/cachekey.h
new file mode 100644
index 000000000..741375946
--- /dev/null
+++ b/src/zenutil/include/zenutil/cache/cachekey.h
@@ -0,0 +1,86 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/iohash.h>
+#include <zencore/string.h>
+#include <zencore/uid.h>
+
+#include <zenutil/cache/cachepolicy.h>
+
+namespace zen {
+
+struct CacheKey
+{
+ std::string Bucket;
+ IoHash Hash;
+
+ static CacheKey Create(std::string_view Bucket, const IoHash& Hash) { return {.Bucket = ToLower(Bucket), .Hash = Hash}; }
+
+ auto operator<=>(const CacheKey& that) const
+ {
+ if (auto b = caseSensitiveCompareStrings(Bucket, that.Bucket); b != std::strong_ordering::equal)
+ {
+ return b;
+ }
+ return Hash <=> that.Hash;
+ }
+
+ auto operator==(const CacheKey& that) const { return (*this <=> that) == std::strong_ordering::equal; }
+
+ static const CacheKey Empty;
+};
+
+struct CacheChunkRequest
+{
+ CacheKey Key;
+ IoHash ChunkId;
+ Oid ValueId;
+ uint64_t RawOffset = 0ull;
+ uint64_t RawSize = ~uint64_t(0);
+ CachePolicy Policy = CachePolicy::Default;
+};
+
+struct CacheKeyRequest
+{
+ CacheKey Key;
+ CacheRecordPolicy Policy;
+};
+
+struct CacheValueRequest
+{
+ CacheKey Key;
+ CachePolicy Policy = CachePolicy::Default;
+};
+
+inline bool
+operator<(const CacheChunkRequest& A, const CacheChunkRequest& B)
+{
+ if (A.Key < B.Key)
+ {
+ return true;
+ }
+ if (B.Key < A.Key)
+ {
+ return false;
+ }
+ if (A.ChunkId < B.ChunkId)
+ {
+ return true;
+ }
+ if (B.ChunkId < A.ChunkId)
+ {
+ return false;
+ }
+ if (A.ValueId < B.ValueId)
+ {
+ return true;
+ }
+ if (B.ValueId < A.ValueId)
+ {
+ return false;
+ }
+ return A.RawOffset < B.RawOffset;
+}
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/cache/cachepolicy.h b/src/zenutil/include/zenutil/cache/cachepolicy.h
new file mode 100644
index 000000000..9a745e42c
--- /dev/null
+++ b/src/zenutil/include/zenutil/cache/cachepolicy.h
@@ -0,0 +1,227 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compactbinary.h>
+#include <zencore/enumflags.h>
+#include <zencore/refcount.h>
+#include <zencore/string.h>
+#include <zencore/uid.h>
+
+#include <gsl/gsl-lite.hpp>
+#include <span>
+namespace zen::Private {
+class ICacheRecordPolicyShared;
+}
+namespace zen {
+
+class CbObjectView;
+class CbWriter;
+
+class OptionalCacheRecordPolicy;
+
+enum class CachePolicy : uint32_t
+{
+ /** A value with no flags. Disables access to the cache unless combined with other flags. */
+ None = 0,
+
+ /** Allow a cache request to query local caches. */
+ QueryLocal = 1 << 0,
+ /** Allow a cache request to query remote caches. */
+ QueryRemote = 1 << 1,
+ /** Allow a cache request to query any caches. */
+ Query = QueryLocal | QueryRemote,
+
+ /** Allow cache requests to query and store records and values in local caches. */
+ StoreLocal = 1 << 2,
+ /** Allow cache records and values to be stored in remote caches. */
+ StoreRemote = 1 << 3,
+ /** Allow cache records and values to be stored in any caches. */
+ Store = StoreLocal | StoreRemote,
+
+ /** Allow cache requests to query and store records and values in local caches. */
+ Local = QueryLocal | StoreLocal,
+ /** Allow cache requests to query and store records and values in remote caches. */
+ Remote = QueryRemote | StoreRemote,
+
+ /** Allow cache requests to query and store records and values in any caches. */
+ Default = Query | Store,
+
+ /** Skip fetching the data for values. */
+ SkipData = 1 << 4,
+
+ /** Skip fetching the metadata for record requests. */
+ SkipMeta = 1 << 5,
+
+ /**
+ * Partial output will be provided with the error status when a required value is missing.
+ *
+ * This is meant for cases when the missing values can be individually recovered, or rebuilt,
+ * without rebuilding the whole record. The cache automatically adds this flag when there are
+ * other cache stores that it may be able to recover missing values from.
+ *
+ * Missing values will be returned in the records, but with only the hash and size.
+ *
+ * Applying this flag for a put of a record allows a partial record to be stored.
+ */
+ PartialRecord = 1 << 6,
+
+ /**
+ * Keep records in the cache for at least the duration of the session.
+ *
+ * This is a hint that the record may be accessed again in this session. This is mainly meant
+ * to be used when subsequent accesses will not tolerate a cache miss.
+ */
+ KeepAlive = 1 << 7,
+};
+
+gsl_DEFINE_ENUM_BITMASK_OPERATORS(CachePolicy);
+/** Append a non-empty text version of the policy to the builder. */
+StringBuilderBase& operator<<(StringBuilderBase& Builder, CachePolicy Policy);
+/** Parse non-empty text written by operator<< into a policy. */
+CachePolicy ParseCachePolicy(std::string_view Text);
+/** Return input converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server. */
+CachePolicy ConvertToUpstream(CachePolicy Policy);
+
+inline CachePolicy
+Union(CachePolicy A, CachePolicy B)
+{
+ constexpr CachePolicy InvertedFlags = CachePolicy::SkipData | CachePolicy::SkipMeta;
+ return (A & ~(InvertedFlags)) | (B & ~(InvertedFlags)) | (A & B & InvertedFlags);
+}
+
+/** A value ID and the cache policy to use for that value. */
+struct CacheValuePolicy
+{
+ Oid Id;
+ CachePolicy Policy = CachePolicy::Default;
+
+ /** Flags that are valid on a value policy. */
+ static constexpr CachePolicy PolicyMask = CachePolicy::Default | CachePolicy::SkipData;
+};
+
+/** Interface for the private implementation of the cache record policy. */
+class Private::ICacheRecordPolicyShared : public RefCounted
+{
+public:
+ virtual ~ICacheRecordPolicyShared() = default;
+ virtual void AddValuePolicy(const CacheValuePolicy& Policy) = 0;
+ virtual std::span<const CacheValuePolicy> GetValuePolicies() const = 0;
+};
+
+/**
+ * Flags to control the behavior of cache record requests, with optional overrides by value.
+ *
+ * Examples:
+ * - A base policy of None with value policy overrides of Default will fetch those values if they
+ * exist in the record, and skip data for any other values.
+ * - A base policy of Default, with value policy overrides of (Query | SkipData), will skip those
+ * values, but still check if they exist, and will load any other values.
+ */
+class CacheRecordPolicy
+{
+public:
+ /** Construct a cache record policy that uses the default policy. */
+ CacheRecordPolicy() = default;
+
+ /** Construct a cache record policy with a uniform policy for the record and every value. */
+ inline CacheRecordPolicy(CachePolicy BasePolicy)
+ : RecordPolicy(BasePolicy)
+ , DefaultValuePolicy(BasePolicy & CacheValuePolicy::PolicyMask)
+ {
+ }
+
+ /** Returns true if the record and every value use the same cache policy. */
+ inline bool IsUniform() const { return !Shared; }
+
+ /** Returns the cache policy to use for the record. */
+ inline CachePolicy GetRecordPolicy() const { return RecordPolicy; }
+
+ /** Returns the base cache policy that this was constructed from. */
+ inline CachePolicy GetBasePolicy() const { return DefaultValuePolicy | (RecordPolicy & ~CacheValuePolicy::PolicyMask); }
+
+ /** Returns the cache policy to use for the value. */
+ CachePolicy GetValuePolicy(const Oid& Id) const;
+
+ /** Returns the array of cache policy overrides for values, sorted by ID. */
+ inline std::span<const CacheValuePolicy> GetValuePolicies() const
+ {
+ return Shared ? Shared->GetValuePolicies() : std::span<const CacheValuePolicy>();
+ }
+
+ /** Saves the cache record policy to a compact binary object. */
+ void Save(CbWriter& Writer) const;
+
+ /** Loads a cache record policy from an object. */
+ static OptionalCacheRecordPolicy Load(CbObjectView Object);
+
+ /** Return *this converted into the equivalent policy that the upstream should use when forwarding a put or get to an upstream server.
+ */
+ CacheRecordPolicy ConvertToUpstream() const;
+
+private:
+ friend class CacheRecordPolicyBuilder;
+ friend class OptionalCacheRecordPolicy;
+
+ CachePolicy RecordPolicy = CachePolicy::Default;
+ CachePolicy DefaultValuePolicy = CachePolicy::Default;
+ RefPtr<const Private::ICacheRecordPolicyShared> Shared;
+};
+
+/** A cache record policy builder is used to construct a cache record policy. */
+class CacheRecordPolicyBuilder
+{
+public:
+ /** Construct a policy builder that uses the default policy as its base policy. */
+ CacheRecordPolicyBuilder() = default;
+
+ /** Construct a policy builder that uses the provided policy for the record and values with no override. */
+ inline explicit CacheRecordPolicyBuilder(CachePolicy Policy) : BasePolicy(Policy) {}
+
+ /** Adds a cache policy override for a value. */
+ void AddValuePolicy(const CacheValuePolicy& Value);
+ inline void AddValuePolicy(const Oid& Id, CachePolicy Policy) { AddValuePolicy({Id, Policy}); }
+
+ /** Build a cache record policy, which makes this builder subsequently unusable. */
+ CacheRecordPolicy Build();
+
+private:
+ CachePolicy BasePolicy = CachePolicy::Default;
+ RefPtr<Private::ICacheRecordPolicyShared> Shared;
+};
+
+/**
+ * A cache record policy that can be null.
+ *
+ * @see CacheRecordPolicy
+ */
+class OptionalCacheRecordPolicy : private CacheRecordPolicy
+{
+public:
+ inline OptionalCacheRecordPolicy() : CacheRecordPolicy(~CachePolicy::None) {}
+
+ inline OptionalCacheRecordPolicy(CacheRecordPolicy&& InOutput) : CacheRecordPolicy(std::move(InOutput)) {}
+ inline OptionalCacheRecordPolicy(const CacheRecordPolicy& InOutput) : CacheRecordPolicy(InOutput) {}
+ inline OptionalCacheRecordPolicy& operator=(CacheRecordPolicy&& InOutput)
+ {
+ CacheRecordPolicy::operator=(std::move(InOutput));
+ return *this;
+ }
+ inline OptionalCacheRecordPolicy& operator=(const CacheRecordPolicy& InOutput)
+ {
+ CacheRecordPolicy::operator=(InOutput);
+ return *this;
+ }
+
+ /** Returns the cache record policy. The caller must check for null before using this accessor. */
+ inline const CacheRecordPolicy& Get() const& { return *this; }
+ inline CacheRecordPolicy Get() && { return std::move(*this); }
+
+ inline bool IsNull() const { return RecordPolicy == ~CachePolicy::None; }
+ inline bool IsValid() const { return !IsNull(); }
+ inline explicit operator bool() const { return !IsNull(); }
+
+ inline void Reset() { *this = OptionalCacheRecordPolicy(); }
+};
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/cache/cacherequests.h b/src/zenutil/include/zenutil/cache/cacherequests.h
new file mode 100644
index 000000000..f1999ebfe
--- /dev/null
+++ b/src/zenutil/include/zenutil/cache/cacherequests.h
@@ -0,0 +1,279 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compress.h>
+
+#include "cachekey.h"
+#include "cachepolicy.h"
+
+#include <functional>
+
+namespace zen {
+
+class CbPackage;
+class CbObjectWriter;
+class CbObjectView;
+
+namespace cacherequests {
+ // I'd really like to get rid of std::optional<CacheRecordPolicy> (or really the class CacheRecordPolicy)
+ //
+ // CacheRecordPolicy has a record level policy but it can also contain policies for individual
+ // values inside the record.
+ //
+ // However, when we do a "PutCacheRecords" we already list the individual Values with their Id
+ // so we can just as well use an optional plain CachePolicy for each value.
+ //
+ // In "GetCacheRecords" we do not currently as for the individual values but you can add
+ // a policy on a per-value level in the std::optional<CacheRecordPolicy> Policy for each record.
+ //
+ // But as we already need to know the Ids of the values we want to set the policy for
+ // it would be simpler to add an array of requested values which each has an optional policy.
+ //
+ // We could add:
+ // struct GetCacheRecordValueRequest
+ // {
+ // Oid Id;
+ // std::optional<CachePolicy> Policy;
+ // };
+ //
+ // and change GetCacheRecordRequest to
+ // struct GetCacheRecordRequest
+ // {
+ // CacheKey Key = CacheKey::Empty;
+ // std::vector<GetCacheRecordValueRequest> ValueRequests;
+ // std::optional<CachePolicy> Policy;
+ // };
+ //
+ // This way we don't need the complex CacheRecordPolicy class and the request becomes
+ // more uniform and easier to understand.
+ //
+ // Would need to decide what the ValueRequests actually mean:
+ // Do they dictate which values to fetch or just a change of the policy?
+ // If they dictate the values to fetch you need to know all the value ids to set them
+ // and that is unlikely what we want - we want to be able to get a cache record with
+ // all its values without knowing all the Ids, right?
+ //
+
+ //////////////////////////////////////////////////////////////////////////
+ // Put 1..n structured cache records with optional attachments
+
+ struct PutCacheRecordRequestValue
+ {
+ Oid Id = Oid::Zero;
+ IoHash RawHash = IoHash::Zero; // If Body is not set, this must be set and the value must already exist in cache
+ CompressedBuffer Body = CompressedBuffer::Null;
+ };
+
+ struct PutCacheRecordRequest
+ {
+ CacheKey Key = CacheKey::Empty;
+ std::vector<PutCacheRecordRequestValue> Values;
+ std::optional<CacheRecordPolicy> Policy;
+ };
+
+ struct PutCacheRecordsRequest
+ {
+ uint32_t AcceptMagic = 0;
+ CachePolicy DefaultPolicy = CachePolicy::Default;
+ std::string Namespace;
+ std::vector<PutCacheRecordRequest> Requests;
+
+ bool Parse(const CbPackage& Package);
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ struct PutCacheRecordsResult
+ {
+ std::vector<bool> Success;
+
+ bool Parse(const CbPackage& Package);
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ // Get 1..n structured cache records with optional attachments
+ // We can get requests for a cache record where we want care about a particular
+ // value id which we now of, but we don't know the ids of the other values and
+ // we still want them.
+ // Not sure if in that case we want different policies for the different attachemnts?
+
+ struct GetCacheRecordRequest
+ {
+ CacheKey Key = CacheKey::Empty;
+ std::optional<CacheRecordPolicy> Policy;
+ };
+
+ struct GetCacheRecordsRequest
+ {
+ uint32_t AcceptMagic = 0;
+ uint16_t AcceptOptions = 0;
+ int32_t ProcessPid = 0;
+ CachePolicy DefaultPolicy = CachePolicy::Default;
+ std::string Namespace;
+ std::vector<GetCacheRecordRequest> Requests;
+
+ bool Parse(const CbPackage& RpcRequest);
+ bool Parse(const CbObjectView& RpcRequest);
+ bool Format(CbPackage& OutPackage, const std::span<const size_t> OptionalRecordFilter = {}) const;
+ bool Format(CbObjectWriter& Writer, const std::span<const size_t> OptionalRecordFilter = {}) const;
+ };
+
+ struct GetCacheRecordResultValue
+ {
+ Oid Id = Oid::Zero;
+ IoHash RawHash = IoHash::Zero;
+ uint64_t RawSize = 0;
+ CompressedBuffer Body = CompressedBuffer::Null;
+ };
+
+ struct GetCacheRecordResult
+ {
+ CacheKey Key = CacheKey::Empty;
+ std::vector<GetCacheRecordResultValue> Values;
+ };
+
+ struct GetCacheRecordsResult
+ {
+ std::vector<std::optional<GetCacheRecordResult>> Results;
+
+ bool Parse(const CbPackage& Package, const std::span<const size_t> OptionalRecordResultIndexes = {});
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ // Put 1..n unstructured cache objects
+
+ struct PutCacheValueRequest
+ {
+ CacheKey Key = CacheKey::Empty;
+ IoHash RawHash = IoHash::Zero;
+ CompressedBuffer Body = CompressedBuffer::Null; // If not set the value is expected to already exist in cache store
+ std::optional<CachePolicy> Policy;
+ };
+
+ struct PutCacheValuesRequest
+ {
+ uint32_t AcceptMagic = 0;
+ CachePolicy DefaultPolicy = CachePolicy::Default;
+ std::string Namespace;
+ std::vector<PutCacheValueRequest> Requests;
+
+ bool Parse(const CbPackage& Package);
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ struct PutCacheValuesResult
+ {
+ std::vector<bool> Success;
+
+ bool Parse(const CbPackage& Package);
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ // Get 1..n unstructured cache objects (stored data may be structured or unstructured)
+
+ struct GetCacheValueRequest
+ {
+ CacheKey Key = CacheKey::Empty;
+ std::optional<CachePolicy> Policy;
+ };
+
+ struct GetCacheValuesRequest
+ {
+ uint32_t AcceptMagic = 0;
+ uint16_t AcceptOptions = 0;
+ int32_t ProcessPid = 0;
+ CachePolicy DefaultPolicy = CachePolicy::Default;
+ std::string Namespace;
+ std::vector<GetCacheValueRequest> Requests;
+
+ bool Parse(const CbObjectView& BatchObject);
+ bool Format(CbPackage& OutPackage, const std::span<const size_t> OptionalValueFilter = {}) const;
+ };
+
+ struct CacheValueResult
+ {
+ uint64_t RawSize = 0;
+ IoHash RawHash = IoHash::Zero;
+ CompressedBuffer Body = CompressedBuffer::Null;
+ };
+
+ struct CacheValuesResult
+ {
+ std::vector<CacheValueResult> Results;
+
+ bool Parse(const CbPackage& Package, const std::span<const size_t> OptionalValueResultIndexes = {});
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ typedef CacheValuesResult GetCacheValuesResult;
+
+ //////////////////////////////////////////////////////////////////////////
+ // Get 1..n cache record values (attachments) for 1..n records
+
+ struct GetCacheChunkRequest
+ {
+ CacheKey Key;
+ Oid ValueId = Oid::Zero; // Set if ChunkId is not known at request time
+ IoHash ChunkId = IoHash::Zero;
+ uint64_t RawOffset = 0ull;
+ uint64_t RawSize = ~uint64_t(0);
+ std::optional<CachePolicy> Policy;
+ };
+
+ struct GetCacheChunksRequest
+ {
+ uint32_t AcceptMagic = 0;
+ uint16_t AcceptOptions = 0;
+ int32_t ProcessPid = 0;
+ CachePolicy DefaultPolicy = CachePolicy::Default;
+ std::string Namespace;
+ std::vector<GetCacheChunkRequest> Requests;
+
+ bool Parse(const CbObjectView& BatchObject);
+ bool Format(CbPackage& OutPackage) const;
+ };
+
+ typedef CacheValuesResult GetCacheChunksResult;
+
+ //////////////////////////////////////////////////////////////////////////
+
+ struct HttpRequestData
+ {
+ std::optional<std::string> Namespace;
+ std::optional<std::string> Bucket;
+ std::optional<IoHash> HashKey;
+ std::optional<IoHash> ValueContentId;
+ };
+
+ bool HttpRequestParseRelativeUri(std::string_view Key, HttpRequestData& Data);
+
+ // Temporarily public
+ std::optional<std::string> GetRequestNamespace(const CbObjectView& Params);
+ bool GetRequestCacheKey(const CbObjectView& KeyView, CacheKey& Key);
+
+ //////////////////////////////////////////////////////////////////////////
+
+ // struct CacheRecordValue
+ // {
+ // Oid Id = Oid::Zero;
+ // IoHash RawHash = IoHash::Zero;
+ // uint64_t RawSize = 0;
+ // };
+ //
+ // struct CacheRecord
+ // {
+ // CacheKey Key = CacheKey::Empty;
+ // std::vector<CacheRecordValue> Values;
+ //
+ // bool Parse(CbObjectView& Reader);
+ // bool Format(CbObjectWriter& Writer) const;
+ // };
+
+} // namespace cacherequests
+
+void cacherequests_forcelink(); // internal
+
+} // namespace zen
diff --git a/src/zenutil/include/zenutil/cache/rpcrecording.h b/src/zenutil/include/zenutil/cache/rpcrecording.h
new file mode 100644
index 000000000..6d65a532a
--- /dev/null
+++ b/src/zenutil/include/zenutil/cache/rpcrecording.h
@@ -0,0 +1,29 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/compositebuffer.h>
+#include <zencore/iobuffer.h>
+
+namespace zen::cache {
+class IRpcRequestRecorder
+{
+public:
+ virtual ~IRpcRequestRecorder() {}
+ virtual uint64_t RecordRequest(const ZenContentType ContentType, const ZenContentType AcceptType, const IoBuffer& RequestBuffer) = 0;
+ virtual void RecordResponse(uint64_t RequestIndex, const ZenContentType ContentType, const IoBuffer& ResponseBuffer) = 0;
+ virtual void RecordResponse(uint64_t RequestIndex, const ZenContentType ContentType, const CompositeBuffer& ResponseBuffer) = 0;
+};
+class IRpcRequestReplayer
+{
+public:
+ virtual ~IRpcRequestReplayer() {}
+ virtual uint64_t GetRequestCount() const = 0;
+ virtual std::pair<ZenContentType, ZenContentType> GetRequest(uint64_t RequestIndex, IoBuffer& OutBuffer) = 0;
+ virtual ZenContentType GetResponse(uint64_t RequestIndex, IoBuffer& OutBuffer) = 0;
+};
+
+std::unique_ptr<cache::IRpcRequestRecorder> MakeDiskRequestRecorder(const std::filesystem::path& BasePath);
+std::unique_ptr<cache::IRpcRequestReplayer> MakeDiskRequestReplayer(const std::filesystem::path& BasePath, bool InMemory);
+
+} // namespace zen::cache
diff --git a/src/zenutil/include/zenutil/zenserverprocess.h b/src/zenutil/include/zenutil/zenserverprocess.h
new file mode 100644
index 000000000..1c204c144
--- /dev/null
+++ b/src/zenutil/include/zenutil/zenserverprocess.h
@@ -0,0 +1,141 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#pragma once
+
+#include <zencore/enumflags.h>
+#include <zencore/logging.h>
+#include <zencore/thread.h>
+#include <zencore/uid.h>
+
+#include <atomic>
+#include <filesystem>
+#include <optional>
+
+namespace zen {
+
+class ZenServerEnvironment
+{
+public:
+ ZenServerEnvironment();
+ ~ZenServerEnvironment();
+
+ void Initialize(std::filesystem::path ProgramBaseDir);
+ void InitializeForTest(std::filesystem::path ProgramBaseDir, std::filesystem::path TestBaseDir, std::string_view ServerClass = "");
+
+ std::filesystem::path CreateNewTestDir();
+ std::filesystem::path ProgramBaseDir() const { return m_ProgramBaseDir; }
+ std::filesystem::path GetTestRootDir(std::string_view Path);
+ inline bool IsInitialized() const { return m_IsInitialized; }
+ inline bool IsTestEnvironment() const { return m_IsTestInstance; }
+ inline std::string_view GetServerClass() const { return m_ServerClass; }
+
+private:
+ std::filesystem::path m_ProgramBaseDir;
+ std::filesystem::path m_TestBaseDir;
+ bool m_IsInitialized = false;
+ bool m_IsTestInstance = false;
+ std::string m_ServerClass;
+};
+
+struct ZenServerInstance
+{
+ ZenServerInstance(ZenServerEnvironment& TestEnvironment);
+ ~ZenServerInstance();
+
+ void Shutdown();
+ void SignalShutdown();
+ void WaitUntilReady();
+ [[nodiscard]] bool WaitUntilReady(int Timeout);
+ void EnableTermination() { m_Terminate = true; }
+ void Detach();
+ inline int GetPid() { return m_Process.Pid(); }
+ inline void SetOwnerPid(int Pid) { m_OwnerPid = Pid; }
+
+ void SetTestDir(std::filesystem::path TestDir)
+ {
+ ZEN_ASSERT(!m_Process.IsValid());
+ m_TestDir = TestDir;
+ }
+
+ void SpawnServer(int BasePort = 0, std::string_view AdditionalServerArgs = std::string_view());
+
+ void AttachToRunningServer(int BasePort = 0);
+
+ std::string GetBaseUri() const;
+
+private:
+ ZenServerEnvironment& m_Env;
+ ProcessHandle m_Process;
+ NamedEvent m_ReadyEvent;
+ NamedEvent m_ShutdownEvent;
+ bool m_Terminate = false;
+ std::filesystem::path m_TestDir;
+ int m_BasePort = 0;
+ std::optional<int> m_OwnerPid;
+
+ void CreateShutdownEvent(int BasePort);
+};
+
+/** Shared system state
+ *
+ * Used as a scratchpad to identify running instances etc
+ *
+ * The state lives in a memory-mapped file backed by the swapfile
+ *
+ */
+
+class ZenServerState
+{
+public:
+ ZenServerState();
+ ~ZenServerState();
+
+ struct ZenServerEntry
+ {
+ // NOTE: any changes to this should consider backwards compatibility
+ // which means you should not rearrange members only potentially
+ // add something to the end or use a different mechanism for
+ // additional state. For example, you can use the session ID
+ // to introduce additional named objects
+ std::atomic<uint32_t> Pid;
+ std::atomic<uint16_t> DesiredListenPort;
+ std::atomic<uint16_t> Flags;
+ uint8_t SessionId[12];
+ std::atomic<uint32_t> SponsorPids[8];
+ std::atomic<uint16_t> EffectiveListenPort;
+ uint8_t Padding[10];
+
+ enum class FlagsEnum : uint16_t
+ {
+ kShutdownPlease = 1 << 0,
+ kIsReady = 1 << 1,
+ };
+
+ FRIEND_ENUM_CLASS_FLAGS(FlagsEnum);
+
+ Oid GetSessionId() const { return Oid::FromMemory(SessionId); }
+ void Reset();
+ void SignalShutdownRequest();
+ void SignalReady();
+ bool AddSponsorProcess(uint32_t Pid);
+ };
+
+ static_assert(sizeof(ZenServerEntry) == 64);
+
+ void Initialize();
+ [[nodiscard]] bool InitializeReadOnly();
+ [[nodiscard]] ZenServerEntry* Lookup(int DesiredListenPort);
+ ZenServerEntry* Register(int DesiredListenPort);
+ void Sweep();
+ void Snapshot(std::function<void(const ZenServerEntry&)>&& Callback);
+ inline bool IsReadOnly() const { return m_IsReadOnly; }
+
+private:
+ void* m_hMapFile = nullptr;
+ ZenServerEntry* m_Data = nullptr;
+ int m_MaxEntryCount = 65536 / sizeof(ZenServerEntry);
+ ZenServerEntry* m_OurEntry = nullptr;
+ bool m_IsReadOnly = true;
+};
+
+} // namespace zen
diff --git a/src/zenutil/xmake.lua b/src/zenutil/xmake.lua
new file mode 100644
index 000000000..e7d849bb2
--- /dev/null
+++ b/src/zenutil/xmake.lua
@@ -0,0 +1,9 @@
+-- Copyright Epic Games, Inc. All Rights Reserved.
+
+target('zenutil')
+ set_kind("static")
+ add_headerfiles("**.h")
+ add_files("**.cpp")
+ add_includedirs("include", {public=true})
+ add_deps("zencore")
+ add_packages("vcpkg::spdlog")
diff --git a/src/zenutil/zenserverprocess.cpp b/src/zenutil/zenserverprocess.cpp
new file mode 100644
index 000000000..5ecde343b
--- /dev/null
+++ b/src/zenutil/zenserverprocess.cpp
@@ -0,0 +1,677 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "zenutil/zenserverprocess.h"
+
+#include <zencore/except.h>
+#include <zencore/filesystem.h>
+#include <zencore/fmtutils.h>
+#include <zencore/logging.h>
+#include <zencore/session.h>
+#include <zencore/string.h>
+#include <zencore/thread.h>
+
+#include <atomic>
+
+#if ZEN_PLATFORM_WINDOWS
+# include <atlbase.h>
+# include <zencore/windows.h>
+#else
+# include <sys/mman.h>
+#endif
+
+//////////////////////////////////////////////////////////////////////////
+
+namespace zen {
+
+namespace zenutil {
+#if ZEN_PLATFORM_WINDOWS
+ class SecurityAttributes
+ {
+ public:
+ inline SECURITY_ATTRIBUTES* Attributes() { return &m_Attributes; }
+
+ protected:
+ SECURITY_ATTRIBUTES m_Attributes{};
+ SECURITY_DESCRIPTOR m_Sd{};
+ };
+
+ // Security attributes which allows any user access
+
+ class AnyUserSecurityAttributes : public SecurityAttributes
+ {
+ public:
+ AnyUserSecurityAttributes()
+ {
+ m_Attributes.nLength = sizeof m_Attributes;
+ m_Attributes.bInheritHandle = false; // Disable inheritance
+
+ const BOOL Success = InitializeSecurityDescriptor(&m_Sd, SECURITY_DESCRIPTOR_REVISION);
+
+ if (Success)
+ {
+ if (!SetSecurityDescriptorDacl(&m_Sd, TRUE, (PACL)NULL, FALSE))
+ {
+ ThrowLastError("SetSecurityDescriptorDacl failed");
+ }
+
+ m_Attributes.lpSecurityDescriptor = &m_Sd;
+ }
+ }
+ };
+#endif // ZEN_PLATFORM_WINDOWS
+
+} // namespace zenutil
+
+//////////////////////////////////////////////////////////////////////////
+
+ZenServerState::ZenServerState()
+{
+}
+
+ZenServerState::~ZenServerState()
+{
+ if (m_OurEntry)
+ {
+ // Clean up our entry now that we're leaving
+
+ m_OurEntry->Reset();
+ m_OurEntry = nullptr;
+ }
+
+#if ZEN_PLATFORM_WINDOWS
+ if (m_Data)
+ {
+ UnmapViewOfFile(m_Data);
+ }
+
+ if (m_hMapFile)
+ {
+ CloseHandle(m_hMapFile);
+ }
+#else
+ if (m_Data != nullptr)
+ {
+ munmap(m_Data, m_MaxEntryCount * sizeof(ZenServerEntry));
+ }
+
+ int Fd = int(intptr_t(m_hMapFile));
+ close(Fd);
+#endif
+
+ m_Data = nullptr;
+}
+
+void
+ZenServerState::Initialize()
+{
+ size_t MapSize = m_MaxEntryCount * sizeof(ZenServerEntry);
+
+#if ZEN_PLATFORM_WINDOWS
+ // TODO: there's a small chance of a race here, this logic could be tightened up with a mutex to
+ // ensure only a single process at a time creates the mapping
+ // TODO: the fallback to Local instead of Global has a flaw where if you start a non-elevated instance
+ // first then start an elevated instance second you'll have the first instance with a local
+ // mapping and the second instance with a global mapping. This kind of elevated/non-elevated
+ // shouldn't be common, but handling for it should be improved in the future.
+
+ HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Global\\ZenMap");
+ if (hMap == NULL)
+ {
+ hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Local\\ZenMap");
+ }
+
+ if (hMap == NULL)
+ {
+ // Security attributes to enable any user to access state
+ zenutil::AnyUserSecurityAttributes Attrs;
+
+ hMap = CreateFileMapping(INVALID_HANDLE_VALUE, // use paging file
+ Attrs.Attributes(), // allow anyone to access
+ PAGE_READWRITE, // read/write access
+ 0, // maximum object size (high-order DWORD)
+ DWORD(MapSize), // maximum object size (low-order DWORD)
+ L"Global\\ZenMap"); // name of mapping object
+
+ if (hMap == NULL)
+ {
+ hMap = CreateFileMapping(INVALID_HANDLE_VALUE, // use paging file
+ Attrs.Attributes(), // allow anyone to access
+ PAGE_READWRITE, // read/write access
+ 0, // maximum object size (high-order DWORD)
+ m_MaxEntryCount * sizeof(ZenServerEntry), // maximum object size (low-order DWORD)
+ L"Local\\ZenMap"); // name of mapping object
+ }
+
+ if (hMap == NULL)
+ {
+ ThrowLastError("Could not open or create file mapping object for Zen server state");
+ }
+ }
+
+ void* pBuf = MapViewOfFile(hMap, // handle to map object
+ FILE_MAP_ALL_ACCESS, // read/write permission
+ 0, // offset high
+ 0, // offset low
+ DWORD(MapSize));
+
+ if (pBuf == NULL)
+ {
+ ThrowLastError("Could not map view of Zen server state");
+ }
+#else
+ int Fd = shm_open("/UnrealEngineZen", O_RDWR | O_CREAT | O_CLOEXEC, 0666);
+ if (Fd < 0)
+ {
+ ThrowLastError("Could not open a shared memory object");
+ }
+ fchmod(Fd, 0666);
+ void* hMap = (void*)intptr_t(Fd);
+
+ int Result = ftruncate(Fd, MapSize);
+ ZEN_UNUSED(Result);
+
+ void* pBuf = mmap(nullptr, MapSize, PROT_READ | PROT_WRITE, MAP_SHARED, Fd, 0);
+ if (pBuf == MAP_FAILED)
+ {
+ ThrowLastError("Could not map view of Zen server state");
+ }
+#endif
+
+ m_hMapFile = hMap;
+ m_Data = reinterpret_cast<ZenServerEntry*>(pBuf);
+ m_IsReadOnly = false;
+}
+
+bool
+ZenServerState::InitializeReadOnly()
+{
+ size_t MapSize = m_MaxEntryCount * sizeof(ZenServerEntry);
+
+#if ZEN_PLATFORM_WINDOWS
+ HANDLE hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Global\\ZenMap");
+ if (hMap == NULL)
+ {
+ hMap = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, L"Local\\ZenMap");
+ }
+
+ if (hMap == NULL)
+ {
+ return false;
+ }
+
+ void* pBuf = MapViewOfFile(hMap, // handle to map object
+ FILE_MAP_READ, // read permission
+ 0, // offset high
+ 0, // offset low
+ MapSize);
+
+ if (pBuf == NULL)
+ {
+ ThrowLastError("Could not map view of Zen server state");
+ }
+#else
+ int Fd = shm_open("/UnrealEngineZen", O_RDONLY | O_CLOEXEC, 0666);
+ if (Fd < 0)
+ {
+ return false;
+ }
+ void* hMap = (void*)intptr_t(Fd);
+
+ void* pBuf = mmap(nullptr, MapSize, PROT_READ, MAP_PRIVATE, Fd, 0);
+ if (pBuf == MAP_FAILED)
+ {
+ ThrowLastError("Could not map read-only view of Zen server state");
+ }
+#endif
+
+ m_hMapFile = hMap;
+ m_Data = reinterpret_cast<ZenServerEntry*>(pBuf);
+
+ return true;
+}
+
+ZenServerState::ZenServerEntry*
+ZenServerState::Lookup(int DesiredListenPort)
+{
+ for (int i = 0; i < m_MaxEntryCount; ++i)
+ {
+ if (m_Data[i].DesiredListenPort == DesiredListenPort)
+ {
+ return &m_Data[i];
+ }
+ }
+
+ return nullptr;
+}
+
+ZenServerState::ZenServerEntry*
+ZenServerState::Register(int DesiredListenPort)
+{
+ if (m_Data == nullptr)
+ {
+ return nullptr;
+ }
+
+ // Allocate an entry
+
+ int Pid = GetCurrentProcessId();
+
+ for (int i = 0; i < m_MaxEntryCount; ++i)
+ {
+ ZenServerEntry& Entry = m_Data[i];
+
+ if (Entry.DesiredListenPort.load(std::memory_order_relaxed) == 0)
+ {
+ uint16_t Expected = 0;
+ if (Entry.DesiredListenPort.compare_exchange_strong(Expected, uint16_t(DesiredListenPort)))
+ {
+ // Successfully allocated entry
+
+ m_OurEntry = &Entry;
+
+ Entry.Pid = Pid;
+ Entry.EffectiveListenPort = 0;
+ Entry.Flags = 0;
+
+ const Oid SesId = GetSessionId();
+ memcpy(Entry.SessionId, &SesId, sizeof SesId);
+
+ return &Entry;
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+void
+ZenServerState::Sweep()
+{
+ if (m_Data == nullptr)
+ {
+ return;
+ }
+
+ ZEN_ASSERT(m_IsReadOnly == false);
+
+ for (int i = 0; i < m_MaxEntryCount; ++i)
+ {
+ ZenServerEntry& Entry = m_Data[i];
+
+ if (Entry.DesiredListenPort)
+ {
+ if (IsProcessRunning(Entry.Pid) == false)
+ {
+ ZEN_DEBUG("Sweep - pid {} not running, reclaiming entry (port {})", Entry.Pid, Entry.DesiredListenPort);
+
+ Entry.Reset();
+ }
+ }
+ }
+}
+
+void
+ZenServerState::Snapshot(std::function<void(const ZenServerEntry&)>&& Callback)
+{
+ if (m_Data == nullptr)
+ {
+ return;
+ }
+
+ for (int i = 0; i < m_MaxEntryCount; ++i)
+ {
+ ZenServerEntry& Entry = m_Data[i];
+
+ if (Entry.DesiredListenPort)
+ {
+ Callback(Entry);
+ }
+ }
+}
+
+void
+ZenServerState::ZenServerEntry::Reset()
+{
+ Pid = 0;
+ DesiredListenPort = 0;
+ Flags = 0;
+ EffectiveListenPort = 0;
+}
+
+void
+ZenServerState::ZenServerEntry::SignalShutdownRequest()
+{
+ Flags |= uint16_t(FlagsEnum::kShutdownPlease);
+}
+
+void
+ZenServerState::ZenServerEntry::SignalReady()
+{
+ Flags |= uint16_t(FlagsEnum::kIsReady);
+}
+
+bool
+ZenServerState::ZenServerEntry::AddSponsorProcess(uint32_t PidToAdd)
+{
+ for (std::atomic<uint32_t>& PidEntry : SponsorPids)
+ {
+ if (PidEntry.load(std::memory_order_relaxed) == 0)
+ {
+ uint32_t Expected = 0;
+ if (PidEntry.compare_exchange_strong(Expected, PidToAdd))
+ {
+ // Success!
+ return true;
+ }
+ }
+ else if (PidEntry.load(std::memory_order_relaxed) == PidToAdd)
+ {
+ // Success, the because pid is already in the list
+ return true;
+ }
+ }
+
+ return false;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+std::atomic<int> ZenServerTestCounter{0};
+
+ZenServerEnvironment::ZenServerEnvironment()
+{
+}
+
+ZenServerEnvironment::~ZenServerEnvironment()
+{
+}
+
+void
+ZenServerEnvironment::Initialize(std::filesystem::path ProgramBaseDir)
+{
+ m_ProgramBaseDir = ProgramBaseDir;
+
+ ZEN_DEBUG("Program base dir is '{}'", ProgramBaseDir);
+
+ m_IsInitialized = true;
+}
+
+void
+ZenServerEnvironment::InitializeForTest(std::filesystem::path ProgramBaseDir,
+ std::filesystem::path TestBaseDir,
+ std::string_view ServerClass)
+{
+ using namespace std::literals;
+
+ m_ProgramBaseDir = ProgramBaseDir;
+ m_TestBaseDir = TestBaseDir;
+
+ ZEN_INFO("Program base dir is '{}'", ProgramBaseDir);
+ ZEN_INFO("Cleaning test base dir '{}'", TestBaseDir);
+ DeleteDirectories(TestBaseDir.c_str());
+
+ m_IsTestInstance = true;
+ m_IsInitialized = true;
+
+ if (ServerClass.empty())
+ {
+#if ZEN_WITH_HTTPSYS
+ m_ServerClass = "httpsys"sv;
+#else
+ m_ServerClass = "asio"sv;
+#endif
+ }
+ else
+ {
+ m_ServerClass = ServerClass;
+ }
+}
+
+std::filesystem::path
+ZenServerEnvironment::CreateNewTestDir()
+{
+ using namespace std::literals;
+
+ ExtendableWideStringBuilder<256> TestDir;
+ TestDir << "test"sv << int64_t(++ZenServerTestCounter);
+
+ std::filesystem::path TestPath = m_TestBaseDir / TestDir.c_str();
+
+ ZEN_INFO("Creating new test dir @ '{}'", TestPath);
+
+ CreateDirectories(TestPath.c_str());
+
+ return TestPath;
+}
+
+std::filesystem::path
+ZenServerEnvironment::GetTestRootDir(std::string_view Path)
+{
+ std::filesystem::path Root = m_ProgramBaseDir.parent_path().parent_path();
+
+ std::filesystem::path Relative{Path};
+
+ return Root / Relative;
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+std::atomic<int> ChildIdCounter{0};
+
+ZenServerInstance::ZenServerInstance(ZenServerEnvironment& TestEnvironment) : m_Env(TestEnvironment)
+{
+ ZEN_ASSERT(TestEnvironment.IsInitialized());
+}
+
+ZenServerInstance::~ZenServerInstance()
+{
+ Shutdown();
+}
+
+void
+ZenServerInstance::SignalShutdown()
+{
+ m_ShutdownEvent.Set();
+}
+
+void
+ZenServerInstance::Shutdown()
+{
+ if (m_Process.IsValid())
+ {
+ if (m_Terminate)
+ {
+ ZEN_INFO("Terminating zenserver process");
+ m_Process.Terminate(111);
+ m_Process.Reset();
+ }
+ else
+ {
+ SignalShutdown();
+ m_Process.Wait();
+ m_Process.Reset();
+ }
+ }
+}
+
+void
+ZenServerInstance::SpawnServer(int BasePort, std::string_view AdditionalServerArgs)
+{
+ ZEN_ASSERT(!m_Process.IsValid()); // Only spawn once
+
+ const int MyPid = zen::GetCurrentProcessId();
+ const int ChildId = ++ChildIdCounter;
+
+ ExtendableStringBuilder<32> ChildEventName;
+ ChildEventName << "Zen_Child_" << ChildId;
+ NamedEvent ChildEvent{ChildEventName};
+
+ CreateShutdownEvent(BasePort);
+
+ ExtendableStringBuilder<32> LogId;
+ LogId << "Zen" << ChildId;
+
+ ExtendableStringBuilder<512> CommandLine;
+ CommandLine << "zenserver" ZEN_EXE_SUFFIX_LITERAL; // see CreateProc() call for actual binary path
+
+ const bool IsTest = m_Env.IsTestEnvironment();
+
+ if (IsTest)
+ {
+ if (!m_OwnerPid.has_value())
+ {
+ m_OwnerPid = MyPid;
+ }
+
+ CommandLine << " --test --log-id " << LogId;
+ }
+
+ if (m_OwnerPid.has_value())
+ {
+ CommandLine << " --owner-pid " << m_OwnerPid.value();
+ }
+
+ CommandLine << " --child-id " << ChildEventName;
+
+ if (std::string_view ServerClass = m_Env.GetServerClass(); ServerClass.empty() == false)
+ {
+ CommandLine << " --http " << ServerClass;
+ }
+
+ if (BasePort)
+ {
+ CommandLine << " --port " << BasePort;
+ m_BasePort = BasePort;
+ }
+
+ if (!m_TestDir.empty())
+ {
+ CommandLine << " --data-dir ";
+ PathToUtf8(m_TestDir.c_str(), CommandLine);
+ }
+
+ if (!AdditionalServerArgs.empty())
+ {
+ CommandLine << " " << AdditionalServerArgs;
+ }
+
+ std::filesystem::path CurrentDirectory = std::filesystem::current_path();
+
+ ZEN_DEBUG("Spawning server '{}'", LogId);
+
+ uint32_t CreationFlags = 0;
+ if (!IsTest)
+ {
+ CreationFlags |= CreateProcOptions::Flag_NewConsole;
+ }
+
+ const std::filesystem::path BaseDir = m_Env.ProgramBaseDir();
+ const std::filesystem::path Executable = BaseDir / "zenserver" ZEN_EXE_SUFFIX_LITERAL;
+ CreateProcOptions CreateOptions = {
+ .WorkingDirectory = &CurrentDirectory,
+ .Flags = CreationFlags,
+ };
+ CreateProcResult ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions);
+#if ZEN_PLATFORM_WINDOWS
+ if (!ChildPid && ::GetLastError() == ERROR_ELEVATION_REQUIRED)
+ {
+ ZEN_DEBUG("Regular spawn failed - spawning elevated server");
+ CreateOptions.Flags |= CreateProcOptions::Flag_Elevated;
+ ChildPid = CreateProc(Executable, CommandLine.ToView(), CreateOptions);
+ }
+#endif
+
+ if (!ChildPid)
+ {
+ ThrowLastError("Server spawn failed");
+ }
+
+ ZEN_DEBUG("Server '{}' spawned OK", LogId);
+
+ if (IsTest)
+ {
+ m_Process.Initialize(ChildPid);
+ }
+
+ m_ReadyEvent = std::move(ChildEvent);
+}
+
+void
+ZenServerInstance::CreateShutdownEvent(int BasePort)
+{
+ ExtendableStringBuilder<32> ChildShutdownEventName;
+ ChildShutdownEventName << "Zen_" << BasePort;
+ ChildShutdownEventName << "_Shutdown";
+ NamedEvent ChildShutdownEvent{ChildShutdownEventName};
+ m_ShutdownEvent = std::move(ChildShutdownEvent);
+}
+
+void
+ZenServerInstance::AttachToRunningServer(int BasePort)
+{
+ ZenServerState State;
+ if (!State.InitializeReadOnly())
+ {
+ // TODO: return success/error code instead?
+ throw std::runtime_error("No zen state found");
+ }
+
+ const ZenServerState::ZenServerEntry* Entry = nullptr;
+
+ if (BasePort)
+ {
+ Entry = State.Lookup(BasePort);
+ }
+ else
+ {
+ State.Snapshot([&](const ZenServerState::ZenServerEntry& InEntry) { Entry = &InEntry; });
+ }
+
+ if (!Entry)
+ {
+ // TODO: return success/error code instead?
+ throw std::runtime_error("No server found");
+ }
+
+ m_Process.Initialize(Entry->Pid);
+ CreateShutdownEvent(Entry->EffectiveListenPort);
+}
+
+void
+ZenServerInstance::Detach()
+{
+ if (m_Process.IsValid())
+ {
+ m_Process.Reset();
+ m_ShutdownEvent.Close();
+ }
+}
+
+void
+ZenServerInstance::WaitUntilReady()
+{
+ while (m_ReadyEvent.Wait(100) == false)
+ {
+ if (!m_Process.IsRunning() || !m_Process.IsValid())
+ {
+ ZEN_INFO("Wait abandoned by invalid process (running={})", m_Process.IsRunning());
+ return;
+ }
+ }
+}
+
+bool
+ZenServerInstance::WaitUntilReady(int Timeout)
+{
+ return m_ReadyEvent.Wait(Timeout);
+}
+
+std::string
+ZenServerInstance::GetBaseUri() const
+{
+ ZEN_ASSERT(m_BasePort);
+
+ return fmt::format("http://localhost:{}", m_BasePort);
+}
+
+} // namespace zen